[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Megvii Inc.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "## BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (CVPRW 2022)\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bsrt-improving-burst-super-resolution-with/burst-image-super-resolution-on-burstsr)](https://paperswithcode.com/sota/burst-image-super-resolution-on-burstsr?p=bsrt-improving-burst-super-resolution-with) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bsrt-improving-burst-super-resolution-with/burst-image-super-resolution-on)](https://paperswithcode.com/sota/burst-image-super-resolution-on?p=bsrt-improving-burst-super-resolution-with)![visitors](https://visitor-badge.glitch.me/badge?page_id=Algolzw/BSRT) \n\n#### [BSRT](https://arxiv.org/abs/2204.08332), the winner of the NTIRE 2022 Burst Super-Resolution Challenge Real-World Track. \nYou can also find our winner method in NTIRE 2021 Burst Super-Resolution Challenge [here](https://github.com/Algolzw/EBSR).\n\n> This work addresses the Burst Super-Resolution (BurstSR) task using a new architecture, which requires restoring a high-quality image from a sequence of noisy, misaligned, and low-resolution RAW bursts. To overcome the challenges in BurstSR, we propose a **B**urst **S**uper-**R**esolution **T**ransformer (**BSRT**), which can significantly improve the capability of extracting inter-frame information and reconstruction. To achieve this goal, we propose a Pyramid Flow-Guided Deformable Convolution Network (Pyramid FG-DCN) and incorporate Swin Transformer Blocks and Groups as our main backbone.  More specifically,  we combine optical flows and deformable convolutions, hence our BSRT can handle misalignment and aggregate the potential texture information in multi-frames more efficiently. In addition, our Transformer-based structure can capture long-range dependency to further improve the performance. The evaluation on both synthetic and real-world tracks demonstrates that our approach achieves a new state-of-the-art in BurstSR task. Further, our BSRT wins the championship in the NTIRE2022 Burst Super-Resolution Challenge.\n\n\n#### Comparison with State-of-the-art Burst Super-Resolution Methods\n\n![ts](figs/ts.png)\n\n\n\n## Overview Architecture\n\n![overview.png](figs/overview.png)\n\n## Dependencies\n- OS: Ubuntu 18.04\n- Python: Python 3.7\n- nvidia :\n   - cuda: 10.1\n   - cudnn: 7.6.1\n- Other reference requirements\n\n## Quick Start\n1.Create a conda virtual environment and activate it\n```python3\nconda create -n pytorch_1.6 python=3.7\nsource activate pytorch_1.6\n```\n2.Install PyTorch and torchvision following the official instructions\n```python3\nconda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch\n```\n3.Install build requirements\n```python3\npip3 install -r requirements.txt\n```\n4.Install DCN\n```python3\ncd DCNv2\npython3 setup.py build develop # build\npython3 test.py # run examples and check\n```\n\n## Training\n\nWe provide all pretrained model weights [here](https://drive.google.com/file/d/1Bv1ZwoE3s8trhG--wjB0Yt6WJIQPpvsn/view?usp=sharing). \n\n#### For Synthetic data\n\n```python3\ncd code/synthetic/bsrt\n# Modify the root path of training dataset and model etc.\n# The number of GPUs should be more than 1\npython main.py --n_GPUs 8 --print_every 40 --lr 0.0001 --decay 150-300 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 32 --burst_size 14 --patch_size 256\n```\n\n#### For Real-World data\n\n```python3\ncd code/real/bsrt\n# Modify the root path of training dataset and model etc.\n# The number of GPUs should be more than 1\npython main.py --n_GPUs 8 --print_every 20 --lr 0.00005 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth \n```\n\nThe pretrained PWC-Net model can be downloaded [here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing). \n\n## Test\n\n#### For Synthetic data\n```python3\n# Modify the path of test dataset and the path of the trained model\npython test_synburst.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth --root /data/dataset/ntire21/burstsr/synthetic\n```\n\n#### For Real-World data\n```python3\n# Modify the path of test dataset and the path of the trained model\npython test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real\n```\n\n## Results\n\n### Comparison on Synthetic dataset\n![cmp_syn.png](figs/cmp_syn.png)\n\n### Comparison on Real-World dataset\n![cmp_real.png](figs/cmp_real.png)\n\n\n## Citations\nIf our code helps your research or work, please consider citing our paper.\nThe following is a BibTeX reference.\n\n```\n@inproceedings{luo2022bsrt,\n  title={BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment},\n  author={Luo, Ziwei and Li, Youwei and Cheng, Shen and Yu, Lei and Wu, Qi and Wen, Zhihong and Fan, Haoqiang and Sun, Jian and Liu, Shuaicheng},\n  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  pages={998--1008},\n  year={2022}\n}\n```\n\n## Contact\nemail: [ziwei.ro@gmail.com]\n"
  },
  {
    "path": "code/real/bsrt/README.md",
    "content": "# BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (Real-World)\n\n## Dependencies\n- OS: Ubuntu 18.04\n- Python: Python 3.7\n- nvidia :\n   - cuda: 10.1\n   - cudnn: 7.6.1\n- Other reference requirements\n\n## Quick Start\n1.Create a conda virtual environment and activate it\n```python3\nconda create -n pytorch_1.6 python=3.7\nsource activate pytorch_1.6\n```\n2.Install PyTorch and torchvision following the official instructions\n```python3\nconda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch\n```\n3.Install build requirements\n```python3\npip3 install -r requirements.txt\n```\n4.Install DCN\n```python3\ncd DCNv2\npython3 setup.py build develop # build\npython3 test.py # run examples and check\n```\n\n## Training\n\nThe pretrained PWC-Net model can be downloaded [here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing). \n\n```python3\n# Modify the root path of training dataset and model etc.\n# The number of GPUs should be more than 1\npython main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth \n```\n## Test\n```python3\n# Modify the path of test dataset and the path of the trained model\npython test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real\n```"
  },
  {
    "path": "code/real/bsrt/data_processing/__init__.py",
    "content": ""
  },
  {
    "path": "code/real/bsrt/data_processing/camera_pipeline.py",
    "content": "import torch\nimport random\nimport math\nimport cv2 as cv\nimport numpy as np\nimport utils.data_format_utils as df_utils\n\"\"\" Based on http://timothybrooks.com/tech/unprocessing\nFunctions for forward and inverse camera pipeline. All functions input a torch float tensor of shape (c, h, w).\nAdditionally, some also support batch operations, i.e. inputs of shape (b, c, h, w)\n\"\"\"\n\n\ndef random_ccm():\n    \"\"\"Generates random RGB -> Camera color correction matrices.\"\"\"\n    # Takes a random convex combination of XYZ -> Camera CCMs.\n    xyz2cams = [[[1.0234, -0.2969, -0.2266],\n               [-0.5625, 1.6328, -0.0469],\n               [-0.0703, 0.2188, 0.6406]],\n              [[0.4913, -0.0541, -0.0202],\n               [-0.613, 1.3513, 0.2906],\n               [-0.1564, 0.2151, 0.7183]],\n              [[0.838, -0.263, -0.0639],\n               [-0.2887, 1.0725, 0.2496],\n               [-0.0627, 0.1427, 0.5438]],\n              [[0.6596, -0.2079, -0.0562],\n               [-0.4782, 1.3016, 0.1933],\n               [-0.097, 0.1581, 0.5181]]]\n\n    num_ccms = len(xyz2cams)\n    xyz2cams = torch.tensor(xyz2cams)\n\n    weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(0.0, 1.0)\n    weights_sum = weights.sum()\n    xyz2cam = (xyz2cams * weights).sum(dim=0) / weights_sum\n\n    # Multiplies with RGB -> XYZ to get RGB -> Camera CCM.\n    rgb2xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],\n                            [0.2126729, 0.7151522, 0.0721750],\n                            [0.0193339, 0.1191920, 0.9503041]])\n    rgb2cam = torch.mm(xyz2cam, rgb2xyz)\n\n    # Normalizes each row.\n    rgb2cam = rgb2cam / rgb2cam.sum(dim=-1, keepdims=True)\n    return rgb2cam\n\n\ndef random_gains():\n    \"\"\"Generates random gains for brightening and white balance.\"\"\"\n    # RGB gain represents brightening.\n    rgb_gain = 1.0 / random.gauss(mu=0.8, sigma=0.1)\n\n    # Red and blue gains represent white balance.\n    red_gain = random.uniform(1.9, 2.4)\n    blue_gain = random.uniform(1.5, 1.9)\n    return rgb_gain, red_gain, blue_gain\n\n\ndef apply_smoothstep(image):\n    \"\"\"Apply global tone mapping curve.\"\"\"\n    image_out = 3 * image**2 - 2 * image**3\n    return image_out\n\n\ndef invert_smoothstep(image):\n    \"\"\"Approximately inverts a global tone mapping curve.\"\"\"\n    image = image.clamp(0.0, 1.0)\n    return 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)\n\n\ndef gamma_expansion(image):\n    \"\"\"Converts from gamma to linear space.\"\"\"\n    # Clamps to prevent numerical instability of gradients near zero.\n    return image.clamp(1e-8) ** 2.2\n\n\ndef gamma_compression(image):\n    \"\"\"Converts from linear to gammaspace.\"\"\"\n    # Clamps to prevent numerical instability of gradients near zero.\n    return image.clamp(1e-8) ** (1.0 / 2.2)\n\n\ndef apply_ccm(image, ccm):\n    \"\"\"Applies a color correction matrix.\"\"\"\n    assert image.dim() == 3 and image.shape[0] == 3\n\n    shape = image.shape\n    image = image.view(3, -1)\n    ccm = ccm.to(image.device).type_as(image)\n\n    image = torch.mm(ccm, image)\n\n    return image.view(shape)\n\n\ndef apply_gains(image, rgb_gain, red_gain, blue_gain):\n    \"\"\"Inverts gains while safely handling saturated pixels.\"\"\"\n    assert image.dim() == 3 and image.shape[0] in [3, 4]\n\n    if image.shape[0] == 3:\n        gains = torch.tensor([red_gain, 1.0, blue_gain]) * rgb_gain\n    else:\n        gains = torch.tensor([red_gain, 1.0, 1.0, blue_gain]) * rgb_gain\n    gains = gains.view(-1, 1, 1)\n    gains = gains.to(image.device).type_as(image)\n\n    return (image * gains).clamp(0.0, 1.0)\n\n\ndef safe_invert_gains(image, rgb_gain, red_gain, blue_gain):\n    \"\"\"Inverts gains while safely handling saturated pixels.\"\"\"\n    assert image.dim() == 3 and image.shape[0] == 3\n\n    gains = torch.tensor([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain\n    gains = gains.view(-1, 1, 1)\n\n    # Prevents dimming of saturated pixels by smoothly masking gains near white.\n    gray = image.mean(dim=0, keepdims=True)\n    inflection = 0.9\n    mask = ((gray - inflection).clamp(0.0) / (1.0 - inflection)) ** 2.0\n\n    safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)\n    return image * safe_gains\n\n\ndef mosaic(image, mode='rggb'):\n    \"\"\"Extracts RGGB Bayer planes from an RGB image.\"\"\"\n    shape = image.shape\n    if image.dim() == 3:\n        image = image.unsqueeze(0)\n\n    if mode == 'rggb':\n        red = image[:, 0, 0::2, 0::2]\n        green_red = image[:, 1, 0::2, 1::2]\n        green_blue = image[:, 1, 1::2, 0::2]\n        blue = image[:, 2, 1::2, 1::2]\n        image = torch.stack((red, green_red, green_blue, blue), dim=1)\n    elif mode == 'grbg':\n        green_red = image[:, 1, 0::2, 0::2]\n        red = image[:, 0, 0::2, 1::2]\n        blue = image[:, 2, 0::2, 1::2]\n        green_blue = image[:, 1, 1::2, 1::2]\n\n        image = torch.stack((green_red, red, blue, green_blue), dim=1)\n\n    if len(shape) == 3:\n        return image.view((4, shape[-2] // 2, shape[-1] // 2))\n    else:\n        return image.view((-1, 4, shape[-2] // 2, shape[-1] // 2))\n\n\ndef demosaic(image):\n    assert isinstance(image, torch.Tensor)\n    image = image.clamp(0.0, 1.0) * 255\n\n    if image.dim() == 4:\n        num_images = image.dim()\n        batch_input = True\n    else:\n        num_images = 1\n        batch_input = False\n        image = image.unsqueeze(0)\n\n    # Generate single channel input for opencv\n    im_sc = torch.zeros((num_images, image.shape[-2] * 2, image.shape[-1] * 2, 1))\n    im_sc[:, ::2, ::2, 0] = image[:, 0, :, :]\n    im_sc[:, ::2, 1::2, 0] = image[:, 1, :, :]\n    im_sc[:, 1::2, ::2, 0] = image[:, 2, :, :]\n    im_sc[:, 1::2, 1::2, 0] = image[:, 3, :, :]\n\n    im_sc = im_sc.numpy().astype(np.uint8)\n\n    out = []\n\n    for im in im_sc:\n        # cv.imwrite('frames/tmp.png', im)\n        im_dem_np = cv.cvtColor(im, cv.COLOR_BAYER_BG2RGB)#_VNG)\n\n        # Convert to torch image\n        im_t = df_utils.npimage_to_torch(im_dem_np, input_bgr=False)\n        out.append(im_t)\n\n    if batch_input:\n        return torch.stack(out, dim=0)\n    else:\n        return out[0]\n\n\ndef random_noise_levels():\n    \"\"\"Generates random noise levels from a log-log linear distribution.\"\"\"\n    log_min_shot_noise = math.log(0.0001)\n    log_max_shot_noise = math.log(0.012)\n    log_shot_noise = random.uniform(log_min_shot_noise, log_max_shot_noise)\n    shot_noise = math.exp(log_shot_noise)\n\n    line = lambda x: 2.18 * x + 1.20\n    log_read_noise = line(log_shot_noise) + random.gauss(mu=0.0, sigma=0.26)\n    read_noise = math.exp(log_read_noise)\n    return shot_noise, read_noise\n\n\ndef add_noise(image, shot_noise=0.01, read_noise=0.0005):\n    \"\"\"Adds random shot (proportional to image) and read (independent) noise.\"\"\"\n    variance = image * shot_noise + read_noise\n    noise = torch.FloatTensor(image.shape).normal_().to(image.device)*variance.sqrt()\n    return image + noise\n\n\ndef process_linear_image_rgb(image, meta_info, return_np=False):\n    image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])\n    image = apply_ccm(image, meta_info['cam2rgb'])\n\n    if meta_info['gamma']:\n        image = gamma_compression(image)\n\n    if meta_info['smoothstep']:\n        image = apply_smoothstep(image)\n\n    image = image.clamp(0.0, 1.0)\n\n    if return_np:\n        image = df_utils.torch_to_npimage(image)\n    return image\n\n\ndef process_linear_image_raw(image, meta_info):\n    image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])\n    image = demosaic(image)\n    image = apply_ccm(image, meta_info['cam2rgb'])\n\n    if meta_info['gamma']:\n        image = gamma_compression(image)\n\n    if meta_info['smoothstep']:\n        image = apply_smoothstep(image)\n    return image.clamp(0.0, 1.0)\n"
  },
  {
    "path": "code/real/bsrt/data_processing/synthetic_burst_generation.py",
    "content": "import torch\nimport random\nimport cv2\nimport numpy as np\nimport torch.nn.functional as F\nfrom data_processing.camera_pipeline import *\nfrom utils.data_format_utils import torch_to_numpy, numpy_to_torch\n\n\ndef random_crop(frames, crop_sz):\n    \"\"\" Extract a random crop of size crop_sz from the input frames. If the crop_sz is larger than the input image size,\n    then the largest possible crop of same aspect ratio as crop_sz will be extracted from frames, and upsampled to\n    crop_sz.\n    \"\"\"\n    if not isinstance(crop_sz, (tuple, list)):\n        crop_sz = (crop_sz, crop_sz)\n    crop_sz = torch.tensor(crop_sz).float()\n\n    shape = frames.shape\n\n    # Select scale_factor. Ensure the crop fits inside the image\n    max_scale_factor = torch.tensor(shape[-2:]).float() / crop_sz\n    max_scale_factor = max_scale_factor.min().item()\n\n    if max_scale_factor < 1.0:\n        scale_factor = max_scale_factor\n    else:\n        scale_factor = 1.0\n\n    # Extract the crop\n    orig_crop_sz = (crop_sz * scale_factor).floor()\n\n    assert orig_crop_sz[-2] <= shape[-2] and orig_crop_sz[-1] <= shape[-1], 'Bug in crop size estimation!'\n\n    r1 = random.randint(0, shape[-2] - orig_crop_sz[-2])\n    c1 = random.randint(0, shape[-1] - orig_crop_sz[-1])\n\n    r2 = r1 + orig_crop_sz[0].int().item()\n    c2 = c1 + orig_crop_sz[1].int().item()\n\n    frames_crop = frames[:, r1:r2, c1:c2]\n\n    # Resize to crop_sz\n    if scale_factor < 1.0:\n        frames_crop = F.interpolate(frames_crop.unsqueeze(0), size=crop_sz.int().tolist(), mode='bilinear', align_corners=False).squeeze(0)\n    return frames_crop\n\n\ndef rgb2rawburst(image, burst_size, downsample_factor=1, burst_transformation_params=None,\n                 image_processing_params=None, interpolation_type='bilinear'):\n    \"\"\" Generates a synthetic LR RAW burst from the input image. The input sRGB image is first converted to linear\n    sensor space using an inverse camera pipeline. A LR burst is then generated by applying random\n    transformations defined by burst_transformation_params to the input image, and downsampling it by the\n    downsample_factor. The generated burst is then mosaicekd and corrputed by random noise.\n    \"\"\"\n\n    if image_processing_params is None:\n        image_processing_params = {}\n\n    _defaults = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 'gamma': True, 'add_noise': True}\n    for k, v in _defaults.items():\n        if k not in image_processing_params:\n            image_processing_params[k] = v\n\n    # Sample camera pipeline params\n    if image_processing_params['random_ccm']:\n        rgb2cam = random_ccm()\n    else:\n        rgb2cam = torch.eye(3).float()\n    cam2rgb = rgb2cam.inverse()\n\n    # Sample gains\n    if image_processing_params['random_gains']:\n        rgb_gain, red_gain, blue_gain = random_gains()\n    else:\n        rgb_gain, red_gain, blue_gain = (1.0, 1.0, 1.0)\n\n    # Approximately inverts global tone mapping.\n    use_smoothstep = image_processing_params['smoothstep']\n    if use_smoothstep:\n        image = invert_smoothstep(image)\n\n    # Inverts gamma compression.\n    use_gamma = image_processing_params['gamma']\n    if use_gamma:\n        image = gamma_expansion(image)\n\n    # Inverts color correction.\n    image = apply_ccm(image, rgb2cam)\n\n    # Approximately inverts white balance and brightening.\n    image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain)\n\n    # Clip saturated pixels.\n    image = image.clamp(0.0, 1.0)\n\n    # Generate LR burst\n    image_burst_rgb, flow_vectors = single2lrburst(image, burst_size=burst_size,\n                                                   downsample_factor=downsample_factor,\n                                                   transformation_params=burst_transformation_params,\n                                                   interpolation_type=interpolation_type)\n\n    # mosaic\n    image_burst = mosaic(image_burst_rgb.clone())\n\n    # Add noise\n    if image_processing_params['add_noise']:\n        shot_noise_level, read_noise_level = random_noise_levels()\n        image_burst = add_noise(image_burst, shot_noise_level, read_noise_level)\n    else:\n        shot_noise_level = 0\n        read_noise_level = 0\n\n    # Clip saturated pixels.\n    image_burst = image_burst.clamp(0.0, 1.0)\n\n    meta_info = {'rgb2cam': rgb2cam, 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain,\n                 'blue_gain': blue_gain, 'smoothstep': use_smoothstep, 'gamma': use_gamma,\n                 'shot_noise_level': shot_noise_level, 'read_noise_level': read_noise_level}\n    return image_burst, image, image_burst_rgb, flow_vectors, meta_info\n\n\ndef get_tmat(image_shape, translation, theta, shear_values, scale_factors):\n    \"\"\" Generates a transformation matrix corresponding to the input transformation parameters \"\"\"\n    im_h, im_w = image_shape\n\n    t_mat = np.identity(3)\n\n    t_mat[0, 2] = translation[0]\n    t_mat[1, 2] = translation[1]\n    t_rot = cv2.getRotationMatrix2D((im_w * 0.5, im_h * 0.5), theta, 1.0)\n    t_rot = np.concatenate((t_rot, np.array([0.0, 0.0, 1.0]).reshape(1, 3)))\n\n    t_shear = np.array([[1.0, shear_values[0], -shear_values[0] * 0.5 * im_w],\n                        [shear_values[1], 1.0, -shear_values[1] * 0.5 * im_h],\n                        [0.0, 0.0, 1.0]])\n\n    t_scale = np.array([[scale_factors[0], 0.0, 0.0],\n                        [0.0, scale_factors[1], 0.0],\n                        [0.0, 0.0, 1.0]])\n\n    t_mat = t_scale @ t_rot @ t_shear @ t_mat\n\n    t_mat = t_mat[:2, :]\n\n    return t_mat\n\n\ndef single2lrburst(image, burst_size, downsample_factor=1, transformation_params=None,\n                   interpolation_type='bilinear'):\n    \"\"\" Generates a burst of size burst_size from the input image by applying random transformations defined by\n    transformation_params, and downsampling the resulting burst by downsample_factor.\n    \"\"\"\n\n    if interpolation_type == 'bilinear':\n        interpolation = cv2.INTER_LINEAR\n    elif interpolation_type == 'lanczos':\n        interpolation = cv2.INTER_LANCZOS4\n    else:\n        raise ValueError\n\n    normalize = False\n    if isinstance(image, torch.Tensor):\n        if image.max() < 2.0:\n            image = image * 255.0\n            normalize = True\n        image = torch_to_numpy(image).astype(np.uint8)\n\n    burst = []\n    sample_pos_inv_all = []\n\n    rvs, cvs = torch.meshgrid([torch.arange(0, image.shape[0]),\n                               torch.arange(0, image.shape[1])])\n\n    sample_grid = torch.stack((cvs, rvs, torch.ones_like(cvs)), dim=-1).float()\n\n    for i in range(burst_size):\n        if i == 0:\n            # For base image, do not apply any random transformations. We only translate the image to center the\n            # sampling grid\n            shift = (downsample_factor / 2.0) - 0.5\n            translation = (shift, shift)\n            theta = 0.0\n            shear_factor = (0.0, 0.0)\n            scale_factor = (1.0, 1.0)\n        else:\n            # Sample random image transformation parameters\n            max_translation = transformation_params.get('max_translation', 0.0)\n\n            if max_translation <= 0.01:\n                shift = (downsample_factor / 2.0) - 0.5\n                translation = (shift, shift)\n            else:\n                translation = (random.uniform(-max_translation, max_translation),\n                               random.uniform(-max_translation, max_translation))\n\n            max_rotation = transformation_params.get('max_rotation', 0.0)\n            theta = random.uniform(-max_rotation, max_rotation)\n\n            max_shear = transformation_params.get('max_shear', 0.0)\n            shear_x = random.uniform(-max_shear, max_shear)\n            shear_y = random.uniform(-max_shear, max_shear)\n            shear_factor = (shear_x, shear_y)\n\n            max_ar_factor = transformation_params.get('max_ar_factor', 0.0)\n            ar_factor = np.exp(random.uniform(-max_ar_factor, max_ar_factor))\n\n            max_scale = transformation_params.get('max_scale', 0.0)\n            scale_factor = np.exp(random.uniform(-max_scale, max_scale))\n\n            scale_factor = (scale_factor, scale_factor * ar_factor)\n\n        output_sz = (image.shape[1], image.shape[0])\n\n        # Generate a affine transformation matrix corresponding to the sampled parameters\n        t_mat = get_tmat((image.shape[0], image.shape[1]), translation, theta, shear_factor, scale_factor)\n        t_mat_tensor = torch.from_numpy(t_mat)\n\n        # Apply the sampled affine transformation\n        image_t = cv2.warpAffine(image, t_mat, output_sz, flags=interpolation,\n                                 borderMode=cv2.BORDER_CONSTANT)\n\n        t_mat_tensor_3x3 = torch.cat((t_mat_tensor.float(), torch.tensor([0.0, 0.0, 1.0]).view(1, 3)), dim=0)\n        t_mat_tensor_inverse = t_mat_tensor_3x3.inverse()[:2, :].contiguous()\n\n        sample_pos_inv = torch.mm(sample_grid.view(-1, 3), t_mat_tensor_inverse.t().float()).view(\n            *sample_grid.shape[:2], -1)\n\n        if transformation_params.get('border_crop') is not None:\n            border_crop = transformation_params.get('border_crop')\n\n            image_t = image_t[border_crop:-border_crop, border_crop:-border_crop, :]\n            sample_pos_inv = sample_pos_inv[border_crop:-border_crop, border_crop:-border_crop, :]\n\n        # Downsample the image\n        image_t = cv2.resize(image_t, None, fx=1.0 / downsample_factor, fy=1.0 / downsample_factor,\n                             interpolation=interpolation)\n        sample_pos_inv = cv2.resize(sample_pos_inv.numpy(), None, fx=1.0 / downsample_factor,\n                                    fy=1.0 / downsample_factor,\n                                    interpolation=interpolation)\n\n        sample_pos_inv = torch.from_numpy(sample_pos_inv).permute(2, 0, 1).contiguous()\n\n        if normalize:\n            image_t = numpy_to_torch(image_t).float() / 255.0\n        else:\n            image_t = numpy_to_torch(image_t).float()\n        burst.append(image_t)\n        sample_pos_inv_all.append(sample_pos_inv / downsample_factor)\n\n    burst_images = torch.stack(burst)\n    sample_pos_inv_all = torch.stack(sample_pos_inv_all)\n\n    # Compute the flow vectors to go from the i'th burst image to the base image\n    flow_vectors = sample_pos_inv_all - sample_pos_inv_all[:, :1, ...]\n\n    return burst_images, flow_vectors\n"
  },
  {
    "path": "code/real/bsrt/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "code/real/bsrt/datasets/burstsr_dataset.py",
    "content": "import os\nimport torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\nimport torch.nn.functional as F\nimport random\nimport time\n\nclass SamsungRAWImage:\n    @staticmethod\n    def load(path):\n        im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED)\n\n        im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16)\n        im_raw = torch.from_numpy(im_raw)\n\n        meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), \"rb\", -1))\n\n        return SamsungRAWImage(im_raw, meta_data['black_level'], meta_data['cam_wb'],\n                               meta_data['daylight_wb'], meta_data['color_matrix'], meta_data['exif_data'],\n                               meta_data.get('crop_info', None), meta_data.get('im_preview', None))\n\n    def __init__(self, im_raw, black_level, cam_wb, daylight_wb, color_matrix, exif_data, crop_info=None,\n                 im_preview=None):\n        self.im_raw = im_raw\n\n        self.black_level = black_level\n        self.cam_wb = cam_wb\n        self.daylight_wb = daylight_wb\n        self.color_matrix = color_matrix\n        self.exif_data = exif_data\n        self.crop_info = crop_info\n        self.im_preview = im_preview\n\n        self.norm_factor = 1023.0\n\n    def get_all_meta_data(self):\n        return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb,\n                'color_matrix': self.color_matrix.tolist()}\n\n    def get_exposure_time(self):\n        return self.exif_data['Image ExposureTime'].values[0].decimal()\n\n    def get_noise_profile(self):\n        noise = self.exif_data['Image Tag 0xC761'].values\n        noise = [n[0] for n in noise]\n        noise = np.array(noise).reshape(3, 2)\n        return noise\n\n    def get_f_number(self):\n        return self.exif_data['Image FNumber'].values[0].decimal()\n\n    def get_iso(self):\n        return self.exif_data['Image ISOSpeedRatings'].values[0]\n\n    def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False):\n        im_raw = self.im_raw.float()\n\n        if substract_black_level:\n            im_raw = im_raw - torch.tensor(self.black_level).view(4, 1, 1)\n\n        if white_balance:\n            im_raw = im_raw * torch.tensor(self.cam_wb).view(4, 1, 1)\n\n        if normalize:\n            im_raw = im_raw / self.norm_factor\n\n\n        return im_raw\n\n    def shape(self):\n        shape = (4, self.im_raw.shape[1], self.im_raw.shape[2])\n        return shape\n\n    def crop_image(self, r1, r2, c1, c2):\n        self.im_raw = self.im_raw[:, r1:r2, c1:c2]\n\n    def get_crop(self, r1, r2, c1, c2):\n        im_raw = self.im_raw[:, r1:r2, c1:c2]\n\n        if self.im_preview is not None:\n            im_preview = self.im_preview[2*r1:2*r2, 2*c1:2*c2]\n        else:\n            im_preview = None\n\n        return SamsungRAWImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.color_matrix,\n                               self.exif_data, im_preview=im_preview)\n\n    def postprocess(self, return_np=True, norm_factor=None):\n        # Convert to rgb\n        # im = torch.from_numpy(self.im_raw.astype(np.float32))\n        im = self.im_raw\n\n        im = (im - torch.tensor(self.black_level).view(4, 1, 1)) * torch.tensor(self.cam_wb).view(4, 1, 1)\n\n        if norm_factor is None:\n            im = im / im.max()\n        else:\n            im = im / norm_factor\n\n        im = torch.stack((im[0], (im[1] + im[2])/2, im[3]), dim=0)\n        # im = torch.stack((im[0], im[1], im[3]), dim=0)\n\n        im_out = im.clamp(0.0, 1.0)\n\n        if return_np:\n            im_out = im_out.permute(1, 2, 0).numpy() * 255.0\n            im_out = im_out.astype(np.uint8)\n        return im_out\n\n\nclass CanonImage:\n    @staticmethod\n    def load(path, split='train'):\n        im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED)\n        im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16)\n        im_raw = torch.from_numpy(im_raw)\n        meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), \"rb\", -1))\n\n        return CanonImage(im_raw.float(), meta_data['black_level'], meta_data['cam_wb'],\n                          meta_data['daylight_wb'], meta_data['rgb_xyz_matrix'], meta_data.get('exif_data', None),\n                          meta_data.get('crop_info', None))\n\n    def __init__(self, im_raw, black_level, cam_wb, daylight_wb, rgb_xyz_matrix, exif_data, crop_info=None):\n        super(CanonImage, self).__init__()\n        self.im_raw = im_raw\n\n        if len(black_level) == 4:\n            black_level = [black_level[0], black_level[1], black_level[3]]\n        self.black_level = black_level\n\n        if len(cam_wb) == 4:\n            cam_wb = [cam_wb[0], cam_wb[1], cam_wb[3]]\n        self.cam_wb = cam_wb\n\n        if len(daylight_wb) == 4:\n            daylight_wb = [daylight_wb[0], daylight_wb[1], daylight_wb[3]]\n        self.daylight_wb = daylight_wb\n\n        self.rgb_xyz_matrix = rgb_xyz_matrix\n        self.xyz_srgb_matrix = torch.tensor([3.2404542, -1.5371385, -0.4985314,\n                                             -0.9692660,  1.8760108,  0.0415560,\n                                             0.0556434, -0.2040259,  1.0572252]).view(3, 3)\n        self.exif_data = exif_data\n        self.crop_info = crop_info\n\n        self.norm_factor = 16383\n\n    def shape(self):\n        shape = (3, self.im_raw.shape[1], self.im_raw.shape[2])\n        return shape\n\n    def get_all_meta_data(self):\n        return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb,\n                'rgb_xyz_matrix': self.rgb_xyz_matrix.tolist(), 'crop_info': self.crop_info,\n                'norm_factor': self.norm_factor}\n\n    def get_exposure_time(self):\n        return self.exif_data['EXIF ExposureTime'].values[0].decimal()\n\n    def get_f_number(self):\n        return self.exif_data['EXIF FNumber'].values[0].decimal()\n\n    def get_iso(self):\n        return self.exif_data['EXIF ISOSpeedRatings'].values[0]\n\n    def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False):\n        im_raw = self.im_raw.float()\n\n        if substract_black_level:\n            im_raw = im_raw - torch.tensor(self.black_level).view(3, 1, 1)\n\n        if white_balance:\n            im_raw = im_raw * torch.tensor(self.cam_wb).view(3, 1, 1) / 1024.0\n\n        if normalize:\n            im_raw = im_raw / self.norm_factor\n\n        return im_raw\n\n    def set_image_data(self, im_data):\n        self.im_raw = im_data\n\n    def crop_image(self, r1, r2, c1, c2):\n        self.im_raw = self.im_raw[:, r1:r2, c1:c2]\n\n    def get_crop(self, r1, r2, c1, c2):\n        im_raw = self.im_raw[:, r1:r2, c1:c2]\n        return CanonImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.rgb_xyz_matrix,\n                          self.exif_data, self.crop_info)\n\n    def set_crop_info(self, crop_info):\n        self.crop_info = crop_info\n\n    def resize(self, size=None, scale_factor=None):\n\n        self.im_raw = F.interpolate(self.im_raw.unsqueeze(0), size=size, scale_factor=scale_factor,\n                                    mode='bilinear').squeeze(0)\n\n    def postprocess(self, return_np=True):\n        # Convert to rgb\n        im = self.im_raw\n\n        im = (im - torch.tensor(self.black_level).view(3, 1, 1)).float() * torch.tensor(self.cam_wb).view(3, 1, 1)\n\n        im_out = im / im.max()\n        im_out = im_out.clamp(0.0, 1.0)\n\n        if return_np:\n            im_out = im_out.permute(1, 2, 0).numpy() * 255.0\n            im_out = im_out.astype(np.uint8)\n        return im_out\n\n\ndef load_txt(path):\n    with open(path, 'r') as fh:\n        out = [d.rstrip() for d in fh.readlines()]\n\n    return out\n\n\nclass BurstSRDataset(torch.utils.data.Dataset):\n    \"\"\" Real-world burst super-resolution dataset. \"\"\"\n    def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='train'):\n        \"\"\"\n        args:\n            root : path of the root directory\n            burst_size : Burst size. Maximum allowed burst size is 14.\n            crop_sz: Size of the extracted crop. Maximum allowed crop size is 80\n            center_crop: Whether to extract a random crop, or a centered crop.\n            random_flip: Whether to apply random horizontal and vertical flip\n            split: Can be 'train' or 'val'\n        \"\"\"\n        assert burst_size <= 14, 'burst_sz must be less than or equal to 14'\n        assert crop_sz <= 80, 'crop_sz must be less than or equal to 80'\n        assert split in ['train', 'val']\n\n        root = root + '/' + split\n        super().__init__()\n\n        self.burst_size = burst_size\n        self.crop_sz = crop_sz\n        self.split = split\n        self.center_crop = center_crop\n        self.random_flip = random_flip\n\n        self.root = root\n\n        self.substract_black_level = True\n        self.white_balance = False\n\n        self.burst_list = self._get_burst_list()\n\n    def _get_burst_list(self):\n        burst_list = sorted(os.listdir('{}'.format(self.root)))\n        # print(burst_list)\n        return burst_list\n\n    def get_burst_info(self, burst_id):\n        burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]}\n        return burst_info\n\n    def _get_raw_image(self, burst_id, im_id):\n        raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id))\n        return raw_image\n\n    def _get_gt_image(self, burst_id):\n        canon_im = CanonImage.load('{}/{}/canon'.format(self.root, self.burst_list[burst_id]), split=self.split)\n        return canon_im\n\n    def get_burst(self, burst_id, im_ids, info=None):\n        frames = [self._get_raw_image(burst_id, i) for i in im_ids]\n\n        gt = self._get_gt_image(burst_id)\n        if info is None:\n            info = self.get_burst_info(burst_id)\n\n        return frames, gt, info\n\n    def _sample_images(self):\n        burst_size = 14\n\n        ids = random.sample(range(1, burst_size), k=self.burst_size - 1)\n        ids = [0, ] + ids\n        return ids\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def __getitem__(self, index):\n        # Sample the images in the burst, in case a burst_size < 14 is used.\n        im_ids = self._sample_images()\n\n        # Read the burst images along with HR ground truth\n        frames, gt, meta_info = self.get_burst(index, im_ids)\n\n        # Extract crop if needed\n        if frames[0].shape()[-1] != self.crop_sz:\n            if getattr(self, 'center_crop', False):\n                r1 = (frames[0].shape()[-2] - self.crop_sz) // 2\n                c1 = (frames[0].shape()[-1] - self.crop_sz) // 2\n            else:\n                r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz)\n                c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz)\n            r2 = r1 + self.crop_sz\n            c2 = c1 + self.crop_sz\n\n            scale_factor = gt.shape()[-1] // frames[0].shape()[-1]\n            frames = [im.get_crop(r1, r2, c1, c2) for im in frames]\n\n            gt = gt.get_crop(scale_factor * r1, scale_factor * r2, scale_factor * c1, scale_factor * c2)\n\n        # Load the RAW image data\n        burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level,\n                                              white_balance=self.white_balance) for im in frames]\n\n        # Convert to tensor\n        gt_image_data = gt.get_image_data(normalize=True, white_balance=self.white_balance,\n                                          substract_black_level=self.substract_black_level)\n\n        if self.random_flip:\n            burst_image_data = [flatten_raw_image(im) for im in burst_image_data]\n\n            pad = [0, 0, 0, 0]\n            if random.random() > 0.5:\n                burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data]\n                gt_image_data = gt_image_data.flip([2, ])[:, :, 2:-2].contiguous()\n                pad[1] = 1\n\n            if random.random() > 0.5:\n                burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data]\n                gt_image_data = gt_image_data.flip([1, ])[:, 2:-2, :].contiguous()\n                pad[3] = 1\n\n            burst_image_data = [pack_raw_image(im) for im in burst_image_data]\n            burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data]\n\n            gt_image_data = F.pad(gt_image_data.unsqueeze(0), [4 * p for p in pad], mode='replicate').squeeze(0)\n\n        burst_image_meta_info = frames[0].get_all_meta_data()\n\n        burst_image_meta_info['black_level_subtracted'] = self.substract_black_level\n        burst_image_meta_info['while_balance_applied'] = self.white_balance\n        burst_image_meta_info['norm_factor'] = frames[0].norm_factor\n\n        gt_image_meta_info = gt.get_all_meta_data()\n\n        burst = torch.stack(burst_image_data, dim=0)\n\n        burst_exposure = frames[0].get_exposure_time()\n        canon_exposure = gt.get_exposure_time()\n\n        burst_f_number = frames[0].get_f_number()\n        canon_f_number = gt.get_f_number()\n\n        burst_iso = frames[0].get_iso()\n        canon_iso = gt.get_iso()\n\n        # Normalize the GT image to account for differences in exposure, ISO etc\n        light_factor_burst = burst_exposure * burst_iso / (burst_f_number ** 2)\n        light_factor_canon = canon_exposure * canon_iso / (canon_f_number ** 2)\n\n        exp_scale_factor = (light_factor_burst / light_factor_canon)\n        gt_image_data = gt_image_data * exp_scale_factor\n\n        gt_image_meta_info['black_level_subtracted'] = self.substract_black_level\n        gt_image_meta_info['while_balance_applied'] = self.white_balance\n        gt_image_meta_info['norm_factor'] = gt.norm_factor / exp_scale_factor\n\n        burst_image_meta_info['exposure'] = burst_exposure\n        burst_image_meta_info['f_number'] = burst_f_number\n        burst_image_meta_info['iso'] = burst_iso\n\n        gt_image_meta_info['exposure'] = canon_exposure\n        gt_image_meta_info['f_number'] = canon_f_number\n        gt_image_meta_info['iso'] = canon_iso\n\n        burst = burst.float()\n        frame_gt = gt_image_data.float()\n\n        meta_info_burst = burst_image_meta_info\n        meta_info_gt = gt_image_meta_info\n\n        del meta_info_gt['crop_info']\n\n        for k, v in meta_info_gt.items():\n            if isinstance(v, (list, tuple)):\n                meta_info_gt[k] = torch.tensor(v)\n\n        for k, v in meta_info_burst.items():\n            if isinstance(v, (list, tuple)):\n                meta_info_burst[k] = torch.tensor(v)\n\n        meta_info_burst['burst_name'] = meta_info['burst_name']\n\n        return burst, frame_gt, meta_info_burst, meta_info_gt\n\n\ndef pack_raw_image(im_raw):\n    if isinstance(im_raw, np.ndarray):\n        im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2))\n    elif isinstance(im_raw, torch.Tensor):\n        im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype).to(im_raw.device)\n    else:\n        raise Exception\n\n    im_out[0, :, :] = im_raw[0::2, 0::2]\n    im_out[1, :, :] = im_raw[0::2, 1::2]\n    im_out[2, :, :] = im_raw[1::2, 0::2]\n    im_out[3, :, :] = im_raw[1::2, 1::2]\n    return im_out\n\n\ndef flatten_raw_image(im_raw_4ch):\n    if isinstance(im_raw_4ch, np.ndarray):\n        im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2))\n    elif isinstance(im_raw_4ch, torch.Tensor):\n        im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device)\n    else:\n        raise Exception\n\n    im_out[0::2, 0::2] = im_raw_4ch[0, :, :]\n    im_out[0::2, 1::2] = im_raw_4ch[1, :, :]\n    im_out[1::2, 0::2] = im_raw_4ch[2, :, :]\n    im_out[1::2, 1::2] = im_raw_4ch[3, :, :]\n\n    return im_out\n\ndef pack_raw_image_batch(im_raw):\n    im_out = torch.zeros((im_raw.shape[0], im_raw.shape[1], 4, im_raw.shape[3] // 2, im_raw.shape[4] // 2), dtype=im_raw.dtype).to(im_raw.device)\n    im_out[:, :, 0, :, :] = im_raw[:, :, 0, 0::2, 0::2]\n    im_out[:, :, 1, :, :] = im_raw[:, :, 0, 0::2, 1::2]\n    im_out[:, :, 2, :, :] = im_raw[:, :, 0, 1::2, 0::2]\n    im_out[:, :, 3, :, :] = im_raw[:, :, 0, 1::2, 1::2]\n    return im_out\n\n\ndef flatten_raw_image_batch(im_raw_4ch):\n    im_out = torch.zeros((im_raw_4ch.shape[0], im_raw_4ch.shape[1], 1, im_raw_4ch.shape[3] * 2, im_raw_4ch.shape[4] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device)\n    im_out[:, :, 0, 0::2, 0::2] = im_raw_4ch[:, :, 0, :, :]\n    im_out[:, :, 0, 0::2, 1::2] = im_raw_4ch[:, :, 1, :, :]\n    im_out[:, :, 0, 1::2, 0::2] = im_raw_4ch[:, :, 2, :, :]\n    im_out[:, :, 0, 1::2, 1::2] = im_raw_4ch[:, :, 3, :, :]\n\n    return im_out\n"
  },
  {
    "path": "code/real/bsrt/datasets/burstsr_test_dataset.py",
    "content": "import os\nimport torch\nimport torch.nn.functional as F\nimport random\nfrom .burstsr_dataset import SamsungRAWImage, flatten_raw_image, pack_raw_image\n\n\nclass BurstSRDataset(torch.utils.data.Dataset):\n    \"\"\" Real-world burst super-resolution dataset. \"\"\"\n    def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='test'):\n        \"\"\"\n        args:\n            root : path of the root directory\n            burst_size : Burst size. Maximum allowed burst size is 14.\n            crop_sz: Size of the extracted crop. Maximum allowed crop size is 80\n            center_crop: Whether to extract a random crop, or a centered crop.\n            random_flip: Whether to apply random horizontal and vertical flip\n            split: Can be 'train' or 'val'\n        \"\"\"\n        assert burst_size <= 14, 'burst_sz must be less than or equal to 14'\n        assert crop_sz <= 80, 'crop_sz must be less than or equal to 80'\n        assert split in ['test']\n\n        root = root + '/' + split\n        super().__init__()\n\n        self.burst_size = burst_size\n        self.crop_sz = crop_sz\n        self.split = split\n        self.center_crop = center_crop\n        self.random_flip = random_flip\n\n        self.root = root\n\n        self.substract_black_level = True\n        self.white_balance = False\n\n        self.burst_list = self._get_burst_list()\n\n    def _get_burst_list(self):\n        burst_list = sorted(os.listdir('{}'.format(self.root)))\n\n        return burst_list\n\n    def get_burst_info(self, burst_id):\n        burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]}\n        return burst_info\n\n    def _get_raw_image(self, burst_id, im_id):\n        raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id))\n        return raw_image\n\n    def get_burst(self, burst_id, im_ids, info=None):\n        frames = [self._get_raw_image(burst_id, i) for i in im_ids]\n\n        if info is None:\n            info = self.get_burst_info(burst_id)\n\n        return frames, info\n\n    def _sample_images(self):\n        burst_size = 14\n\n        ids = random.sample(range(1, burst_size), k=self.burst_size - 1)\n        ids = [0, ] + ids\n        return ids\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def __getitem__(self, index):\n        # Sample the images in the burst, in case a burst_size < 14 is used.\n        im_ids = self._sample_images()\n\n        # Read the burst images along with HR ground truth\n        frames, meta_info = self.get_burst(index, im_ids)\n\n        # Extract crop if needed\n        if frames[0].shape()[-1] != self.crop_sz:\n            if getattr(self, 'center_crop', False):\n                r1 = (frames[0].shape()[-2] - self.crop_sz) // 2\n                c1 = (frames[0].shape()[-1] - self.crop_sz) // 2\n            else:\n                r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz)\n                c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz)\n            r2 = r1 + self.crop_sz\n            c2 = c1 + self.crop_sz\n\n            frames = [im.get_crop(r1, r2, c1, c2) for im in frames]\n\n        # Load the RAW image data\n        burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level,\n                                              white_balance=self.white_balance) for im in frames]\n\n        if self.random_flip:\n            burst_image_data = [flatten_raw_image(im) for im in burst_image_data]\n\n            pad = [0, 0, 0, 0]\n            if random.random() > 0.5:\n                burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data]\n                pad[1] = 1\n\n            if random.random() > 0.5:\n                burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data]\n                pad[3] = 1\n\n            burst_image_data = [pack_raw_image(im) for im in burst_image_data]\n            burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data]\n\n        burst_image_meta_info = frames[0].get_all_meta_data()\n\n        burst_image_meta_info['black_level_subtracted'] = self.substract_black_level\n        burst_image_meta_info['while_balance_applied'] = self.white_balance\n        burst_image_meta_info['norm_factor'] = frames[0].norm_factor\n\n        burst = torch.stack(burst_image_data, dim=0)\n\n        burst_exposure = frames[0].get_exposure_time()\n\n        burst_f_number = frames[0].get_f_number()\n\n        burst_iso = frames[0].get_iso()\n\n        burst_image_meta_info['exposure'] = burst_exposure\n        burst_image_meta_info['f_number'] = burst_f_number\n        burst_image_meta_info['iso'] = burst_iso\n\n        burst = burst.float()\n\n        meta_info_burst = burst_image_meta_info\n\n        for k, v in meta_info_burst.items():\n            if isinstance(v, (list, tuple)):\n                meta_info_burst[k] = torch.tensor(v)\n\n        return burst, meta_info_burst"
  },
  {
    "path": "code/real/bsrt/datasets/data_sampler.py",
    "content": "\"\"\"\nModified from torch.utils.data.distributed.DistributedSampler\nSupport enlarging the dataset for *iter-oriented* training, for saving time when restart the\ndataloader after each epoch\n\"\"\"\nimport math\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data.sampler import Sampler\n\n\nclass DistIterSampler(Sampler):\n    \"\"\"Sampler that restricts data loading to a subset of the dataset.\n\n    It is especially useful in conjunction with\n    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each\n    process can pass a DistributedSampler instance as a DataLoader sampler,\n    and load a subset of the original dataset that is exclusive to it.\n\n    .. note::\n        Dataset is assumed to be of constant size.\n\n    Arguments:\n        dataset: Dataset used for sampling.\n        num_replicas (optional): Number of processes participating in\n            distributed training.\n        rank (optional): Rank of the current process within num_replicas.\n    \"\"\"\n\n    def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):\n        if num_replicas is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            num_replicas = dist.get_world_size()\n        if rank is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            rank = dist.get_rank()\n        self.dataset = dataset\n        self.num_replicas = num_replicas\n        self.rank = rank\n        self.epoch = 0\n        self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))\n        self.total_size = self.num_samples * self.num_replicas\n\n    def __iter__(self):\n        # deterministically shuffle based on epoch\n        g = torch.Generator()\n        g.manual_seed(self.epoch)\n        indices = torch.randperm(\n            self.total_size, generator=g\n        ).tolist()  # Returns a random permutation of integers from 0 to n - 1\n\n        dsize = len(self.dataset)\n        indices = [v % dsize for v in indices]\n\n        # subsample\n        indices = indices[self.rank : self.total_size : self.num_replicas]\n        assert len(indices) == self.num_samples\n\n        return iter(indices)\n\n    def __len__(self):\n        return self.num_samples\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n"
  },
  {
    "path": "code/real/bsrt/datasets/realworld_burst_test_set.py",
    "content": "import torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass RealWorldBurstTest(torch.utils.data.Dataset):\n    \"\"\"\n    \"\"\"\n    def __init__(self, root):\n        self.root = root\n        self.burst_list = list(range(20))\n        self.burst_size = 14\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def _read_burst_image(self, index, image_id):\n        im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)\n        im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)\n        return im_t\n\n    def __getitem__(self, index):\n        \"\"\"\n                args:\n                    index: Index of the burst\n\n                returns:\n                    burst: LR RAW burst, a torch tensor of shape\n                           The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.\n                    meta_info: Meta information about the burst\n                \"\"\"\n        burst_name = '{:04d}'.format(index)\n        burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]\n        burst = torch.stack(burst, 0)\n\n        meta_info = {}\n        meta_info['burst_name'] = burst_name\n\n        return burst, meta_info\n"
  },
  {
    "path": "code/real/bsrt/datasets/synthetic_burst_test_set.py",
    "content": "import torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass SyntheticBurstTest(torch.utils.data.Dataset):\n    \"\"\" Synthetic burst test set. The test burst have been generated using the same synthetic pipeline as\n    employed in SyntheticBurst dataset.\n    \"\"\"\n    def __init__(self, root):\n        self.root = root\n        self.burst_list = list(range(92))\n        self.burst_size = 14\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def _read_burst_image(self, index, image_id):\n        im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)\n        im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)\n        return im_t\n\n    def __getitem__(self, index):\n        \"\"\" Generates a synthetic burst\n                args:\n                    index: Index of the burst\n\n                returns:\n                    burst: LR RAW burst, a torch tensor of shape\n                           The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.\n                    meta_info: Meta information about the burst\n                \"\"\"\n        burst_name = '{:04d}'.format(index)\n        burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]\n        burst = torch.stack(burst, 0)\n\n        meta_info = {}\n        meta_info['burst_name'] = burst_name\n\n        return burst, meta_info\n"
  },
  {
    "path": "code/real/bsrt/datasets/synthetic_burst_train_set.py",
    "content": "import torch\nimport numpy as np\nfrom PIL import Image\nfrom data_processing.synthetic_burst_generation import rgb2rawburst, random_crop #syn_burst_utils\nimport torchvision.transforms as tfm\n\n\nclass SyntheticBurst(torch.utils.data.Dataset):\n    \"\"\" Synthetic burst dataset for joint denoising, demosaicking, and super-resolution. RAW Burst sequences are\n    synthetically generated on the fly as follows. First, a single image is loaded from the base_dataset. The sampled\n    image is converted to linear sensor space using the inverse camera pipeline employed in [1]. A burst\n    sequence is then generated by adding random translations and rotations to the converted image. The generated burst\n    is then converted is then mosaicked, and corrupted by random noise to obtain the RAW burst.\n\n    [1] Unprocessing Images for Learned Raw Denoising, Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen,\n    Jiawen and Sharlet, Dillon and Barron, Jonathan T, CVPR 2019\n    \"\"\"\n    def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=tfm.ToTensor()):\n        self.base_dataset = base_dataset\n\n        self.burst_size = burst_size\n        self.crop_sz = crop_sz\n        self.transform = transform\n\n        self.downsample_factor = 4\n        self.burst_transformation_params = {'max_translation': 24.0,\n                                            'max_rotation': 1.0,\n                                            'max_shear': 0.0,\n                                            'max_scale': 0.0,\n                                            'border_crop': 24}\n\n        self.image_processing_params = {'random_ccm': True, 'random_gains': True, 'smoothstep': True,\n                                        'gamma': True,\n                                        'add_noise': True}\n        self.interpolation_type = 'bilinear'\n\n    def __len__(self):\n        return len(self.base_dataset)\n\n    def __getitem__(self, index):\n        \"\"\" Generates a synthetic burst\n        args:\n            index: Index of the image in the base_dataset used to generate the burst\n\n        returns:\n            burst: Generated LR RAW burst, a torch tensor of shape\n                   [burst_size, 4, self.crop_sz / (2*self.downsample_factor), self.crop_sz / (2*self.downsample_factor)]\n                   The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.\n                   The extra factor 2 in the denominator (2*self.downsample_factor) corresponds to the mosaicking\n                   operation.\n\n            frame_gt: The HR RGB ground truth in the linear sensor space, a torch tensor of shape\n                      [3, self.crop_sz, self.crop_sz]\n\n            flow_vectors: The ground truth flow vectors between a burst image and the base image (i.e. the first image in the burst).\n                          The flow_vectors can be used to warp the burst images to the base frame, using the 'warp'\n                          function in utils.warp package.\n                          flow_vectors is torch tensor of shape\n                          [burst_size, 2, self.crop_sz / self.downsample_factor, self.crop_sz / self.downsample_factor].\n                          Note that the flow_vectors are in the LR RGB space, before mosaicking. Hence it has twice\n                          the number of rows and columns, compared to the output burst.\n\n                          NOTE: The flow_vectors are only available during training for the purpose of using any\n                                auxiliary losses if needed. The flow_vectors will NOT be provided for the bursts in the\n                                test set\n\n            meta_info: A dictionary containing the parameters used to generate the synthetic burst.\n        \"\"\"\n        frame = self.base_dataset[index]\n\n        # Augmentation, e.g. convert to tensor\n        if self.transform is not None:\n            # frame = Image.fromarray(frame)\n            frame = self.transform(frame)\n\n        # Extract a random crop from the image\n        crop_sz = self.crop_sz + 2 * self.burst_transformation_params.get('border_crop', 0)\n        frame_crop = random_crop(frame, crop_sz)\n\n        # Generate RAW burst\n        burst, frame_gt, burst_rgb, flow_vectors, meta_info = rgb2rawburst(frame_crop,\n                                                                           self.burst_size,\n                                                                           self.downsample_factor,\n                                                                           burst_transformation_params=self.burst_transformation_params,\n                                                                           image_processing_params=self.image_processing_params,\n                                                                           interpolation_type=self.interpolation_type\n                                                                           )\n\n        if self.burst_transformation_params.get('border_crop') is not None:\n            border_crop = self.burst_transformation_params.get('border_crop')\n            frame_gt = frame_gt[:, border_crop:-border_crop, border_crop:-border_crop]\n\n        return burst, frame_gt, flow_vectors, meta_info\n"
  },
  {
    "path": "code/real/bsrt/datasets/synthetic_burst_val_set.py",
    "content": "import os\nimport torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass SyntheticBurstVal(torch.utils.data.Dataset):\n    \"\"\" Synthetic burst validation set introduced in [1]. The validation burst have been generated using a\n    synthetic data generation pipeline. The dataset can be downloaded from\n    https://data.vision.ee.ethz.ch/bhatg/SyntheticBurstVal.zip\n\n    [1] Deep Burst Super-Resolution. Goutam Bhat, Martin Danelljan, Luc Van Gool, and Radu Timofte. CVPR 2021\n    \"\"\"\n    def __init__(self, root=None, initialize=True):\n        \"\"\"\n        args:\n            root - Path to root dataset directory\n            initialize - boolean indicating whether to load the meta-data for the dataset\n        \"\"\"\n        self.root = os.path.join(root, 'val')\n        self.burst_list = list(range(300))\n        self.burst_size = 14\n\n    def initialize(self):\n        pass\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def _read_burst_image(self, index, image_id):\n        im = cv2.imread('{}/bursts/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)\n        im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)\n\n        return im_t\n\n    def _read_gt_image(self, index):\n        gt = cv2.imread('{}/gt/{:04d}/im_rgb.png'.format(self.root, index), cv2.IMREAD_UNCHANGED)\n        gt_t = (torch.from_numpy(gt.astype(np.float32)) / 2 ** 14).permute(2, 0, 1).float()\n        return gt_t\n\n    def _read_meta_info(self, index):\n        with open('{}/gt/{:04d}/meta_info.pkl'.format(self.root, index), \"rb\") as input_file:\n            meta_info = pkl.load(input_file)\n\n        return meta_info\n\n    def __getitem__(self, index):\n        \"\"\" Generates a synthetic burst\n        args:\n            index: Index of the burst\n\n        returns:\n            burst: LR RAW burst, a torch tensor of shape\n                   [14, 4, 48, 48]\n                   The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.\n            gt : Ground truth linear image\n            meta_info: Meta info about the burst which can be used to convert gt to sRGB space\n        \"\"\"\n        burst_name = '{:04d}'.format(index)\n        burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]\n        burst = torch.stack(burst, 0)\n\n        gt = self._read_gt_image(index)\n        meta_info = self._read_meta_info(index)\n        meta_info['burst_name'] = burst_name\n        return burst, gt, meta_info\n"
  },
  {
    "path": "code/real/bsrt/datasets/zurich_raw2rgb_dataset.py",
    "content": "import torch\nimport os\nimport numpy as np\nfrom cv2 import imread\n\n\nclass ZurichRAW2RGB(torch.utils.data.Dataset):\n    \"\"\" Canon RGB images from the \"Zurich RAW to RGB mapping\" dataset. You can download the full\n    dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the\n    Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip\n    \"\"\"\n    def __init__(self, root, split='train'):\n        super().__init__()\n\n        if split in ['train', 'test']:\n            self.img_pth = os.path.join(root, split, 'canon')\n        else:\n            raise Exception('Unknown split {}'.format(split))\n\n        self.image_list = self._get_image_list(split)\n\n    def _get_image_list(self, split):\n        if split == 'train':\n            image_list = ['{:d}.jpg'.format(i) for i in range(46839)]\n        elif split == 'test':\n            # image_list = ['{:d}.jpg'.format(int(i)) for i in np.linspace(1, 1200, 200)]\n            image_list = ['{:d}.jpg'.format(i) for i in range(1200)]\n        else:\n            raise Exception\n\n        return image_list\n\n    def _get_image(self, im_id):\n        path = os.path.join(self.img_pth, self.image_list[im_id])\n        img = imread(path)\n        return img\n\n    def get_image(self, im_id):\n        frame = self._get_image(im_id)\n\n        return frame\n\n    def __len__(self):\n        return len(self.image_list)\n\n    def __getitem__(self, index):\n        frame = self._get_image(index)\n\n        return frame\n"
  },
  {
    "path": "code/real/bsrt/demo.sh",
    "content": "#!/usr/bin/env bash\n\n\npython main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth \n# python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_large --model BSRT --fp16 --model_level L --swinfeature --batch_size 8 --burst_size 14 --patch_size 48 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_large/bsrt_best_epoch.pth \n\n\n# python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real\n# python test_real.py --n_GPUs 1 --model BSRT --model_level L --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_large/bsrt_realworld.pth --root /data/dataset/ntire21/burstsr/real"
  },
  {
    "path": "code/real/bsrt/loss/Charbonnier.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass CharbonnierLoss(nn.Module):\n    \"\"\"L1 charbonnier loss.\"\"\"\n\n    def __init__(self, epsilon=1e-3, reduce=True):\n        super(CharbonnierLoss, self).__init__()\n        self.eps = epsilon * epsilon\n        self.reduce = reduce\n\n    def forward(self, X, Y):\n        diff = torch.add(X, -Y)\n        error = torch.sqrt(diff * diff + self.eps)\n        if self.reduce:\n            loss = torch.mean(error)\n        else:\n            loss = error\n        return loss"
  },
  {
    "path": "code/real/bsrt/loss/__init__.py",
    "content": "import os\nfrom importlib import import_module\n\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Loss(nn.modules.loss._Loss):\n    def __init__(self, args, ckp):\n        super(Loss, self).__init__()\n        if args.local_rank == 0:\n            print('Preparing loss function:')\n\n        self.n_GPUs = args.n_GPUs\n        self.loss = []\n        self.loss_module = nn.ModuleList()\n        for loss in args.loss.split('+'):\n            weight, loss_type = loss.split('*')\n            if loss_type == 'MSE':\n                loss_function = nn.MSELoss()\n            elif loss_type == 'L1':\n                loss_function = nn.L1Loss()\n            elif loss_type.find('VGG') >= 0:\n                module = import_module('loss.vgg')\n                loss_function = getattr(module, 'VGG')(\n                    loss_type[3:],\n                    rgb_range=args.rgb_range\n                )\n            elif loss_type.find('GAN') >= 0:\n                module = import_module('loss.adversarial')\n                loss_function = getattr(module, 'Adversarial')(\n                    args,\n                    loss_type\n                )\n            elif loss_type == 'FILTER':\n                module = import_module('loss.filter')\n                loss_function = getattr(module, 'Filter')(args)\n            elif loss_type == 'SSIM':\n                module = import_module('loss.mssim')\n                loss_function = getattr(module, 'SSIM')(args)\n            elif loss_type == 'MSSSIM':\n                module = import_module('loss.mssim')\n                loss_function = getattr(module, 'MSSSIM')(args)\n\n            self.loss.append({\n                'type': loss_type,\n                'weight': float(weight),\n                'function': loss_function}\n            )\n            if loss_type.find('GAN') >= 0:\n                self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})\n\n        if len(self.loss) > 1:\n            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})\n\n        for l in self.loss:\n            if l['function'] is not None:\n                if args.local_rank == 0:\n                    print('{:.3f} * {}'.format(l['weight'], l['type']))\n                self.loss_module.append(l['function'])\n\n        self.log = torch.Tensor()\n\n        device = torch.device('cpu' if args.cpu else 'cuda')\n        self.loss_module.to(device)\n        if args.precision == 'half': self.loss_module.half()\n        if not args.cpu and args.n_GPUs > 1:\n            self.loss_module = nn.DataParallel(\n                self.loss_module, range(args.n_GPUs)\n            )\n\n        if args.load != '': self.load(ckp.dir, cpu=args.cpu)\n\n    def forward(self, sr, hr):\n        losses = []\n        for i, l in enumerate(self.loss):\n            if l['function'] is not None:\n                loss = l['function'](sr, hr)\n                effective_loss = l['weight'] * loss\n                losses.append(effective_loss)\n                self.log[-1, i] += effective_loss.item()\n            elif l['type'] == 'DIS':\n                self.log[-1, i] += self.loss[i - 1]['function'].loss\n\n        loss_sum = sum(losses)\n        if len(self.loss) > 1:\n            self.log[-1, -1] += loss_sum.item()\n\n        return loss_sum\n\n    def step(self):\n        for l in self.get_loss_module():\n            if hasattr(l, 'scheduler'):\n                l.scheduler.step()\n\n    def start_log(self):\n        self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))\n\n    def end_log(self, n_batches):\n        self.log[-1].div_(n_batches)\n\n    def display_loss(self, batch):\n        n_samples = batch + 1\n        log = []\n        for l, c in zip(self.loss, self.log[-1]):\n            log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))\n\n        return ''.join(log)\n\n    def plot_loss(self, apath, epoch):\n        axis = np.linspace(1, epoch, epoch)\n        for i, l in enumerate(self.loss):\n            label = '{} Loss'.format(l['type'])\n            fig = plt.figure()\n            plt.title(label)\n            plt.plot(axis, self.log[:, i].numpy(), label=label)\n            plt.legend()\n            plt.xlabel('Epochs')\n            plt.ylabel('Loss')\n            plt.grid(True)\n            plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))\n            plt.close(fig)\n\n    def get_loss_module(self):\n        if self.n_GPUs == 1:\n            return self.loss_module\n        else:\n            return self.loss_module.module\n\n    def save(self, apath):\n        torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))\n        torch.save(self.log, os.path.join(apath, 'loss_log.pt'))\n\n    def load(self, apath, cpu=False):\n        if cpu:\n            kwargs = {'map_location': lambda storage, loc: storage}\n        else:\n            kwargs = {}\n\n        self.load_state_dict(torch.load(\n            os.path.join(apath, 'loss.pt'),\n            **kwargs\n        ))\n        self.log = torch.load(os.path.join(apath, 'loss_log.pt'))\n        for l in self.get_loss_module():\n            if hasattr(l, 'scheduler'):\n                for _ in range(len(self.log)): l.scheduler.step()\n\n"
  },
  {
    "path": "code/real/bsrt/loss/adversarial.py",
    "content": "import utility\nfrom types import SimpleNamespace\n\nfrom model import common\nfrom loss import discriminator\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nclass Adversarial(nn.Module):\n    def __init__(self, args, gan_type):\n        super(Adversarial, self).__init__()\n        self.gan_type = gan_type\n        self.gan_k = args.gan_k\n        self.dis = discriminator.Discriminator(args)\n        # if gan_type == 'WGAN_GP':\n        if True:\n            # see https://arxiv.org/pdf/1704.00028.pdf pp.4\n            optim_dict = {\n                'optimizer': 'ADAM',\n                'betas': (0.5, 0.9),\n                'epsilon': 1e-8,\n                'lr': 1e-5,\n                'weight_decay': args.weight_decay,\n                'decay': args.decay,\n                'gamma': args.gamma\n            }\n            optim_args = SimpleNamespace(**optim_dict)\n        else:\n            optim_args = args\n\n        self.optimizer = utility.make_optimizer(optim_args, self.dis)\n\n    def forward(self, fake, real):\n        # updating discriminator...\n        self.loss = 0\n        fake_detach = fake.detach()     # do not backpropagate through G\n        for _ in range(self.gan_k):\n            self.optimizer.zero_grad()\n            # d: B x 1 tensor\n            d_fake = self.dis(fake_detach)\n            d_real = self.dis(real)\n            retain_graph = False\n            if self.gan_type in ['GAN', 'SNGAN']:\n                loss_d = self.bce(d_real, d_fake)\n            elif self.gan_type.find('WGAN') >= 0:\n                loss_d = (d_fake - d_real).mean()\n                if self.gan_type.find('GP') >= 0:\n                    epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)\n                    hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)\n                    hat.requires_grad = True\n                    d_hat = self.dis(hat)\n                    gradients = torch.autograd.grad(\n                        outputs=d_hat.sum(), inputs=hat,\n                        retain_graph=True, create_graph=True, only_inputs=True\n                    )[0]\n                    gradients = gradients.view(gradients.size(0), -1)\n                    gradient_norm = gradients.norm(2, dim=1)\n                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()\n                    loss_d += gradient_penalty\n            # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks\n            elif self.gan_type == 'RGAN':\n                better_real = d_real - d_fake.mean(dim=0, keepdim=True)\n                better_fake = d_fake - d_real.mean(dim=0, keepdim=True)\n                loss_d = self.bce(better_real, better_fake)\n                retain_graph = True\n\n            # Discriminator update\n            self.loss += loss_d.item()\n            loss_d.backward(retain_graph=retain_graph)\n            self.optimizer.step()\n\n            if self.gan_type == 'WGAN':\n                for p in self.dis.parameters():\n                    p.data.clamp_(-1, 1)\n\n        self.loss /= self.gan_k\n\n        # updating generator...\n        d_fake_bp = self.dis(fake)      # for backpropagation, use fake as it is\n        if self.gan_type in ['GAN', 'SNGAN']:\n            label_real = torch.ones_like(d_fake_bp)\n            loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)\n        elif self.gan_type.find('WGAN') >= 0:\n            loss_g = -d_fake_bp.mean()\n        elif self.gan_type == 'RGAN':\n            better_real = d_real.detach() - d_fake_bp.mean(dim=0, keepdim=True)\n            better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True).detach()\n            loss_g = self.bce(better_fake, better_real)\n\n        # Generator loss\n        return loss_g\n\n    def state_dict(self, *args, **kwargs):\n        state_discriminator = self.dis.state_dict(*args, **kwargs)\n        state_optimizer = self.optimizer.state_dict()\n\n        return dict(**state_discriminator, **state_optimizer)\n\n    def bce(self, real, fake):\n        label_real = torch.ones_like(real)\n        label_fake = torch.zeros_like(fake)\n        bce_real = F.binary_cross_entropy_with_logits(real, label_real)\n        bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)\n        bce_loss = bce_real + bce_fake\n        return bce_loss\n\n# Some references\n# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py\n# OR\n# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py\n"
  },
  {
    "path": "code/real/bsrt/loss/discriminator.py",
    "content": "from model import common\n\nimport torch.nn as nn\n\nclass Discriminator(nn.Module):\n    '''\n        output is not normalized\n    '''\n    def __init__(self, args, gan_type='GAN'):\n        super(Discriminator, self).__init__()\n\n        in_channels = args.n_colors\n        out_channels = 32\n        depth = 6\n\n        def _block(_in_channels, _out_channels, stride=1):\n\n            Conv = nn.Conv2d(\n                    _in_channels,\n                    _out_channels,\n                    3,\n                    padding=1,\n                    stride=stride,\n                    bias=False\n                )\n\n            if gan_type == 'SNGAN':\n                return nn.Sequential(\n                            spectral_norm(Conv),\n                            nn.BatchNorm2d(_out_channels),\n                            nn.LeakyReLU(negative_slope=0.2, inplace=True)\n                )\n            else:\n                return nn.Sequential(\n                    Conv,\n                    nn.BatchNorm2d(_out_channels),\n                    nn.LeakyReLU(negative_slope=0.2, inplace=True)\n                )\n\n        m_features = [_block(in_channels, out_channels)]\n        for i in range(depth):\n            in_channels = out_channels\n            # if i % 2 == 1:\n            #     stride = 1\n            #     out_channels *= 2\n            # else:\n            out_channels *= 2\n            stride = 2\n            m_features.append(_block(in_channels, out_channels, stride=stride))\n\n        patch_size = args.patch_size // 2**(depth-1)\n\n        # print(out_channels, patch_size)\n\n        m_classifier = [\n            nn.Flatten(),\n            nn.Linear(out_channels*patch_size**2, 512),\n            nn.LeakyReLU(0.2, True),\n            nn.Linear(512, 1)\n        ]\n\n        self.features = nn.Sequential(*m_features)\n        self.classifier = nn.Sequential(*m_classifier)\n\n    def forward(self, x):\n        features = self.features(x)\n        # print(features.shape)\n        output = self.classifier(features)\n\n        return output\n\n"
  },
  {
    "path": "code/real/bsrt/loss/filter.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Filter(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n\n        kernel = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])\n        self.conv = nn.Conv2d(args.n_colors, args.n_colors, 3, 3)\n        with torch.no_grad():\n            self.conv.weight.copy_(kernel.float())\n        self.loss = nn.L1Loss()\n\n    def forward(self, x, y):\n        preds_x = self.conv(x)\n        preds_y = self.conv(y)\n\n        return self.loss(preds_x, preds_y)\n"
  },
  {
    "path": "code/real/bsrt/loss/hist_entropy.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass HistEntropy(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n\n    def forward(self, x):\n        p = torch.softmax(x, dim=1)\n        logp = torch.log_softmax(x, dim=1)\n\n        entropy = (-p * logp).sum(dim=(2, 3)).mean()\n\n        return entropy\n"
  },
  {
    "path": "code/real/bsrt/loss/mssim.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\n\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])\n    return gauss/gauss.sum()\n\n\ndef create_window(window_size, channel=1):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()\n    return window\n\n\ndef ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):\n    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).\n    if val_range is None:\n        if torch.max(img1) > 128:\n            max_val = 255\n        else:\n            max_val = 1\n\n        if torch.min(img1) < -0.5:\n            min_val = -1\n        else:\n            min_val = 0\n        L = max_val - min_val\n    else:\n        L = val_range\n\n    padd = 0\n    (_, channel, height, width) = img1.size()\n    if window is None:\n        real_size = min(window_size, height, width)\n        window = create_window(real_size, channel=channel).to(img1.device)\n\n    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2\n\n    C1 = (0.01 * L) ** 2\n    C2 = (0.03 * L) ** 2\n\n    v1 = 2.0 * sigma12 + C2\n    v2 = sigma1_sq + sigma2_sq + C2\n    cs = torch.mean(v1 / v2)  # contrast sensitivity\n\n    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)\n\n    if size_average:\n        ret = ssim_map.mean()\n    else:\n        ret = ssim_map.mean(1).mean(1).mean(1)\n\n    if full:\n        return ret, cs\n    return ret\n\n\ndef msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None):\n    device = img1.device\n    weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)\n    levels = weights.size()[0]\n    ssims = []\n    mcs = []\n    for _ in range(levels):\n        sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)\n\n        # Relu normalize (not compliant with original definition)\n        if normalize == \"relu\":\n            ssims.append(torch.relu(sim))\n            mcs.append(torch.relu(cs))\n        else:\n            ssims.append(sim)\n            mcs.append(cs)\n\n        img1 = F.avg_pool2d(img1, (2, 2))\n        img2 = F.avg_pool2d(img2, (2, 2))\n\n    ssims = torch.stack(ssims)\n    mcs = torch.stack(mcs)\n\n    # Simple normalize (not compliant with original definition)\n    # TODO: remove support for normalize == True (kept for backward support)\n    if normalize == \"simple\" or normalize == True:\n        ssims = (ssims + 1) / 2\n        mcs = (mcs + 1) / 2\n\n    pow1 = mcs ** weights\n    pow2 = ssims ** weights\n\n    # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/\n    output = torch.prod(pow1[:-1] * pow2[-1])\n    return output\n\n\n# Classes to re-use window\nclass SSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, val_range=None):\n        super(SSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.val_range = val_range\n\n        # Assume 1 channel for SSIM\n        self.channel = 1\n        self.window = create_window(window_size)\n\n    def forward(self, img1, img2):\n        (_, channel, _, _) = img1.size()\n\n        if channel == self.channel and self.window.dtype == img1.dtype:\n            window = self.window\n        else:\n            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)\n            self.window = window\n            self.channel = channel\n\n        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)\n\nclass MSSSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, channel=3):\n        super(MSSSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.channel = channel\n\n    def forward(self, img1, img2):\n        # TODO: store window between calls if possible\n        return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)"
  },
  {
    "path": "code/real/bsrt/loss/vgg.py",
    "content": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models as models\n\nclass VGG(nn.Module):\n    def __init__(self, conv_index, rgb_range=1):\n        super(VGG, self).__init__()\n        vgg_features = models.vgg19(pretrained=True).features\n        modules = [m for m in vgg_features]\n        if conv_index.find('22') >= 0:\n            self.vgg = nn.Sequential(*modules[:8])\n        elif conv_index.find('54') >= 0:\n            self.vgg = nn.Sequential(*modules[:35])\n\n        vgg_mean = (0.485, 0.456, 0.406)\n        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)\n        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)\n        for p in self.parameters():\n            p.requires_grad = False\n\n    def forward(self, sr, hr):\n        def _forward(x):\n            # x = self.sub_mean(x)\n            x = self.vgg(x)\n            return x\n\n        sr = sr.repeat(1, 3, 1, 1)\n        hr = hr.repeat(1, 3, 1, 1)\n\n        vgg_sr = _forward(sr)\n        with torch.no_grad():\n            vgg_hr = _forward(hr.detach())\n\n        loss = F.mse_loss(vgg_sr, vgg_hr)\n\n        return loss\n"
  },
  {
    "path": "code/real/bsrt/main.py",
    "content": "import torch\nimport random\nimport numpy as np\nfrom torch.utils.data import DataLoader\nimport os\n\nimport utility\nimport model\nimport loss\nfrom option import args\nfrom trainer import Trainer\nfrom datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image\nimport torch.multiprocessing as mp\nimport torch.backends.cudnn as cudnn\nimport torch.distributed as dist\nimport torch.utils.data.distributed\n\n\ndef init_seeds(seed=0, cuda_deterministic=True):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html\n    if cuda_deterministic:  # slower, more reproducible\n        cudnn.deterministic = True\n        cudnn.benchmark = False\n    else:  # faster, less reproducible\n        cudnn.deterministic = False\n        cudnn.benchmark = True\n\ncheckpoint = utility.checkpoint(args)\n\ndef main():\n    mp.spawn(main_worker, nprocs=args.n_GPUs, args=(args.n_GPUs, args))\n\n\ndef main_worker(local_rank, nprocs, args):\n    # print(local_rank)\n    if checkpoint.ok:\n        args.local_rank = local_rank\n        init_seeds(local_rank+1)\n        cudnn.benchmark = True\n        utility.setup(local_rank, nprocs)\n        torch.cuda.set_device(local_rank)\n\n        batch_size = int(args.batch_size / nprocs)\n        train_data = BurstSRDataset(root=args.root,\n                                    burst_size=args.burst_size,\n                                    crop_sz=args.patch_size, random_flip=True,\n                                    center_crop=True, split='train')\n        valid_data = BurstSRDataset(root=args.root,\n                                    burst_size=14,\n                                    crop_sz=80, split='val')\n\n        if local_rank <= 0:\n            print(f\"train data: {len(train_data)}, test data: {len(valid_data)}\")\n\n        if nprocs > 1:\n            train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)\n            valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data, shuffle=False)\n            train_loader = DataLoader(dataset=train_data, batch_size=batch_size, num_workers=args.batch_size,\n                                        pin_memory=True, drop_last=True, sampler=train_sampler)  # args.cpus\n            valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, num_workers=args.batch_size,\n                                        pin_memory=True, drop_last=True, sampler=valid_sampler)  # args.cpus\n        else:\n            train_sampler = None\n            train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=8,\n                                    shuffle=True, pin_memory=True, drop_last=True)  # args.cpus\n            valid_loader = DataLoader(dataset=valid_data, batch_size=args.batch_size, num_workers=4, shuffle=False,\n                                    pin_memory=True, drop_last=True)  # args.cpus\n\n        _model = model.Model(args, checkpoint)\n\n        _loss = loss.Loss(args, checkpoint) if not args.test_only else None\n        t = Trainer(args, train_loader, train_sampler, valid_loader, _model, _loss, checkpoint)\n        while not t.terminate():\n            t.train()\n\n        del _model\n        del _loss\n        del train_loader\n        del valid_loader\n\n        # checkpoint.done()\n\nif __name__ == '__main__':\n    # if not args.cpu: torch.cuda.set_device(0)\n    main()\n"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2019, Charles Shang\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its\n   contributors may be used to endorse or promote products derived from\n   this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "code/real/bsrt/model/DCNv2/README.md",
    "content": "## Deformable Convolutional Networks V2 with Pytorch 1.0\n\n### Build\n```bash\n    ./make.sh         # build\n    python test.py    # run examples and gradient check \n```\n\n### An Example\n- deformable conv\n```python\n    from dcn_v2 import DCN\n    input = torch.randn(2, 64, 128, 128).cuda()\n    # wrap all things (offset and mask) in DCN\n    dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()\n    output = dcn(input)\n    print(output.shape)\n```\n- deformable roi pooling\n```python\n    from dcn_v2 import DCNPooling\n    input = torch.randn(2, 32, 64, 64).cuda()\n    batch_inds = torch.randint(2, (20, 1)).cuda().float()\n    x = torch.randint(256, (20, 1)).cuda().float()\n    y = torch.randint(256, (20, 1)).cuda().float()\n    w = torch.randint(64, (20, 1)).cuda().float()\n    h = torch.randint(64, (20, 1)).cuda().float()\n    rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)\n\n    # mdformable pooling (V2)\n    # wrap all things (offset and mask) in DCNPooling\n    dpooling = DCNPooling(spatial_scale=1.0 / 4,\n                         pooled_size=7,\n                         output_dim=32,\n                         no_trans=False,\n                         group_size=1,\n                         trans_std=0.1).cuda()\n\n    dout = dpooling(input, rois)\n```\n### Note\nNow the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,\n```bash\ngit checkout pytorch_0.4\n```\n\n### Known Issues:\n\n- [x] Gradient check w.r.t offset (solved)\n- [ ] Backward is not reentrant (minor)\n\nThis is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).\n\n<s>I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.\nHowever, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some\nnon-differential points? </s>\n\nUpdate: all gradient check passes with double precision. \n\nAnother issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for \nfloat `<1e-15` for double), \nso it may not be a serious problem (?)\n\nPlease post an issue or PR if you have any comments.\n    "
  },
  {
    "path": "code/real/bsrt/model/DCNv2/__init__.py",
    "content": ""
  },
  {
    "path": "code/real/bsrt/model/DCNv2/dcn_v2.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import absolute_import, division, print_function\n\nimport math\n\nimport torch\nfrom torch import nn\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\nfrom torch.nn.modules.utils import _pair\nfrom torch.cuda.amp import custom_fwd, custom_bwd\n# from apex import amp\n\nimport _ext as _backend\n\n\nclass _DCNv2(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    # @amp.float_function\n    def forward(\n        ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups\n    ):\n        ctx.stride = _pair(stride)\n        ctx.padding = _pair(padding)\n        ctx.dilation = _pair(dilation)\n        ctx.kernel_size = _pair(weight.shape[2:4])\n        ctx.deformable_groups = deformable_groups\n        output = _backend.dcn_v2_forward(\n            input,\n            weight,\n            bias,\n            offset,\n            mask,\n            ctx.kernel_size[0],\n            ctx.kernel_size[1],\n            ctx.stride[0],\n            ctx.stride[1],\n            ctx.padding[0],\n            ctx.padding[1],\n            ctx.dilation[0],\n            ctx.dilation[1],\n            ctx.deformable_groups,\n        )\n        ctx.save_for_backward(input, offset, mask, weight, bias)\n        return output\n\n    @staticmethod\n    @once_differentiable\n    @custom_bwd\n    # @amp.float_function\n    def backward(ctx, grad_output):\n        input, offset, mask, weight, bias = ctx.saved_tensors\n        grad_input, grad_offset, grad_mask, grad_weight, grad_bias = _backend.dcn_v2_backward(\n            input,\n            weight,\n            bias,\n            offset,\n            mask,\n            grad_output,\n            ctx.kernel_size[0],\n            ctx.kernel_size[1],\n            ctx.stride[0],\n            ctx.stride[1],\n            ctx.padding[0],\n            ctx.padding[1],\n            ctx.dilation[0],\n            ctx.dilation[1],\n            ctx.deformable_groups,\n        )\n\n        return grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None\n\n    @staticmethod\n    def symbolic(\n        g, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups\n    ):\n        from torch.nn.modules.utils import _pair\n\n        stride = _pair(stride)\n        padding = _pair(padding)\n        dilation = _pair(dilation)\n        # as of trt 7, the dcn operation will be translated again by modifying the onnx file\n        # so the exporting code is kept to resemble the forward()\n        return g.op(\n            \"DCNv2_2\",\n            input,\n            offset,\n            mask,\n            weight,\n            bias,\n            stride_i=stride,\n            padding_i=padding,\n            dilation_i=dilation,\n            deformable_groups_i=deformable_groups,\n        )\n\n\ndcn_v2_conv = _DCNv2.apply\n\n\nclass DCNv2(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride,\n        padding,\n        dilation=1,\n        deformable_groups=1,\n    ):\n        super(DCNv2, self).__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = _pair(kernel_size)\n        self.stride = _pair(stride)\n        self.padding = _pair(padding)\n        self.dilation = _pair(dilation)\n        self.deformable_groups = deformable_groups\n\n        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))\n        self.bias = nn.Parameter(torch.Tensor(out_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        n = self.in_channels\n        for k in self.kernel_size:\n            n *= k\n        stdv = 1.0 / math.sqrt(n)\n        self.weight.data.uniform_(-stdv, stdv)\n        self.bias.data.zero_()\n\n    def forward(self, input, offset, mask):\n        assert (\n            2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1]\n            == offset.shape[1]\n        )\n        assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1]\n        return dcn_v2_conv(\n            input,\n            offset,\n            mask,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.deformable_groups,\n        )\n\n\nclass DCN(DCNv2):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride,\n        padding,\n        dilation=1,\n        deformable_groups=1,\n    ):\n        super(DCN, self).__init__(\n            in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups\n        )\n\n        channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]\n        self.conv_offset_mask = nn.Conv2d(\n            self.in_channels,\n            channels_,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            bias=True,\n        )\n        self.init_offset()\n\n    def init_offset(self):\n        self.conv_offset_mask.weight.data.zero_()\n        self.conv_offset_mask.bias.data.zero_()\n\n    def forward(self, input):\n        out = self.conv_offset_mask(input)\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n        offset = torch.cat((o1, o2), dim=1)\n        mask = torch.sigmoid(mask)\n        return dcn_v2_conv(\n            input,\n            offset,\n            mask,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.deformable_groups,\n        )\n\n\nclass DCN_sep(DCNv2):\n    '''Use other features to generate offsets and masks'''\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 stride,\n                 padding,\n                 dilation=1,\n                 deformable_groups=1):\n        super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding,\n                                      dilation, deformable_groups)\n\n        channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]\n        self.conv_offset_mask = nn.Conv2d(\n            self.in_channels,\n            channels_,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            bias=True)\n        self.init_offset()\n\n    def init_offset(self):\n        self.conv_offset_mask.weight.data.zero_()\n        self.conv_offset_mask.bias.data.zero_()\n\n    def forward(self, input, fea):\n        '''input: input features for deformable conv\n        fea: other features used for generating offsets and mask'''\n        out = self.conv_offset_mask(fea)\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n        offset = torch.cat((o1, o2), dim=1)\n        # offset = torch.clamp(offset, -100, 100)\n\n        offset_mean = torch.mean(torch.abs(offset))\n        if offset_mean > 250:\n            print('Offset mean is {}, larger than 100.'.format(offset_mean))\n            # return None\n            # offset[offset>=150] = 1e-3\n            # offset = offset.clamp(-50, 50)\n\n        mask = torch.sigmoid(mask)\n        return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,\n                           self.dilation, self.deformable_groups)\n\n\nclass FlowGuidedDCN(DCNv2):\n    '''Use other features to generate offsets and masks'''\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 stride,\n                 padding,\n                 dilation=1,\n                 deformable_groups=1):\n        super(FlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding,\n                                      dilation, deformable_groups)\n\n        channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]\n        self.conv_offset_mask = nn.Conv2d(\n            in_channels, channels_, kernel_size, stride, padding, bias=True)\n\n        self.init_offset()\n\n    def init_offset(self):\n        self.conv_offset_mask.weight.data.zero_()\n        self.conv_offset_mask.bias.data.zero_()\n\n    def forward(self, input, fea, flows):\n        '''input: input features for deformable conv: N, C, H, W.\n           fea: other features used for generating offsets and mask: N, C, H, W.\n           flows: N, 2, H, W.\n        '''\n        out = self.conv_offset_mask(fea)\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n\n        offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude\n        offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1)\n\n        offset_mean = torch.mean(torch.abs(offset))\n        if offset_mean > 250:\n            print('FlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean))\n            # offset = offset.clamp(-50, 50)\n            # return None\n\n        mask = torch.sigmoid(mask)\n        return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,\n                           self.dilation, self.deformable_groups)\n\n\n\nclass InsideFlowGuidedDCN(DCNv2):\n    '''Use other features to generate offsets and masks'''\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 stride,\n                 padding,\n                 dilation=1,\n                 deformable_groups=1):\n        super(InsideFlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding,\n                                      dilation, deformable_groups)\n\n        channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]\n        self.conv_offset_mask = nn.Sequential(\n            nn.Conv2d(in_channels*2+2, out_channels, kernel_size, stride, padding, bias=True),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=True),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(out_channels, channels_, kernel_size, stride, padding, bias=True)\n        )\n\n        self.reset_parameters()\n        self.init_offset()\n\n    def reset_parameters(self):\n        n = self.in_channels\n        for k in self.kernel_size:\n            n *= k\n        stdv = 1.0 / math.sqrt(n)\n        self.weight.data.uniform_(-stdv, stdv)\n        self.bias.data.zero_()\n\n\n    def init_offset(self):\n        self.conv_offset_mask[-1].weight.data.zero_()\n        self.conv_offset_mask[-1].bias.data.zero_()\n\n    def forward(self, input, warped, ref, flows):\n        '''input: input features for deformable conv: N, C, H, W.\n           fea: other features used for generating offsets and mask: N, C, H, W.\n           flows: N, 2, H, W.\n        '''\n        out = self.conv_offset_mask(torch.cat([warped, ref, flows], dim=1))\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n\n        offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude\n        offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1)\n\n        offset_mean = torch.mean(torch.abs(offset))\n        if offset_mean > 250:\n            print('InsideFlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean))\n            print('flow mean is {}'.format(torch.abs(flows).mean()))\n            offset = offset.clamp(-50, 50)\n            # return None\n\n        mask = torch.sigmoid(mask)\n        return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,\n                           self.dilation, self.deformable_groups)\n\n\n\nclass _DCNv2Pooling(Function):\n    @staticmethod\n    def forward(\n        ctx,\n        input,\n        rois,\n        offset,\n        spatial_scale,\n        pooled_size,\n        output_dim,\n        no_trans,\n        group_size=1,\n        part_size=None,\n        sample_per_part=4,\n        trans_std=0.0,\n    ):\n        ctx.spatial_scale = spatial_scale\n        ctx.no_trans = int(no_trans)\n        ctx.output_dim = output_dim\n        ctx.group_size = group_size\n        ctx.pooled_size = pooled_size\n        ctx.part_size = pooled_size if part_size is None else part_size\n        ctx.sample_per_part = sample_per_part\n        ctx.trans_std = trans_std\n\n        output, output_count = _backend.dcn_v2_psroi_pooling_forward(\n            input,\n            rois,\n            offset,\n            ctx.no_trans,\n            ctx.spatial_scale,\n            ctx.output_dim,\n            ctx.group_size,\n            ctx.pooled_size,\n            ctx.part_size,\n            ctx.sample_per_part,\n            ctx.trans_std,\n        )\n        ctx.save_for_backward(input, rois, offset, output_count)\n        return output\n\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_output):\n        input, rois, offset, output_count = ctx.saved_tensors\n        grad_input, grad_offset = _backend.dcn_v2_psroi_pooling_backward(\n            grad_output,\n            input,\n            rois,\n            offset,\n            output_count,\n            ctx.no_trans,\n            ctx.spatial_scale,\n            ctx.output_dim,\n            ctx.group_size,\n            ctx.pooled_size,\n            ctx.part_size,\n            ctx.sample_per_part,\n            ctx.trans_std,\n        )\n\n        return grad_input, None, grad_offset, None, None, None, None, None, None, None, None\n\n\ndcn_v2_pooling = _DCNv2Pooling.apply\n\n\nclass DCNv2Pooling(nn.Module):\n    def __init__(\n        self,\n        spatial_scale,\n        pooled_size,\n        output_dim,\n        no_trans,\n        group_size=1,\n        part_size=None,\n        sample_per_part=4,\n        trans_std=0.0,\n    ):\n        super(DCNv2Pooling, self).__init__()\n        self.spatial_scale = spatial_scale\n        self.pooled_size = pooled_size\n        self.output_dim = output_dim\n        self.no_trans = no_trans\n        self.group_size = group_size\n        self.part_size = pooled_size if part_size is None else part_size\n        self.sample_per_part = sample_per_part\n        self.trans_std = trans_std\n\n    def forward(self, input, rois, offset):\n        assert input.shape[1] == self.output_dim\n        if self.no_trans:\n            offset = input.new()\n        return dcn_v2_pooling(\n            input,\n            rois,\n            offset,\n            self.spatial_scale,\n            self.pooled_size,\n            self.output_dim,\n            self.no_trans,\n            self.group_size,\n            self.part_size,\n            self.sample_per_part,\n            self.trans_std,\n        )\n\n\nclass DCNPooling(DCNv2Pooling):\n    def __init__(\n        self,\n        spatial_scale,\n        pooled_size,\n        output_dim,\n        no_trans,\n        group_size=1,\n        part_size=None,\n        sample_per_part=4,\n        trans_std=0.0,\n        deform_fc_dim=1024,\n    ):\n        super(DCNPooling, self).__init__(\n            spatial_scale,\n            pooled_size,\n            output_dim,\n            no_trans,\n            group_size,\n            part_size,\n            sample_per_part,\n            trans_std,\n        )\n\n        self.deform_fc_dim = deform_fc_dim\n\n        if not no_trans:\n            self.offset_mask_fc = nn.Sequential(\n                nn.Linear(\n                    self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim\n                ),\n                nn.ReLU(inplace=True),\n                nn.Linear(self.deform_fc_dim, self.deform_fc_dim),\n                nn.ReLU(inplace=True),\n                nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3),\n            )\n            self.offset_mask_fc[4].weight.data.zero_()\n            self.offset_mask_fc[4].bias.data.zero_()\n\n    def forward(self, input, rois):\n        offset = input.new()\n\n        if not self.no_trans:\n\n            # do roi_align first\n            n = rois.shape[0]\n            roi = dcn_v2_pooling(\n                input,\n                rois,\n                offset,\n                self.spatial_scale,\n                self.pooled_size,\n                self.output_dim,\n                True,  # no trans\n                self.group_size,\n                self.part_size,\n                self.sample_per_part,\n                self.trans_std,\n            )\n\n            # build mask and offset\n            offset_mask = self.offset_mask_fc(roi.view(n, -1))\n            offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size)\n            o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)\n            offset = torch.cat((o1, o2), dim=1)\n            mask = torch.sigmoid(mask)\n\n            # do pooling with offset and mask\n            return (\n                dcn_v2_pooling(\n                    input,\n                    rois,\n                    offset,\n                    self.spatial_scale,\n                    self.pooled_size,\n                    self.output_dim,\n                    self.no_trans,\n                    self.group_size,\n                    self.part_size,\n                    self.sample_per_part,\n                    self.trans_std,\n                )\n                * mask\n            )\n        # only roi_align\n        return dcn_v2_pooling(\n            input,\n            rois,\n            offset,\n            self.spatial_scale,\n            self.pooled_size,\n            self.output_dim,\n            self.no_trans,\n            self.group_size,\n            self.part_size,\n            self.sample_per_part,\n            self.trans_std,\n        )\n"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/files.txt",
    "content": "/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu.so\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.py\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/PKG-INFO\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/SOURCES.txt\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/dependency_links.txt\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/native_libs.txt\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/not-zip-safe\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/top_level.txt\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/__pycache__/_ext.cpython-37.pyc\n"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/make.sh",
    "content": "#!/usr/bin/env bash\npython setup.py build develop\n"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/setup.py",
    "content": "#!/usr/bin/env python\n\nimport glob\nimport os\n\nimport torch\nfrom setuptools import find_packages, setup\nfrom torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension\n\nrequirements = [\"torch\", \"torchvision\"]\n\n\ndef get_extensions():\n    this_dir = os.path.dirname(os.path.abspath(__file__))\n    extensions_dir = os.path.join(this_dir, \"src\")\n\n    main_file = glob.glob(os.path.join(extensions_dir, \"*.cpp\"))\n    source_cpu = glob.glob(os.path.join(extensions_dir, \"cpu\", \"*.cpp\"))\n    source_cuda = glob.glob(os.path.join(extensions_dir, \"cuda\", \"*.cu\"))\n\n    os.environ[\"CC\"] = \"g++\"\n    sources = main_file + source_cpu\n    extension = CppExtension\n    extra_compile_args = {\"cxx\": []}\n    define_macros = []\n\n    if True:\n        extension = CUDAExtension\n        sources += source_cuda\n        define_macros += [(\"WITH_CUDA\", None)]\n        extra_compile_args[\"nvcc\"] = [\n            \"-DCUDA_HAS_FP16=1\",\n            \"-D__CUDA_NO_HALF_OPERATORS__\",\n            \"-D__CUDA_NO_HALF_CONVERSIONS__\",\n            \"-D__CUDA_NO_HALF2_OPERATORS__\",\n        ]\n    else:\n        # raise NotImplementedError('Cuda is not available')\n        pass\n\n    sources = [os.path.join(extensions_dir, s) for s in sources]\n    include_dirs = [extensions_dir]\n    ext_modules = [\n        extension(\n            \"_ext\",\n            sources,\n            include_dirs=include_dirs,\n            define_macros=define_macros,\n            extra_compile_args=extra_compile_args,\n        )\n    ]\n    return ext_modules\n\n\nsetup(\n    name=\"DCNv2\",\n    version=\"0.1\",\n    author=\"charlesshang\",\n    url=\"https://github.com/charlesshang/DCNv2\",\n    description=\"deformable convolutional networks\",\n    packages=find_packages(exclude=(\"configs\", \"tests\")),\n    # install_requires=requirements,\n    ext_modules=get_extensions(),\n    cmdclass={\"build_ext\": torch.utils.cpp_extension.BuildExtension},\n)\n"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp",
    "content": "#include <vector>\n#include \"cpu/dcn_v2_im2col_cpu.h\"\n\n#include <ATen/ATen.h>\n//#include <ATen/cuda/CUDAContext.h>\n\n#include <TH/TH.h>\n//#include <THC/THCAtomics.cuh>\n//#include <THC/THCDeviceUtils.cuh>\n\n//extern THCState *state;\n\n// author: Charles Shang\n// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu\n// modified from the CUDA version for CPU use by Daniel K. Suhendro\n\nat::Tensor\ndcn_v2_cpu_forward(const at::Tensor &input,\n                    const at::Tensor &weight,\n                    const at::Tensor &bias,\n                    const at::Tensor &offset,\n                    const at::Tensor &mask,\n                    const int kernel_h,\n                    const int kernel_w,\n                    const int stride_h,\n                    const int stride_w,\n                    const int pad_h,\n                    const int pad_w,\n                    const int dilation_h,\n                    const int dilation_w,\n                    const int deformable_group)\n{\n    // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));\n    /*AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n    AT_ASSERTM(weight.type().is_cuda(), \"weight must be a CUDA tensor\");\n    AT_ASSERTM(bias.type().is_cuda(), \"bias must be a CUDA tensor\");\n    AT_ASSERTM(offset.type().is_cuda(), \"offset must be a CUDA tensor\");\n    AT_ASSERTM(mask.type().is_cuda(), \"mask must be a CUDA tensor\");*/\n\n    const int batch = input.size(0);\n    const int channels = input.size(1);\n    const int height = input.size(2);\n    const int width = input.size(3);\n\n    const int channels_out = weight.size(0);\n    const int channels_kernel = weight.size(1);\n    const int kernel_h_ = weight.size(2);\n    const int kernel_w_ = weight.size(3);\n\n    // printf(\"Kernels: %d %d %d %d\\n\", kernel_h_, kernel_w_, kernel_w, kernel_h);\n    // printf(\"Channels: %d %d\\n\", channels, channels_kernel);\n    // printf(\"Channels: %d %d\\n\", channels_out, channels_kernel);\n\n    AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,\n               \"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\", kernel_h_, kernel_w, kernel_h_, kernel_w_);\n\n    AT_ASSERTM(channels == channels_kernel,\n               \"Input shape and kernel channels wont match: (%d vs %d).\", channels, channels_kernel);\n\n    const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n    const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n    auto ones = at::ones({height_out, width_out}, input.options());\n    auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());\n    auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());\n\n    using scalar_t = float;\n    for (int b = 0; b < batch; b++)\n    {\n        auto input_n = input.select(0, b);\n        auto offset_n = offset.select(0, b);\n        auto mask_n = mask.select(0, b);\n        auto output_n = output.select(0, b);\n\n        // Do Bias first:\n        // M,N,K are dims of matrix A and B\n        // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)\n        // (N x 1) (1 x M)\n        long m_ = channels_out;\n        long n_ = height_out * width_out;\n        long k_ = 1;\n        THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,\n                         ones.contiguous().data<scalar_t>(), k_,\n                         bias.contiguous().data<scalar_t>(), k_, 0.0f,\n                         output_n.data<scalar_t>(), n_);\n\n        modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),\n                                         offset_n.data<scalar_t>(),\n                                         mask_n.data<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,\n                                         deformable_group,\n                                         columns.data<scalar_t>());\n\n        //(k * m)  x  (m * n)\n        // Y = WC\n        long m = channels_out;\n        long n = height_out * width_out;\n        long k = channels * kernel_h * kernel_w;\n        THFloatBlas_gemm('n', 'n', n, m, k, 1.0f,\n                         columns.data<scalar_t>(), n,\n                         weight.data<scalar_t>(), k, 1.0f,\n                         output_n.data<scalar_t>(), n);\n    }\n    return output;\n}\n\nstd::vector<at::Tensor> dcn_v2_cpu_backward(const at::Tensor &input,\n                                             const at::Tensor &weight,\n                                             const at::Tensor &bias,\n                                             const at::Tensor &offset,\n                                             const at::Tensor &mask,\n                                             const at::Tensor &grad_output,\n                                             int kernel_h, int kernel_w,\n                                             int stride_h, int stride_w,\n                                             int pad_h, int pad_w,\n                                             int dilation_h, int dilation_w,\n                                             int deformable_group)\n{\n\n    THArgCheck(input.is_contiguous(), 1, \"input tensor has to be contiguous\");\n    THArgCheck(weight.is_contiguous(), 2, \"weight tensor has to be contiguous\");\n\n    /*AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n    AT_ASSERTM(weight.type().is_cuda(), \"weight must be a CUDA tensor\");\n    AT_ASSERTM(bias.type().is_cuda(), \"bias must be a CUDA tensor\");\n    AT_ASSERTM(offset.type().is_cuda(), \"offset must be a CUDA tensor\");\n    AT_ASSERTM(mask.type().is_cuda(), \"mask must be a CUDA tensor\");*/\n\n    const int batch = input.size(0);\n    const int channels = input.size(1);\n    const int height = input.size(2);\n    const int width = input.size(3);\n\n    const int channels_out = weight.size(0);\n    const int channels_kernel = weight.size(1);\n    const int kernel_h_ = weight.size(2);\n    const int kernel_w_ = weight.size(3);\n\n    AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,\n               \"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\", kernel_h_, kernel_w, kernel_h_, kernel_w_);\n\n    AT_ASSERTM(channels == channels_kernel,\n               \"Input shape and kernel channels wont match: (%d vs %d).\", channels, channels_kernel);\n\n    const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n    const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n    auto ones = at::ones({height_out, width_out}, input.options());\n    auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());\n    auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());\n\n    auto grad_input = at::zeros_like(input);\n    auto grad_weight = at::zeros_like(weight);\n    auto grad_bias = at::zeros_like(bias);\n    auto grad_offset = at::zeros_like(offset);\n    auto grad_mask = at::zeros_like(mask);\n\n    using scalar_t = float;\n\n    for (int b = 0; b < batch; b++)\n    {\n        auto input_n = input.select(0, b);\n        auto offset_n = offset.select(0, b);\n        auto mask_n = mask.select(0, b);\n        auto grad_output_n = grad_output.select(0, b);\n        auto grad_input_n = grad_input.select(0, b);\n        auto grad_offset_n = grad_offset.select(0, b);\n        auto grad_mask_n = grad_mask.select(0, b);\n\n        long m = channels * kernel_h * kernel_w;\n        long n = height_out * width_out;\n        long k = channels_out;\n\n        THFloatBlas_gemm('n', 't', n, m, k, 1.0f,\n                         grad_output_n.data<scalar_t>(), n,\n                         weight.data<scalar_t>(), m, 0.0f,\n                         columns.data<scalar_t>(), n);\n\n        // gradient w.r.t. input coordinate data\n        modulated_deformable_col2im_coord_cpu(columns.data<scalar_t>(),\n                                               input_n.data<scalar_t>(),\n                                               offset_n.data<scalar_t>(),\n                                               mask_n.data<scalar_t>(),\n                                               1, channels, height, width,\n                                               height_out, width_out, kernel_h, kernel_w,\n                                               pad_h, pad_w, stride_h, stride_w,\n                                               dilation_h, dilation_w, deformable_group,\n                                               grad_offset_n.data<scalar_t>(),\n                                               grad_mask_n.data<scalar_t>());\n        // gradient w.r.t. input data\n        modulated_deformable_col2im_cpu(columns.data<scalar_t>(),\n                                         offset_n.data<scalar_t>(),\n                                         mask_n.data<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w,\n                                         dilation_h, dilation_w, deformable_group,\n                                         grad_input_n.data<scalar_t>());\n\n        // gradient w.r.t. weight, dWeight should accumulate across the batch and group\n        modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),\n                                         offset_n.data<scalar_t>(),\n                                         mask_n.data<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w,\n                                         dilation_h, dilation_w, deformable_group,\n                                         columns.data<scalar_t>());\n\n        long m_ = channels_out;\n        long n_ = channels * kernel_h * kernel_w;\n        long k_ = height_out * width_out;\n\n        THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,\n                         columns.data<scalar_t>(), k_,\n                         grad_output_n.data<scalar_t>(), k_, 1.0f,\n                         grad_weight.data<scalar_t>(), n_);\n\n        // gradient w.r.t. bias\n        // long m_ = channels_out;\n        // long k__ = height_out * width_out;\n        // THFloatBlas_gemv('t', k_, m_, 1.0f,\n        //                  grad_output_n.data<scalar_t>(), k_,\n        //                  ones.data<scalar_t>(), 1, 1.0f,\n        //                  grad_bias.data<scalar_t>(), 1);\n    }\n\n    return {\n        grad_input, grad_offset, grad_mask, grad_weight, grad_bias\n    };\n}"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp",
    "content": "#include \"dcn_v2_im2col_cpu.h\"\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n//#include <ATen/cuda/CUDAContext.h>\n\n#include <TH/TH.h>\n//#include <THC/THCAtomics.cuh>\n//#include <THC/THCDeviceUtils.cuh>\n\n// modified from the CUDA version for CPU use by Daniel K. Suhendro\n\n/*#define CUDA_KERNEL_LOOP(i, n)                          \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x;   \\\n      i < (n);                                          \\\n      i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N)\n{\n  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;\n}*/\n\n\nfloat dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_width,\n                           const int height, const int width, float h, float w)\n{\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  float lh = h - h_low;\n  float lw = w - w_low;\n  float hh = 1 - lh, hw = 1 - lw;\n\n  float v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n    v1 = bottom_data[h_low * data_width + w_low];\n  float v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n    v2 = bottom_data[h_low * data_width + w_high];\n  float v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n    v3 = bottom_data[h_high * data_width + w_low];\n  float v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n    v4 = bottom_data[h_high * data_width + w_high];\n\n  float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\nfloat dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w,\n                               const int h, const int w, const int height, const int width)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  float weight = 0;\n  if (h == argmax_h_low && w == argmax_w_low)\n    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);\n  if (h == argmax_h_low && w == argmax_w_high)\n    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);\n  if (h == argmax_h_high && w == argmax_w_low)\n    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);\n  if (h == argmax_h_high && w == argmax_w_high)\n    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);\n  return weight;\n}\n\nfloat dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w,\n                                 const int height, const int width, const float *im_data,\n                                 const int data_width, const int bp_dir)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  float weight = 0;\n\n  if (bp_dir == 0)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n  else if (bp_dir == 1)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n\n  return weight;\n}\n\nvoid modulated_deformable_im2col_cpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask,\n                                                       const int height, const int width, const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int num_channels, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       float *data_col)\n{\n  // launch channels * batch_size * height_col * width_col cores\n  for(int index=0; index<n; index++)\n  {\n    // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)\n    // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis\n\n    // index index of output matrix\n    const int w_col = index % width_col;\n    const int h_col = (index / width_col) % height_col;\n    // const int b_col = (index / width_col / height_col) % batch_size;\n    const int b_col = (index / width_col / height_col / num_channels) % batch_size;\n    // const int c_im = (index / width_col / height_col) / batch_size;\n    const int c_im = (index / width_col / height_col) % num_channels;\n    // const int c_col = c_im * kernel_h * kernel_w;\n    const int c_col = c_im * kernel_h * kernel_w;\n\n    // compute deformable group index\n    const int deformable_group_index = c_im / channel_per_deformable_group;\n\n    const int h_in = h_col * stride_h - pad_h;\n    const int w_in = w_col * stride_w - pad_w;\n\n    //  float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;\n    float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col;\n    //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;\n    const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;\n    const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n\n    const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    for (int i = 0; i < kernel_h; ++i)\n    {\n      for (int j = 0; j < kernel_w; ++j)\n      {\n        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;\n        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;\n        const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;\n        const float offset_h = data_offset_ptr[data_offset_h_ptr];\n        const float offset_w = data_offset_ptr[data_offset_w_ptr];\n        const float mask = data_mask_ptr[data_mask_hw_ptr];\n        float val = static_cast<float>(0);\n        const float h_im = h_in + i * dilation_h + offset_h;\n        const float w_im = w_in + j * dilation_w + offset_w;\n        //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {\n        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)\n        {\n          //const float map_h = i * dilation_h + offset_h;\n          //const float map_w = j * dilation_w + offset_w;\n          //const int cur_height = height - h_in;\n          //const int cur_width = width - w_in;\n          //val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, cur_height, cur_width, map_h, map_w);\n          val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im);\n        }\n        *data_col_ptr = val * mask;\n        // data_col_ptr += batch_size * height_col * width_col;\n        data_col_ptr += height_col * width_col;\n      }\n    }\n  }\n}\n\nvoid modulated_deformable_col2im_cpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask,\n                                                       const int channels, const int height, const int width,\n                                                       const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       float *grad_im)\n{\n  for(int index = 0; index < n; index++)\n  {\n    const int j = (index / width_col / height_col / batch_size) % kernel_w;\n    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;\n    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / channel_per_deformable_group;\n\n    int w_out = index % width_col;\n    int h_out = (index / width_col) % height_col;\n    int b = (index / width_col / height_col) % batch_size;\n    int w_in = w_out * stride_w - pad_w;\n    int h_in = h_out * stride_h - pad_h;\n\n    const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;\n    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;\n    const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;\n    const float offset_h = data_offset_ptr[data_offset_h_ptr];\n    const float offset_w = data_offset_ptr[data_offset_w_ptr];\n    const float mask = data_mask_ptr[data_mask_hw_ptr];\n    const float cur_inv_h_data = h_in + i * dilation_h + offset_h;\n    const float cur_inv_w_data = w_in + j * dilation_w + offset_w;\n\n    const float cur_top_grad = data_col[index] * mask;\n    const int cur_h = (int)cur_inv_h_data;\n    const int cur_w = (int)cur_inv_w_data;\n    \n    for (int dy = -2; dy <= 2; dy++)\n    {\n      for (int dx = -2; dx <= 2; dx++)\n      {\n        if (cur_h + dy >= 0 && cur_h + dy < height &&\n            cur_w + dx >= 0 && cur_w + dx < width &&\n            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&\n            abs(cur_inv_w_data - (cur_w + dx)) < 1)\n        {\n          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;\n          float weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);\n          //atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);\n          *(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad;\n\n        }\n      }\n    }\n  }\n}\n\nvoid modulated_deformable_col2im_coord_cpu_kernel(const int n, const float *data_col, const float *data_im,\n                                                             const float *data_offset, const float *data_mask,\n                                                             const int channels, const int height, const int width,\n                                                             const int kernel_h, const int kernel_w,\n                                                             const int pad_h, const int pad_w,\n                                                             const int stride_h, const int stride_w,\n                                                             const int dilation_h, const int dilation_w,\n                                                             const int channel_per_deformable_group,\n                                                             const int batch_size, const int offset_channels, const int deformable_group,\n                                                             const int height_col, const int width_col,\n                                                             float *grad_offset, float *grad_mask)\n{\n  for(int index = 0; index < n; index++)\n  {\n    float val = 0, mval = 0;\n    int w = index % width_col;\n    int h = (index / width_col) % height_col;\n    int c = (index / width_col / height_col) % offset_channels;\n    int b = (index / width_col / height_col) / offset_channels;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / (2 * kernel_h * kernel_w);\n    const int col_step = kernel_h * kernel_w;\n    int cnt = 0;\n    const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;\n    const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;\n    const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;\n\n    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)\n    {\n      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;\n      const int bp_dir = offset_c % 2;\n\n      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;\n      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;\n      int w_out = col_pos % width_col;\n      int h_out = (col_pos / width_col) % height_col;\n      int w_in = w_out * stride_w - pad_w;\n      int h_in = h_out * stride_h - pad_h;\n      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);\n      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);\n      const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);\n      const float offset_h = data_offset_ptr[data_offset_h_ptr];\n      const float offset_w = data_offset_ptr[data_offset_w_ptr];\n      const float mask = data_mask_ptr[data_mask_hw_ptr];\n      float inv_h = h_in + i * dilation_h + offset_h;\n      float inv_w = w_in + j * dilation_w + offset_w;\n      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)\n      {\n        inv_h = inv_w = -2;\n      }\n      else\n      {\n        mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);\n      }\n      const float weight = dmcn_get_coordinate_weight_cpu(\n          inv_h, inv_w,\n          height, width, data_im_ptr + cnt * height * width, width, bp_dir);\n      val += weight * data_col_ptr[col_pos] * mask;\n      cnt += 1;\n    }\n    // KERNEL_ASSIGN(grad_offset[index], offset_req, val);\n    grad_offset[index] = val;\n    if (offset_c % 2 == 0)\n      // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);\n      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;\n  }\n}\n\nvoid modulated_deformable_im2col_cpu(const float* data_im, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w,\n  const int deformable_group, float* data_col) {\n  // num_axes should be smaller than block size\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * batch_size * height_col * width_col;\n  modulated_deformable_im2col_cpu_kernel(\n      num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,\n      pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,\n      batch_size, channels, deformable_group, height_col, width_col, data_col);\n  \n  /*cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_im2col_cuda: %s\\n\", cudaGetErrorString(err));\n  }*/\n\n}\n\nvoid modulated_deformable_col2im_cpu(const float* data_col, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w, \n  const int deformable_group, float* grad_im){\n\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;\n  modulated_deformable_col2im_cpu_kernel(\n        num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,\n        kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,\n        dilation_h, dilation_w, channel_per_deformable_group,\n        batch_size, deformable_group, height_col, width_col, grad_im);\n  /*cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_cuda: %s\\n\", cudaGetErrorString(err));\n  }*/\n\n}\n\nvoid modulated_deformable_col2im_coord_cpu(const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w, \n  const int deformable_group,\n  float* grad_offset, float* grad_mask) {\n  const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;\n  const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;\n  modulated_deformable_col2im_coord_cpu_kernel(\n        num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,\n        kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,\n        dilation_h, dilation_w, channel_per_deformable_group,\n        batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, \n        grad_offset, grad_mask);\n  /*cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_coord_cuda: %s\\n\", cudaGetErrorString(err));\n  }*/\n}"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h",
    "content": "\n/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contributions by the University of California:\n * Copyright (c) 2014-2017 The Regents of the University of California (Regents)\n * All rights reserved.\n *\n * All other contributions:\n * Copyright (c) 2014-2017, the respective contributors\n * All rights reserved.\n *\n * Caffe uses a shared copyright model: each contributor holds copyright over\n * their contributions to Caffe. The project versioning records all such\n * contribution and copyright details. If a contributor wants to further mark\n * their specific copyright on a particular contribution, they should indicate\n * their copyright solely in the commit message of the change when it is\n * committed.\n *\n * LICENSE\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n * CONTRIBUTION AGREEMENT\n *\n * By contributing to the BVLC/caffe repository through pull-request, comment,\n * or otherwise, the contributor releases their content to the\n * license and copyright terms herein.\n *\n ***************** END Caffe Copyright Notice and Disclaimer ********************\n *\n * Copyright (c) 2018 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file modulated_deformable_im2col.h\n * \\brief Function definitions of converting an image to\n * column matrix based on kernel, padding, dilation, and offset.\n * These functions are mainly used in deformable convolution operators.\n * \\ref: https://arxiv.org/abs/1811.11168\n * \\author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu\n */\n\n/***************** Adapted by Charles Shang *********************/\n// modified from the CUDA version for CPU use by Daniel K. Suhendro\n\n#ifndef DCN_V2_IM2COL_CPU\n#define DCN_V2_IM2COL_CPU\n\n#ifdef __cplusplus\nextern \"C\"\n{\n#endif\n\n  void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask,\n                                        const int batch_size, const int channels, const int height_im, const int width_im,\n                                        const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                        const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                        const int dilation_h, const int dilation_w,\n                                        const int deformable_group, float *data_col);\n\n  void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask,\n                                        const int batch_size, const int channels, const int height_im, const int width_im,\n                                        const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                        const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                        const int dilation_h, const int dilation_w,\n                                        const int deformable_group, float *grad_im);\n\n  void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,\n                                         const int batch_size, const int channels, const int height_im, const int width_im,\n                                         const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                         const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                         const int dilation_h, const int dilation_w,\n                                         const int deformable_group,\n                                         float *grad_offset, float *grad_mask);\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp",
    "content": "/*!\n * Copyright (c) 2017 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file deformable_psroi_pooling.cu\n * \\brief\n * \\author Yi Li, Guodong Zhang, Jifeng Dai\n*/\n/***************** Adapted by Charles Shang *********************/\n// modified from the CUDA version for CPU use by Daniel K. Suhendro\n\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n//#include <ATen/cuda/CUDAContext.h>\n\n#include <TH/TH.h>\n//#include <THC/THCAtomics.cuh>\n//#include <THC/THCDeviceUtils.cuh>\n\n/*#define CUDA_KERNEL_LOOP(i, n)                        \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \\\n       i < (n);                                       \\\n       i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N)\n{\n  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;\n}*/\n\ntemplate <typename T>\nT bilinear_interp_cpu(\n    const T *data,\n    const T x,\n    const T y,\n    const int width,\n    const int height)\n{\n  int x1 = floor(x);\n  int x2 = ceil(x);\n  int y1 = floor(y);\n  int y2 = ceil(y);\n  T dist_x = static_cast<T>(x - x1);\n  T dist_y = static_cast<T>(y - y1);\n  T value11 = data[y1 * width + x1];\n  T value12 = data[y2 * width + x1];\n  T value21 = data[y1 * width + x2];\n  T value22 = data[y2 * width + x2];\n  T value = (1 - dist_x) * (1 - dist_y) * value11 +\n            (1 - dist_x) * dist_y * value12 +\n            dist_x * (1 - dist_y) * value21 +\n            dist_x * dist_y * value22;\n  return value;\n}\n\ntemplate <typename T>\n void DeformablePSROIPoolForwardKernelCpu(\n    const int count,\n    const T *bottom_data,\n    const T spatial_scale,\n    const int channels,\n    const int height, const int width,\n    const int pooled_height, const int pooled_width,\n    const T *bottom_rois, const T *bottom_trans,\n    const int no_trans,\n    const T trans_std,\n    const int sample_per_part,\n    const int output_dim,\n    const int group_size,\n    const int part_size,\n    const int num_classes,\n    const int channels_each_class,\n    T *top_data,\n    T *top_count)\n{\n  for(int index = 0; index < count; index++)\n  {\n    // The output is in order (n, ctop, ph, pw)\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int ctop = (index / pooled_width / pooled_height) % output_dim;\n    int n = index / pooled_width / pooled_height / output_dim;\n\n    // [start, end) interval for spatial sampling\n    const T *offset_bottom_rois = bottom_rois + n * 5;\n    int roi_batch_ind = offset_bottom_rois[0];\n    T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;\n    T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;\n    T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;\n    T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;\n\n    // Force too small ROIs to be 1x1\n    T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0\n    T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));\n\n    // Compute w and h at bottom\n    T bin_size_h = roi_height / static_cast<T>(pooled_height);\n    T bin_size_w = roi_width / static_cast<T>(pooled_width);\n\n    T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);\n    T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);\n\n    int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);\n    int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);\n    int class_id = ctop / channels_each_class;\n    T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;\n    T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;\n\n    T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;\n    wstart += trans_x * roi_width;\n    T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;\n    hstart += trans_y * roi_height;\n\n    T sum = 0;\n    int count = 0;\n    int gw = floor(static_cast<T>(pw) * group_size / pooled_width);\n    int gh = floor(static_cast<T>(ph) * group_size / pooled_height);\n    gw = std::min(std::max(gw, 0), group_size - 1);\n    gh = std::min(std::max(gh, 0), group_size - 1);\n\n    const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;\n    for (int ih = 0; ih < sample_per_part; ih++)\n    {\n      for (int iw = 0; iw < sample_per_part; iw++)\n      {\n        T w = wstart + iw * sub_bin_size_w;\n        T h = hstart + ih * sub_bin_size_h;\n        // bilinear interpolation\n        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)\n        {\n          continue;\n        }\n        w = std::min(std::max(w, T(0.)), width - T(1.));\n        h = std::min(std::max(h, T(0.)), height - T(1.));\n        int c = (ctop * group_size + gh) * group_size + gw;\n        T val = bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height);\n        sum += val;\n        count++;\n      }\n    }\n    top_data[index] = count == 0 ? static_cast<T>(0) : sum / count;\n    top_count[index] = count;\n  }\n}\n\ntemplate <typename T>\nvoid DeformablePSROIPoolBackwardAccKernelCpu(\n    const int count,\n    const T *top_diff,\n    const T *top_count,\n    const int num_rois,\n    const T spatial_scale,\n    const int channels,\n    const int height, const int width,\n    const int pooled_height, const int pooled_width,\n    const int output_dim,\n    T *bottom_data_diff, T *bottom_trans_diff,\n    const T *bottom_data,\n    const T *bottom_rois,\n    const T *bottom_trans,\n    const int no_trans,\n    const T trans_std,\n    const int sample_per_part,\n    const int group_size,\n    const int part_size,\n    const int num_classes,\n    const int channels_each_class)\n{\n  for(int index = 0; index < count; index++)\n  {\n    // The output is in order (n, ctop, ph, pw)\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int ctop = (index / pooled_width / pooled_height) % output_dim;\n    int n = index / pooled_width / pooled_height / output_dim;\n\n    // [start, end) interval for spatial sampling\n    const T *offset_bottom_rois = bottom_rois + n * 5;\n    int roi_batch_ind = offset_bottom_rois[0];\n    T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;\n    T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;\n    T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;\n    T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;\n    \n    // Force too small ROIs to be 1x1\n    T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0\n    T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));\n\n    // Compute w and h at bottom\n    T bin_size_h = roi_height / static_cast<T>(pooled_height);\n    T bin_size_w = roi_width / static_cast<T>(pooled_width);\n\n    T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);\n    T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);\n\n    int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);\n    int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);\n    int class_id = ctop / channels_each_class;\n    T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;\n    T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;\n\n    T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;\n    wstart += trans_x * roi_width;\n    T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;\n    hstart += trans_y * roi_height;\n\n    if (top_count[index] <= 0)\n    {\n      continue;\n    }\n    T diff_val = top_diff[index] / top_count[index];\n    const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;\n    T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;\n    int gw = floor(static_cast<T>(pw) * group_size / pooled_width);\n    int gh = floor(static_cast<T>(ph) * group_size / pooled_height);\n    gw = std::min(std::max(gw, 0), group_size - 1);\n    gh = std::min(std::max(gh, 0), group_size - 1);\n\n    for (int ih = 0; ih < sample_per_part; ih++)\n    {\n      for (int iw = 0; iw < sample_per_part; iw++)\n      {\n        T w = wstart + iw * sub_bin_size_w;\n        T h = hstart + ih * sub_bin_size_h;\n        // bilinear interpolation\n        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)\n        {\n          continue;\n        }\n        w = std::min(std::max(w, T(0.)), width - T(1.));\n        h = std::min(std::max(h, T(0.)), height - T(1.));\n        int c = (ctop * group_size + gh) * group_size + gw;\n        // backward on feature\n        int x0 = floor(w);\n        int x1 = ceil(w);\n        int y0 = floor(h);\n        int y1 = ceil(h);\n        T dist_x = w - x0, dist_y = h - y0;\n        T q00 = (1 - dist_x) * (1 - dist_y);\n        T q01 = (1 - dist_x) * dist_y;\n        T q10 = dist_x * (1 - dist_y);\n        T q11 = dist_x * dist_y;\n        int bottom_index_base = c * height * width;\n        /*atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);*/\n       *(offset_bottom_data_diff + bottom_index_base + y0 * width + x0) += q00 * diff_val;\n       *(offset_bottom_data_diff + bottom_index_base + y1 * width + x0) += q01 * diff_val;\n       *(offset_bottom_data_diff + bottom_index_base + y0 * width + x1) += q10 * diff_val;\n       *(offset_bottom_data_diff + bottom_index_base + y1 * width + x1) += q11 * diff_val;\n\n\n        if (no_trans)\n        {\n          continue;\n        }\n        T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];\n        T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];\n        T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];\n        T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];\n        T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;\n        diff_x *= roi_width;\n        T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;\n        diff_y *= roi_height;\n\n        /*atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);\n        atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);*/\n        *(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w) += diff_x;\n        *(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w) += diff_y;\n      }\n    }\n  }\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,\n                                  const at::Tensor &bbox,\n                                  const at::Tensor &trans,\n                                  const int no_trans,\n                                  const float spatial_scale,\n                                  const int output_dim,\n                                  const int group_size,\n                                  const int pooled_size,\n                                  const int part_size,\n                                  const int sample_per_part,\n                                  const float trans_std)\n{\n  /*AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n  AT_ASSERTM(bbox.type().is_cuda(), \"rois must be a CUDA tensor\");\n  AT_ASSERTM(trans.type().is_cuda(), \"trans must be a CUDA tensor\");*/\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n  const int channels_trans = no_trans ? 2 : trans.size(1);\n  const int num_bbox = bbox.size(0);\n\n  AT_ASSERTM(channels == output_dim, \"input channels and output channels must equal\");\n  auto pooled_height = pooled_size;\n  auto pooled_width = pooled_size;\n\n  auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());\n  long out_size = num_bbox * output_dim * pooled_height * pooled_width;\n  auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());\n\n  const int num_classes = no_trans ? 1 : channels_trans / 2;\n  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;\n\n  //cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (out.numel() == 0)\n  {\n    //THCudaCheck(cudaGetLastError());\n    return std::make_tuple(out, top_count);\n  }\n\n  /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));\n  dim3 block(512);*/\n\n  AT_DISPATCH_FLOATING_TYPES(input.type(), \"dcn_v2_psroi_pooling_cpu_forward\", [&] {\n    DeformablePSROIPoolForwardKernelCpu<scalar_t>(\n        out_size,\n        input.contiguous().data<scalar_t>(),\n        spatial_scale,\n        channels,\n        height, width,\n        pooled_height,\n        pooled_width,\n        bbox.contiguous().data<scalar_t>(),\n        trans.contiguous().data<scalar_t>(),\n        no_trans,\n        trans_std,\n        sample_per_part,\n        output_dim,\n        group_size,\n        part_size,\n        num_classes,\n        channels_each_class,\n        out.data<scalar_t>(),\n        top_count.data<scalar_t>());\n  });\n  //THCudaCheck(cudaGetLastError());\n  return std::make_tuple(out, top_count);\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,\n                                   const at::Tensor &input,\n                                   const at::Tensor &bbox,\n                                   const at::Tensor &trans,\n                                   const at::Tensor &top_count,\n                                   const int no_trans,\n                                   const float spatial_scale,\n                                   const int output_dim,\n                                   const int group_size,\n                                   const int pooled_size,\n                                   const int part_size,\n                                   const int sample_per_part,\n                                   const float trans_std)\n{\n  /*AT_ASSERTM(out_grad.type().is_cuda(), \"out_grad must be a CUDA tensor\");\n  AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n  AT_ASSERTM(bbox.type().is_cuda(), \"bbox must be a CUDA tensor\");\n  AT_ASSERTM(trans.type().is_cuda(), \"trans must be a CUDA tensor\");\n  AT_ASSERTM(top_count.type().is_cuda(), \"top_count must be a CUDA tensor\");*/\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n  const int channels_trans = no_trans ? 2 : trans.size(1);\n  const int num_bbox = bbox.size(0);\n\n  AT_ASSERTM(channels == output_dim, \"input channels and output channels must equal\");\n  auto pooled_height = pooled_size;\n  auto pooled_width = pooled_size;\n  long out_size = num_bbox * output_dim * pooled_height * pooled_width;\n  const int num_classes = no_trans ? 1 : channels_trans / 2;\n  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;\n\n  auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());\n  auto trans_grad = at::zeros_like(trans);\n\n  if (input_grad.numel() == 0)\n  {\n    //THCudaCheck(cudaGetLastError());\n    return std::make_tuple(input_grad, trans_grad);\n  }\n\n  /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));\n  dim3 block(512);\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/\n\n  AT_DISPATCH_FLOATING_TYPES(out_grad.type(), \"dcn_v2_psroi_pooling_cpu_backward\", [&] {\n    DeformablePSROIPoolBackwardAccKernelCpu<scalar_t>(\n        out_size,\n        out_grad.contiguous().data<scalar_t>(),\n        top_count.contiguous().data<scalar_t>(),\n        num_bbox,\n        spatial_scale,\n        channels,\n        height,\n        width,\n        pooled_height,\n        pooled_width,\n        output_dim,\n        input_grad.contiguous().data<scalar_t>(),\n        trans_grad.contiguous().data<scalar_t>(),\n        input.contiguous().data<scalar_t>(),\n        bbox.contiguous().data<scalar_t>(),\n        trans.contiguous().data<scalar_t>(),\n        no_trans,\n        trans_std,\n        sample_per_part,\n        group_size,\n        part_size,\n        num_classes,\n        channels_each_class);\n  });\n  //THCudaCheck(cudaGetLastError());\n  return std::make_tuple(input_grad, trans_grad);\n}"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cpu/vision.h",
    "content": "#pragma once\n#include <torch/extension.h>\n\nat::Tensor\ndcn_v2_cpu_forward(const at::Tensor &input,\n                    const at::Tensor &weight,\n                    const at::Tensor &bias,\n                    const at::Tensor &offset,\n                    const at::Tensor &mask,\n                    const int kernel_h,\n                    const int kernel_w,\n                    const int stride_h,\n                    const int stride_w,\n                    const int pad_h,\n                    const int pad_w,\n                    const int dilation_h,\n                    const int dilation_w,\n                    const int deformable_group);\n\nstd::vector<at::Tensor>\ndcn_v2_cpu_backward(const at::Tensor &input,\n                     const at::Tensor &weight,\n                     const at::Tensor &bias,\n                     const at::Tensor &offset,\n                     const at::Tensor &mask,\n                     const at::Tensor &grad_output,\n                     int kernel_h, int kernel_w,\n                     int stride_h, int stride_w,\n                     int pad_h, int pad_w,\n                     int dilation_h, int dilation_w,\n                     int deformable_group);\n\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,\n                                  const at::Tensor &bbox,\n                                  const at::Tensor &trans,\n                                  const int no_trans,\n                                  const float spatial_scale,\n                                  const int output_dim,\n                                  const int group_size,\n                                  const int pooled_size,\n                                  const int part_size,\n                                  const int sample_per_part,\n                                  const float trans_std);\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,\n                                   const at::Tensor &input,\n                                   const at::Tensor &bbox,\n                                   const at::Tensor &trans,\n                                   const at::Tensor &top_count,\n                                   const int no_trans,\n                                   const float spatial_scale,\n                                   const int output_dim,\n                                   const int group_size,\n                                   const int pooled_size,\n                                   const int part_size,\n                                   const int sample_per_part,\n                                   const float trans_std);"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_cuda.cu",
    "content": "#include <vector>\n#include \"cuda/dcn_v2_im2col_cuda.h\"\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/CUDABlas.h>\n#include <ATen/Dispatch.h>\n#include <ATen/div_rtn.h>\n#include <THC/THC.h>\n#include <THC/THCAtomics.cuh>\n#include <THC/THCDeviceUtils.cuh>\n#include <ATen/cuda/CUDABlas.h>\n#include <ATen/cuda/Exceptions.h>\n\nTHCState *state = at::globalContext().lazyInitCUDA();\n\nstatic cublasOperation_t _cublasOpFromChar(char op) {\n    switch (op) {\n      case 'n':\n      case 'N':\n        return CUBLAS_OP_N;\n      case 't':\n      case 'T':\n        return CUBLAS_OP_T;\n      case 'c':\n      case 'C':\n        return CUBLAS_OP_C;\n    }\n    AT_ERROR(\n        \"_cublasOpFromChar input should be 't', 'n' or 'c' but got `\", op, \"`\");\n  }\n\n  static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {\n    // Note: leading dimensions generally are checked that they are > 0\n    // and at least as big the result requires (even if the value won't\n    // be used).\n  \n    // Q: Why does Level3 check trans but this doesn't?\n    // A: In level 2, the sizes (m, n) specify the size of A\n    // (independent of trans value). In level 3. the sizes (m, n, k)\n    // specify the sizes of op(A), op(B) where op depend on trans\n    // values.\n    if (n <= 1)\n      *lda = std::max<int64_t>(m, 1);\n  }\n\n\n\n// author: Charles Shang\n// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu\n\n// [batch gemm]\n// https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu\n\n__global__ void createBatchGemmBuffer(const float **input_b, float **output_b,\n                                      float **columns_b, const float **ones_b,\n                                      const float **weight_b, const float **bias_b,\n                                      float *input, float *output,\n                                      float *columns, float *ones,\n                                      float *weight, float *bias,\n                                      const int input_stride, const int output_stride,\n                                      const int columns_stride, const int ones_stride,\n                                      const int num_batches)\n{\n    const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < num_batches)\n    {\n        input_b[idx] = input + idx * input_stride;\n        output_b[idx] = output + idx * output_stride;\n        columns_b[idx] = columns + idx * columns_stride;\n        ones_b[idx] = ones + idx * ones_stride;\n        // share weights and bias within a Mini-Batch\n        weight_b[idx] = weight;\n        bias_b[idx] = bias;\n    }\n}\n\nat::Tensor\ndcn_v2_cuda_forward(const at::Tensor &input,\n                    const at::Tensor &weight,\n                    const at::Tensor &bias,\n                    const at::Tensor &offset,\n                    const at::Tensor &mask,\n                    const int kernel_h,\n                    const int kernel_w,\n                    const int stride_h,\n                    const int stride_w,\n                    const int pad_h,\n                    const int pad_w,\n                    const int dilation_h,\n                    const int dilation_w,\n                    const int deformable_group)\n{\n    using scalar_t = float;\n    // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));\n    AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n    AT_ASSERTM(weight.type().is_cuda(), \"weight must be a CUDA tensor\");\n    AT_ASSERTM(bias.type().is_cuda(), \"bias must be a CUDA tensor\");\n    AT_ASSERTM(offset.type().is_cuda(), \"offset must be a CUDA tensor\");\n    AT_ASSERTM(mask.type().is_cuda(), \"mask must be a CUDA tensor\");\n\n    const int batch = input.size(0);\n    const int channels = input.size(1);\n    const int height = input.size(2);\n    const int width = input.size(3);\n\n    const int channels_out = weight.size(0);\n    const int channels_kernel = weight.size(1);\n    const int kernel_h_ = weight.size(2);\n    const int kernel_w_ = weight.size(3);\n\n    // printf(\"Kernels: %d %d %d %d\\n\", kernel_h_, kernel_w_, kernel_w, kernel_h);\n    // printf(\"Channels: %d %d\\n\", channels, channels_kernel);\n    // printf(\"Channels: %d %d\\n\", channels_out, channels_kernel);\n\n    AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,\n               \"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\", kernel_h_, kernel_w, kernel_h_, kernel_w_);\n\n    AT_ASSERTM(channels == channels_kernel,\n               \"Input shape and kernel channels wont match: (%d vs %d).\", channels, channels_kernel);\n\n    const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n    const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n    auto ones = at::ones({batch, height_out, width_out}, input.options());\n    auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());\n    auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());\n\n    // prepare for batch-wise computing, which is significantly faster than instance-wise computing\n    // when batch size is large.\n    // launch batch threads\n    int matrices_size = batch * sizeof(float *);\n    auto input_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));\n    auto output_b = static_cast<float **>(THCudaMalloc(state, matrices_size));\n    auto columns_b = static_cast<float **>(THCudaMalloc(state, matrices_size));\n    auto ones_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));\n    auto weight_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));\n    auto bias_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));\n\n    const int block = 128;\n    const int grid = (batch + block - 1) / block;\n\n    createBatchGemmBuffer<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(\n        input_b, output_b,\n        columns_b, ones_b,\n        weight_b, bias_b,\n        input.data_ptr<scalar_t>(),\n        output.data_ptr<scalar_t>(),\n        columns.data_ptr<scalar_t>(),\n        ones.data_ptr<scalar_t>(),\n        weight.data_ptr<scalar_t>(),\n        bias.data_ptr<scalar_t>(),\n        channels * width * height,\n        channels_out * width_out * height_out,\n        channels * kernel_h * kernel_w * height_out * width_out,\n        height_out * width_out,\n        batch);\n\n    long m_ = channels_out;\n    long n_ = height_out * width_out;\n    long k_ = 1;\n    THCudaBlas_SgemmBatched(state,\n                            't',\n                            'n',\n                            n_,\n                            m_,\n                            k_,\n                            1.0f,\n                            ones_b, k_,\n                            bias_b, k_,\n                            0.0f,\n                            output_b, n_,\n                            batch);\n\n    modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),\n                                     input.data_ptr<scalar_t>(),\n                                     offset.data_ptr<scalar_t>(),\n                                     mask.data_ptr<scalar_t>(),\n                                     batch, channels, height, width,\n                                     height_out, width_out, kernel_h, kernel_w,\n                                     pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,\n                                     deformable_group,\n                                     columns.data_ptr<scalar_t>());\n\n    long m = channels_out;\n    long n = height_out * width_out;\n    long k = channels * kernel_h * kernel_w;\n    THCudaBlas_SgemmBatched(state,\n                            'n',\n                            'n',\n                            n,\n                            m,\n                            k,\n                            1.0f,\n                            (const float **)columns_b, n,\n                            weight_b, k,\n                            1.0f,\n                            output_b, n,\n                            batch);\n\n    THCudaFree(state, input_b);\n    THCudaFree(state, output_b);\n    THCudaFree(state, columns_b);\n    THCudaFree(state, ones_b);\n    THCudaFree(state, weight_b);\n    THCudaFree(state, bias_b);\n    return output;\n}\n\n__global__ void createBatchGemmBufferBackward(\n    float **grad_output_b,\n    float **columns_b,\n    float **ones_b,\n    float **weight_b,\n    float **grad_weight_b,\n    float **grad_bias_b,\n    float *grad_output,\n    float *columns,\n    float *ones,\n    float *weight,\n    float *grad_weight,\n    float *grad_bias,\n    const int grad_output_stride,\n    const int columns_stride,\n    const int ones_stride,\n    const int num_batches)\n{\n    const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < num_batches)\n    {\n        grad_output_b[idx] = grad_output + idx * grad_output_stride;\n        columns_b[idx] = columns + idx * columns_stride;\n        ones_b[idx] = ones + idx * ones_stride;\n\n        // share weights and bias within a Mini-Batch\n        weight_b[idx] = weight;\n        grad_weight_b[idx] = grad_weight;\n        grad_bias_b[idx] = grad_bias;\n    }\n}\n\nstd::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,\n                                             const at::Tensor &weight,\n                                             const at::Tensor &bias,\n                                             const at::Tensor &offset,\n                                             const at::Tensor &mask,\n                                             const at::Tensor &grad_output,\n                                             int kernel_h, int kernel_w,\n                                             int stride_h, int stride_w,\n                                             int pad_h, int pad_w,\n                                             int dilation_h, int dilation_w,\n                                             int deformable_group)\n{\n\n    THArgCheck(input.is_contiguous(), 1, \"input tensor has to be contiguous\");\n    THArgCheck(weight.is_contiguous(), 2, \"weight tensor has to be contiguous\");\n\n    AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n    AT_ASSERTM(weight.type().is_cuda(), \"weight must be a CUDA tensor\");\n    AT_ASSERTM(bias.type().is_cuda(), \"bias must be a CUDA tensor\");\n    AT_ASSERTM(offset.type().is_cuda(), \"offset must be a CUDA tensor\");\n    AT_ASSERTM(mask.type().is_cuda(), \"mask must be a CUDA tensor\");\n\n    const int batch = input.size(0);\n    const int channels = input.size(1);\n    const int height = input.size(2);\n    const int width = input.size(3);\n\n    const int channels_out = weight.size(0);\n    const int channels_kernel = weight.size(1);\n    const int kernel_h_ = weight.size(2);\n    const int kernel_w_ = weight.size(3);\n\n    AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,\n               \"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\", kernel_h_, kernel_w, kernel_h_, kernel_w_);\n\n    AT_ASSERTM(channels == channels_kernel,\n               \"Input shape and kernel channels wont match: (%d vs %d).\", channels, channels_kernel);\n\n    const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n    const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n    auto ones = at::ones({height_out, width_out}, input.options());\n    auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());\n    auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());\n\n    auto grad_input = at::zeros_like(input);\n    auto grad_weight = at::zeros_like(weight);\n    auto grad_bias = at::zeros_like(bias);\n    auto grad_offset = at::zeros_like(offset);\n    auto grad_mask = at::zeros_like(mask);\n\n    using scalar_t = float;\n\n    for (int b = 0; b < batch; b++)\n    {\n        auto input_n = input.select(0, b);\n        auto offset_n = offset.select(0, b);\n        auto mask_n = mask.select(0, b);\n        auto grad_output_n = grad_output.select(0, b);\n        auto grad_input_n = grad_input.select(0, b);\n        auto grad_offset_n = grad_offset.select(0, b);\n        auto grad_mask_n = grad_mask.select(0, b);\n\n        long m = channels * kernel_h * kernel_w;\n        long n = height_out * width_out;\n        long k = channels_out;\n\n        THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,\n                         grad_output_n.data_ptr<scalar_t>(), n,\n                         weight.data_ptr<scalar_t>(), m, 0.0f,\n                         columns.data_ptr<scalar_t>(), n);\n\n        // gradient w.r.t. input coordinate data\n        modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(),\n                                               columns.data_ptr<scalar_t>(),\n                                               input_n.data_ptr<scalar_t>(),\n                                               offset_n.data_ptr<scalar_t>(),\n                                               mask_n.data_ptr<scalar_t>(),\n                                               1, channels, height, width,\n                                               height_out, width_out, kernel_h, kernel_w,\n                                               pad_h, pad_w, stride_h, stride_w,\n                                               dilation_h, dilation_w, deformable_group,\n                                               grad_offset_n.data_ptr<scalar_t>(),\n                                               grad_mask_n.data_ptr<scalar_t>());\n        // gradient w.r.t. input data\n        modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(),\n                                         columns.data_ptr<scalar_t>(),\n                                         offset_n.data_ptr<scalar_t>(),\n                                         mask_n.data_ptr<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w,\n                                         dilation_h, dilation_w, deformable_group,\n                                         grad_input_n.data_ptr<scalar_t>());\n\n        // gradient w.r.t. weight, dWeight should accumulate across the batch and group\n        modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),\n                                         input_n.data_ptr<scalar_t>(),\n                                         offset_n.data_ptr<scalar_t>(),\n                                         mask_n.data_ptr<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w,\n                                         dilation_h, dilation_w, deformable_group,\n                                         columns.data_ptr<scalar_t>());\n\n        long m_ = channels_out;\n        long n_ = channels * kernel_h * kernel_w;\n        long k_ = height_out * width_out;\n\n        THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,\n                         columns.data_ptr<scalar_t>(), k_,\n                         grad_output_n.data_ptr<scalar_t>(), k_, 1.0f,\n                         grad_weight.data_ptr<scalar_t>(), n_);\n\n        cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n        cublasOperation_t op = _cublasOpFromChar('t');\n        _cublasAdjustLdLevel2(k_, m_, &k_);\n        scalar_t* grad_output_n_float = grad_output_n.data_ptr<scalar_t>();\n        scalar_t* one_float = ones.data_ptr<scalar_t>();\n        scalar_t alpha = 1.0;\n        scalar_t beta = 1.0;\n        cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr<scalar_t>(), 1);\n\n    }\n    \n\n    return {\n        grad_input, grad_offset, grad_mask, grad_weight, grad_bias\n    };\n}\n"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu",
    "content": "#include \"dcn_v2_im2col_cuda.h\"\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <THC/THC.h>\n#include <THC/THCAtomics.cuh>\n#include <THC/THCDeviceUtils.cuh>\n\n#define CUDA_KERNEL_LOOP(i, n)                          \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x;   \\\n      i < (n);                                          \\\n      i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N)\n{\n  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;\n}\n\n\n__device__ float dmcn_im2col_bilinear_cuda(const float *bottom_data, const int data_width,\n                                      const int height, const int width, float h, float w)\n{\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  float lh = h - h_low;\n  float lw = w - w_low;\n  float hh = 1 - lh, hw = 1 - lw;\n\n  float v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n    v1 = bottom_data[h_low * data_width + w_low];\n  float v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n    v2 = bottom_data[h_low * data_width + w_high];\n  float v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n    v3 = bottom_data[h_high * data_width + w_low];\n  float v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n    v4 = bottom_data[h_high * data_width + w_high];\n\n  float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\n__device__ float dmcn_get_gradient_weight_cuda(float argmax_h, float argmax_w,\n                                          const int h, const int w, const int height, const int width)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  float weight = 0;\n  if (h == argmax_h_low && w == argmax_w_low)\n    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);\n  if (h == argmax_h_low && w == argmax_w_high)\n    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);\n  if (h == argmax_h_high && w == argmax_w_low)\n    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);\n  if (h == argmax_h_high && w == argmax_w_high)\n    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);\n  return weight;\n}\n\n__device__ float dmcn_get_coordinate_weight_cuda(float argmax_h, float argmax_w,\n                                            const int height, const int width, const float *im_data,\n                                            const int data_width, const int bp_dir)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  float weight = 0;\n\n  if (bp_dir == 0)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n  else if (bp_dir == 1)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n\n  return weight;\n}\n\n__global__ void modulated_deformable_im2col_gpu_kernel(const int n,\n                                                       const float *data_im, const float *data_offset, const float *data_mask,\n                                                       const int height, const int width, const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int num_channels, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       float *data_col)\n{\n  // launch channels * batch_size * height_col * width_col cores\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)\n    // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis\n\n    // index index of output matrix\n    const int w_col = index % width_col;\n    const int h_col = (index / width_col) % height_col;\n    // const int b_col = (index / width_col / height_col) % batch_size;\n    const int b_col = (index / width_col / height_col / num_channels) % batch_size;\n    // const int c_im = (index / width_col / height_col) / batch_size;\n    const int c_im = (index / width_col / height_col) % num_channels;\n    // const int c_col = c_im * kernel_h * kernel_w;\n    const int c_col = c_im * kernel_h * kernel_w;\n\n    // compute deformable group index\n    const int deformable_group_index = c_im / channel_per_deformable_group;\n\n    const int h_in = h_col * stride_h - pad_h;\n    const int w_in = w_col * stride_w - pad_w;\n\n    //  float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;\n    float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col;\n    //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;\n    const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;\n    const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n\n    const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    for (int i = 0; i < kernel_h; ++i)\n    {\n      for (int j = 0; j < kernel_w; ++j)\n      {\n        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;\n        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;\n        const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;\n        const float offset_h = data_offset_ptr[data_offset_h_ptr];\n        const float offset_w = data_offset_ptr[data_offset_w_ptr];\n        const float mask = data_mask_ptr[data_mask_hw_ptr];\n        float val = static_cast<float>(0);\n        const float h_im = h_in + i * dilation_h + offset_h;\n        const float w_im = w_in + j * dilation_w + offset_w;\n        //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {\n        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)\n        {\n          //const float map_h = i * dilation_h + offset_h;\n          //const float map_w = j * dilation_w + offset_w;\n          //const int cur_height = height - h_in;\n          //const int cur_width = width - w_in;\n          //val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, cur_height, cur_width, map_h, map_w);\n          val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, height, width, h_im, w_im);\n        }\n        *data_col_ptr = val * mask;\n        // data_col_ptr += batch_size * height_col * width_col;\n        data_col_ptr += height_col * width_col;\n      }\n    }\n  }\n}\n\n__global__ void modulated_deformable_col2im_gpu_kernel(const int n,\n                                                       const float *data_col, const float *data_offset, const float *data_mask,\n                                                       const int channels, const int height, const int width,\n                                                       const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       float *grad_im)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    const int j = (index / width_col / height_col / batch_size) % kernel_w;\n    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;\n    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / channel_per_deformable_group;\n\n    int w_out = index % width_col;\n    int h_out = (index / width_col) % height_col;\n    int b = (index / width_col / height_col) % batch_size;\n    int w_in = w_out * stride_w - pad_w;\n    int h_in = h_out * stride_h - pad_h;\n\n    const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;\n    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;\n    const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;\n    const float offset_h = data_offset_ptr[data_offset_h_ptr];\n    const float offset_w = data_offset_ptr[data_offset_w_ptr];\n    const float mask = data_mask_ptr[data_mask_hw_ptr];\n    const float cur_inv_h_data = h_in + i * dilation_h + offset_h;\n    const float cur_inv_w_data = w_in + j * dilation_w + offset_w;\n\n    const float cur_top_grad = data_col[index] * mask;\n    const int cur_h = (int)cur_inv_h_data;\n    const int cur_w = (int)cur_inv_w_data;\n    for (int dy = -2; dy <= 2; dy++)\n    {\n      for (int dx = -2; dx <= 2; dx++)\n      {\n        if (cur_h + dy >= 0 && cur_h + dy < height &&\n            cur_w + dx >= 0 && cur_w + dx < width &&\n            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&\n            abs(cur_inv_w_data - (cur_w + dx)) < 1)\n        {\n          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;\n          float weight = dmcn_get_gradient_weight_cuda(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);\n          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);\n        }\n      }\n    }\n  }\n}\n\n__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,\n                                                             const float *data_col, const float *data_im,\n                                                             const float *data_offset, const float *data_mask,\n                                                             const int channels, const int height, const int width,\n                                                             const int kernel_h, const int kernel_w,\n                                                             const int pad_h, const int pad_w,\n                                                             const int stride_h, const int stride_w,\n                                                             const int dilation_h, const int dilation_w,\n                                                             const int channel_per_deformable_group,\n                                                             const int batch_size, const int offset_channels, const int deformable_group,\n                                                             const int height_col, const int width_col,\n                                                             float *grad_offset, float *grad_mask)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    float val = 0, mval = 0;\n    int w = index % width_col;\n    int h = (index / width_col) % height_col;\n    int c = (index / width_col / height_col) % offset_channels;\n    int b = (index / width_col / height_col) / offset_channels;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / (2 * kernel_h * kernel_w);\n    const int col_step = kernel_h * kernel_w;\n    int cnt = 0;\n    const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;\n    const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;\n    const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;\n\n    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)\n    {\n      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;\n      const int bp_dir = offset_c % 2;\n\n      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;\n      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;\n      int w_out = col_pos % width_col;\n      int h_out = (col_pos / width_col) % height_col;\n      int w_in = w_out * stride_w - pad_w;\n      int h_in = h_out * stride_h - pad_h;\n      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);\n      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);\n      const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);\n      const float offset_h = data_offset_ptr[data_offset_h_ptr];\n      const float offset_w = data_offset_ptr[data_offset_w_ptr];\n      const float mask = data_mask_ptr[data_mask_hw_ptr];\n      float inv_h = h_in + i * dilation_h + offset_h;\n      float inv_w = w_in + j * dilation_w + offset_w;\n      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)\n      {\n        inv_h = inv_w = -2;\n      }\n      else\n      {\n        mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cuda(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);\n      }\n      const float weight = dmcn_get_coordinate_weight_cuda(\n          inv_h, inv_w,\n          height, width, data_im_ptr + cnt * height * width, width, bp_dir);\n      val += weight * data_col_ptr[col_pos] * mask;\n      cnt += 1;\n    }\n    // KERNEL_ASSIGN(grad_offset[index], offset_req, val);\n    grad_offset[index] = val;\n    if (offset_c % 2 == 0)\n      // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);\n      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;\n  }\n}\n\nvoid modulated_deformable_im2col_cuda(cudaStream_t stream,\n  const float* data_im, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w,\n  const int deformable_group, float* data_col) {\n  // num_axes should be smaller than block size\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * batch_size * height_col * width_col;\n  modulated_deformable_im2col_gpu_kernel\n      <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS,\n          0, stream>>>(\n      num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,\n      pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,\n      batch_size, channels, deformable_group, height_col, width_col, data_col);\n  \n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_im2col_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n\n}\n\nvoid modulated_deformable_col2im_cuda(cudaStream_t stream,\n  const float* data_col, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w, \n  const int deformable_group, float* grad_im){\n\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;\n  modulated_deformable_col2im_gpu_kernel\n      <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS,\n          0, stream>>>(\n        num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,\n        kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,\n        dilation_h, dilation_w, channel_per_deformable_group,\n        batch_size, deformable_group, height_col, width_col, grad_im);\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n\n}\n\nvoid modulated_deformable_col2im_coord_cuda(cudaStream_t stream,\n  const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w, \n  const int deformable_group,\n  float* grad_offset, float* grad_mask) {\n  const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;\n  const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;\n  modulated_deformable_col2im_coord_gpu_kernel\n      <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS,\n        0, stream>>>(\n        num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,\n        kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,\n        dilation_h, dilation_w, channel_per_deformable_group,\n        batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, \n        grad_offset, grad_mask);\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_coord_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n}"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.h",
    "content": "\n/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contributions by the University of California:\n * Copyright (c) 2014-2017 The Regents of the University of California (Regents)\n * All rights reserved.\n *\n * All other contributions:\n * Copyright (c) 2014-2017, the respective contributors\n * All rights reserved.\n *\n * Caffe uses a shared copyright model: each contributor holds copyright over\n * their contributions to Caffe. The project versioning records all such\n * contribution and copyright details. If a contributor wants to further mark\n * their specific copyright on a particular contribution, they should indicate\n * their copyright solely in the commit message of the change when it is\n * committed.\n *\n * LICENSE\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n * CONTRIBUTION AGREEMENT\n *\n * By contributing to the BVLC/caffe repository through pull-request, comment,\n * or otherwise, the contributor releases their content to the\n * license and copyright terms herein.\n *\n ***************** END Caffe Copyright Notice and Disclaimer ********************\n *\n * Copyright (c) 2018 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file modulated_deformable_im2col.h\n * \\brief Function definitions of converting an image to\n * column matrix based on kernel, padding, dilation, and offset.\n * These functions are mainly used in deformable convolution operators.\n * \\ref: https://arxiv.org/abs/1811.11168\n * \\author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu\n */\n\n/***************** Adapted by Charles Shang *********************/\n\n#ifndef DCN_V2_IM2COL_CUDA\n#define DCN_V2_IM2COL_CUDA\n\n#ifdef __cplusplus\nextern \"C\"\n{\n#endif\n\n  void modulated_deformable_im2col_cuda(cudaStream_t stream,\n                                        const float *data_im, const float *data_offset, const float *data_mask,\n                                        const int batch_size, const int channels, const int height_im, const int width_im,\n                                        const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                        const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                        const int dilation_h, const int dilation_w,\n                                        const int deformable_group, float *data_col);\n\n  void modulated_deformable_col2im_cuda(cudaStream_t stream,\n                                        const float *data_col, const float *data_offset, const float *data_mask,\n                                        const int batch_size, const int channels, const int height_im, const int width_im,\n                                        const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                        const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                        const int dilation_h, const int dilation_w,\n                                        const int deformable_group, float *grad_im);\n\n  void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,\n                                         const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,\n                                         const int batch_size, const int channels, const int height_im, const int width_im,\n                                         const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                         const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                         const int dilation_h, const int dilation_w,\n                                         const int deformable_group,\n                                         float *grad_offset, float *grad_mask);\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu",
    "content": "/*!\n * Copyright (c) 2017 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file deformable_psroi_pooling.cu\n * \\brief\n * \\author Yi Li, Guodong Zhang, Jifeng Dai\n*/\n/***************** Adapted by Charles Shang *********************/\n\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <THC/THC.h>\n#include <THC/THCAtomics.cuh>\n#include <THC/THCDeviceUtils.cuh>\n\n#define CUDA_KERNEL_LOOP(i, n)                        \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \\\n       i < (n);                                       \\\n       i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N)\n{\n  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;\n}\n\ntemplate <typename T>\n__device__ T bilinear_interp_cuda(\n    const T *data,\n    const T x,\n    const T y,\n    const int width,\n    const int height)\n{\n  int x1 = floor(x);\n  int x2 = ceil(x);\n  int y1 = floor(y);\n  int y2 = ceil(y);\n  T dist_x = static_cast<T>(x - x1);\n  T dist_y = static_cast<T>(y - y1);\n  T value11 = data[y1 * width + x1];\n  T value12 = data[y2 * width + x1];\n  T value21 = data[y1 * width + x2];\n  T value22 = data[y2 * width + x2];\n  T value = (1 - dist_x) * (1 - dist_y) * value11 +\n            (1 - dist_x) * dist_y * value12 +\n            dist_x * (1 - dist_y) * value21 +\n            dist_x * dist_y * value22;\n  return value;\n}\n\ntemplate <typename T>\n__global__ void DeformablePSROIPoolForwardKernelCuda(\n    const int count,\n    const T *bottom_data,\n    const T spatial_scale,\n    const int channels,\n    const int height, const int width,\n    const int pooled_height, const int pooled_width,\n    const T *bottom_rois, const T *bottom_trans,\n    const int no_trans,\n    const T trans_std,\n    const int sample_per_part,\n    const int output_dim,\n    const int group_size,\n    const int part_size,\n    const int num_classes,\n    const int channels_each_class,\n    T *top_data,\n    T *top_count)\n{\n  CUDA_KERNEL_LOOP(index, count)\n  {\n    // The output is in order (n, ctop, ph, pw)\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int ctop = (index / pooled_width / pooled_height) % output_dim;\n    int n = index / pooled_width / pooled_height / output_dim;\n\n    // [start, end) interval for spatial sampling\n    const T *offset_bottom_rois = bottom_rois + n * 5;\n    int roi_batch_ind = offset_bottom_rois[0];\n    T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;\n    T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;\n    T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;\n    T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;\n\n    // Force too small ROIs to be 1x1\n    T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0\n    T roi_height = max(roi_end_h - roi_start_h, 0.1);\n\n    // Compute w and h at bottom\n    T bin_size_h = roi_height / static_cast<T>(pooled_height);\n    T bin_size_w = roi_width / static_cast<T>(pooled_width);\n\n    T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);\n    T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);\n\n    int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);\n    int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);\n    int class_id = ctop / channels_each_class;\n    T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;\n    T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;\n\n    T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;\n    wstart += trans_x * roi_width;\n    T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;\n    hstart += trans_y * roi_height;\n\n    T sum = 0;\n    int count = 0;\n    int gw = floor(static_cast<T>(pw) * group_size / pooled_width);\n    int gh = floor(static_cast<T>(ph) * group_size / pooled_height);\n    gw = min(max(gw, 0), group_size - 1);\n    gh = min(max(gh, 0), group_size - 1);\n\n    const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;\n    for (int ih = 0; ih < sample_per_part; ih++)\n    {\n      for (int iw = 0; iw < sample_per_part; iw++)\n      {\n        T w = wstart + iw * sub_bin_size_w;\n        T h = hstart + ih * sub_bin_size_h;\n        // bilinear interpolation\n        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)\n        {\n          continue;\n        }\n        w = min(max(w, 0.), width - 1.);\n        h = min(max(h, 0.), height - 1.);\n        int c = (ctop * group_size + gh) * group_size + gw;\n        T val = bilinear_interp_cuda(offset_bottom_data + c * height * width, w, h, width, height);\n        sum += val;\n        count++;\n      }\n    }\n    top_data[index] = count == 0 ? static_cast<T>(0) : sum / count;\n    top_count[index] = count;\n  }\n}\n\ntemplate <typename T>\n__global__ void DeformablePSROIPoolBackwardAccKernelCuda(\n    const int count,\n    const T *top_diff,\n    const T *top_count,\n    const int num_rois,\n    const T spatial_scale,\n    const int channels,\n    const int height, const int width,\n    const int pooled_height, const int pooled_width,\n    const int output_dim,\n    T *bottom_data_diff, T *bottom_trans_diff,\n    const T *bottom_data,\n    const T *bottom_rois,\n    const T *bottom_trans,\n    const int no_trans,\n    const T trans_std,\n    const int sample_per_part,\n    const int group_size,\n    const int part_size,\n    const int num_classes,\n    const int channels_each_class)\n{\n  CUDA_KERNEL_LOOP(index, count)\n  {\n    // The output is in order (n, ctop, ph, pw)\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int ctop = (index / pooled_width / pooled_height) % output_dim;\n    int n = index / pooled_width / pooled_height / output_dim;\n\n    // [start, end) interval for spatial sampling\n    const T *offset_bottom_rois = bottom_rois + n * 5;\n    int roi_batch_ind = offset_bottom_rois[0];\n    T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;\n    T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;\n    T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;\n    T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;\n\n    // Force too small ROIs to be 1x1\n    T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0\n    T roi_height = max(roi_end_h - roi_start_h, 0.1);\n\n    // Compute w and h at bottom\n    T bin_size_h = roi_height / static_cast<T>(pooled_height);\n    T bin_size_w = roi_width / static_cast<T>(pooled_width);\n\n    T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);\n    T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);\n\n    int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);\n    int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);\n    int class_id = ctop / channels_each_class;\n    T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;\n    T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;\n\n    T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;\n    wstart += trans_x * roi_width;\n    T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;\n    hstart += trans_y * roi_height;\n\n    if (top_count[index] <= 0)\n    {\n      continue;\n    }\n    T diff_val = top_diff[index] / top_count[index];\n    const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;\n    T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;\n    int gw = floor(static_cast<T>(pw) * group_size / pooled_width);\n    int gh = floor(static_cast<T>(ph) * group_size / pooled_height);\n    gw = min(max(gw, 0), group_size - 1);\n    gh = min(max(gh, 0), group_size - 1);\n\n    for (int ih = 0; ih < sample_per_part; ih++)\n    {\n      for (int iw = 0; iw < sample_per_part; iw++)\n      {\n        T w = wstart + iw * sub_bin_size_w;\n        T h = hstart + ih * sub_bin_size_h;\n        // bilinear interpolation\n        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)\n        {\n          continue;\n        }\n        w = min(max(w, 0.), width - 1.);\n        h = min(max(h, 0.), height - 1.);\n        int c = (ctop * group_size + gh) * group_size + gw;\n        // backward on feature\n        int x0 = floor(w);\n        int x1 = ceil(w);\n        int y0 = floor(h);\n        int y1 = ceil(h);\n        T dist_x = w - x0, dist_y = h - y0;\n        T q00 = (1 - dist_x) * (1 - dist_y);\n        T q01 = (1 - dist_x) * dist_y;\n        T q10 = dist_x * (1 - dist_y);\n        T q11 = dist_x * dist_y;\n        int bottom_index_base = c * height * width;\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);\n\n        if (no_trans)\n        {\n          continue;\n        }\n        T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];\n        T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];\n        T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];\n        T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];\n        T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;\n        diff_x *= roi_width;\n        T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;\n        diff_y *= roi_height;\n\n        atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);\n        atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);\n      }\n    }\n  }\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,\n                                  const at::Tensor &bbox,\n                                  const at::Tensor &trans,\n                                  const int no_trans,\n                                  const float spatial_scale,\n                                  const int output_dim,\n                                  const int group_size,\n                                  const int pooled_size,\n                                  const int part_size,\n                                  const int sample_per_part,\n                                  const float trans_std)\n{\n  AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n  AT_ASSERTM(bbox.type().is_cuda(), \"rois must be a CUDA tensor\");\n  AT_ASSERTM(trans.type().is_cuda(), \"trans must be a CUDA tensor\");\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n  const int channels_trans = no_trans ? 2 : trans.size(1);\n  const int num_bbox = bbox.size(0);\n\n  AT_ASSERTM(channels == output_dim, \"input channels and output channels must equal\");\n  auto pooled_height = pooled_size;\n  auto pooled_width = pooled_size;\n\n  auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());\n  long out_size = num_bbox * output_dim * pooled_height * pooled_width;\n  auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());\n\n  const int num_classes = no_trans ? 1 : channels_trans / 2;\n  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (out.numel() == 0)\n  {\n    THCudaCheck(cudaGetLastError());\n    return std::make_tuple(out, top_count);\n  }\n\n  dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));\n  dim3 block(512);\n\n  AT_DISPATCH_FLOATING_TYPES(input.type(), \"dcn_v2_psroi_pooling_cuda_forward\", [&] {\n    DeformablePSROIPoolForwardKernelCuda<scalar_t><<<grid, block, 0, stream>>>(\n        out_size,\n        input.contiguous().data_ptr<scalar_t>(),\n        spatial_scale,\n        channels,\n        height, width,\n        pooled_height,\n        pooled_width,\n        bbox.contiguous().data_ptr<scalar_t>(),\n        trans.contiguous().data_ptr<scalar_t>(),\n        no_trans,\n        trans_std,\n        sample_per_part,\n        output_dim,\n        group_size,\n        part_size,\n        num_classes,\n        channels_each_class,\n        out.data_ptr<scalar_t>(),\n        top_count.data_ptr<scalar_t>());\n  });\n  THCudaCheck(cudaGetLastError());\n  return std::make_tuple(out, top_count);\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,\n                                   const at::Tensor &input,\n                                   const at::Tensor &bbox,\n                                   const at::Tensor &trans,\n                                   const at::Tensor &top_count,\n                                   const int no_trans,\n                                   const float spatial_scale,\n                                   const int output_dim,\n                                   const int group_size,\n                                   const int pooled_size,\n                                   const int part_size,\n                                   const int sample_per_part,\n                                   const float trans_std)\n{\n  AT_ASSERTM(out_grad.type().is_cuda(), \"out_grad must be a CUDA tensor\");\n  AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n  AT_ASSERTM(bbox.type().is_cuda(), \"bbox must be a CUDA tensor\");\n  AT_ASSERTM(trans.type().is_cuda(), \"trans must be a CUDA tensor\");\n  AT_ASSERTM(top_count.type().is_cuda(), \"top_count must be a CUDA tensor\");\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n  const int channels_trans = no_trans ? 2 : trans.size(1);\n  const int num_bbox = bbox.size(0);\n\n  AT_ASSERTM(channels == output_dim, \"input channels and output channels must equal\");\n  auto pooled_height = pooled_size;\n  auto pooled_width = pooled_size;\n  long out_size = num_bbox * output_dim * pooled_height * pooled_width;\n  const int num_classes = no_trans ? 1 : channels_trans / 2;\n  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;\n\n  auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());\n  auto trans_grad = at::zeros_like(trans);\n\n  if (input_grad.numel() == 0)\n  {\n    THCudaCheck(cudaGetLastError());\n    return std::make_tuple(input_grad, trans_grad);\n  }\n\n  dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));\n  dim3 block(512);\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  AT_DISPATCH_FLOATING_TYPES(out_grad.type(), \"dcn_v2_psroi_pooling_cuda_backward\", [&] {\n    DeformablePSROIPoolBackwardAccKernelCuda<scalar_t><<<grid, block, 0, stream>>>(\n        out_size,\n        out_grad.contiguous().data_ptr<scalar_t>(),\n        top_count.contiguous().data_ptr<scalar_t>(),\n        num_bbox,\n        spatial_scale,\n        channels,\n        height,\n        width,\n        pooled_height,\n        pooled_width,\n        output_dim,\n        input_grad.contiguous().data_ptr<scalar_t>(),\n        trans_grad.contiguous().data_ptr<scalar_t>(),\n        input.contiguous().data_ptr<scalar_t>(),\n        bbox.contiguous().data_ptr<scalar_t>(),\n        trans.contiguous().data_ptr<scalar_t>(),\n        no_trans,\n        trans_std,\n        sample_per_part,\n        group_size,\n        part_size,\n        num_classes,\n        channels_each_class);\n  });\n  THCudaCheck(cudaGetLastError());\n  return std::make_tuple(input_grad, trans_grad);\n}"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/cuda/vision.h",
    "content": "#pragma once\n#include <torch/extension.h>\n#include <ATen/div_rtn.h>\nat::Tensor\ndcn_v2_cuda_forward(const at::Tensor &input,\n                    const at::Tensor &weight,\n                    const at::Tensor &bias,\n                    const at::Tensor &offset,\n                    const at::Tensor &mask,\n                    const int kernel_h,\n                    const int kernel_w,\n                    const int stride_h,\n                    const int stride_w,\n                    const int pad_h,\n                    const int pad_w,\n                    const int dilation_h,\n                    const int dilation_w,\n                    const int deformable_group);\n\nstd::vector<at::Tensor>\ndcn_v2_cuda_backward(const at::Tensor &input,\n                     const at::Tensor &weight,\n                     const at::Tensor &bias,\n                     const at::Tensor &offset,\n                     const at::Tensor &mask,\n                     const at::Tensor &grad_output,\n                     int kernel_h, int kernel_w,\n                     int stride_h, int stride_w,\n                     int pad_h, int pad_w,\n                     int dilation_h, int dilation_w,\n                     int deformable_group);\n\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,\n                                  const at::Tensor &bbox,\n                                  const at::Tensor &trans,\n                                  const int no_trans,\n                                  const float spatial_scale,\n                                  const int output_dim,\n                                  const int group_size,\n                                  const int pooled_size,\n                                  const int part_size,\n                                  const int sample_per_part,\n                                  const float trans_std);\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,\n                                   const at::Tensor &input,\n                                   const at::Tensor &bbox,\n                                   const at::Tensor &trans,\n                                   const at::Tensor &top_count,\n                                   const int no_trans,\n                                   const float spatial_scale,\n                                   const int output_dim,\n                                   const int group_size,\n                                   const int pooled_size,\n                                   const int part_size,\n                                   const int sample_per_part,\n                                   const float trans_std);"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/dcn_v2.h",
    "content": "#pragma once\n\n#include \"cpu/vision.h\"\n\n#ifdef WITH_CUDA\n#include \"cuda/vision.h\"\n#endif\n\nat::Tensor\ndcn_v2_forward(const at::Tensor &input,\n               const at::Tensor &weight,\n               const at::Tensor &bias,\n               const at::Tensor &offset,\n               const at::Tensor &mask,\n               const int kernel_h,\n               const int kernel_w,\n               const int stride_h,\n               const int stride_w,\n               const int pad_h,\n               const int pad_w,\n               const int dilation_h,\n               const int dilation_w,\n               const int deformable_group)\n{\n    if (input.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return dcn_v2_cuda_forward(input, weight, bias, offset, mask,\n                                   kernel_h, kernel_w,\n                                   stride_h, stride_w,\n                                   pad_h, pad_w,\n                                   dilation_h, dilation_w,\n                                   deformable_group);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    else{\n        return dcn_v2_cpu_forward(input, weight, bias, offset, mask,\n                                   kernel_h, kernel_w,\n                                   stride_h, stride_w,\n                                   pad_h, pad_w,\n                                   dilation_h, dilation_w,\n                                   deformable_group);\n    }\n}\n\nstd::vector<at::Tensor>\ndcn_v2_backward(const at::Tensor &input,\n                const at::Tensor &weight,\n                const at::Tensor &bias,\n                const at::Tensor &offset,\n                const at::Tensor &mask,\n                const at::Tensor &grad_output,\n                int kernel_h, int kernel_w,\n                int stride_h, int stride_w,\n                int pad_h, int pad_w,\n                int dilation_h, int dilation_w,\n                int deformable_group)\n{\n    if (input.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return dcn_v2_cuda_backward(input,\n                                    weight,\n                                    bias,\n                                    offset,\n                                    mask,\n                                    grad_output,\n                                    kernel_h, kernel_w,\n                                    stride_h, stride_w,\n                                    pad_h, pad_w,\n                                    dilation_h, dilation_w,\n                                    deformable_group);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    else{\n        return dcn_v2_cpu_backward(input,\n                                    weight,\n                                    bias,\n                                    offset,\n                                    mask,\n                                    grad_output,\n                                    kernel_h, kernel_w,\n                                    stride_h, stride_w,\n                                    pad_h, pad_w,\n                                    dilation_h, dilation_w,\n                                    deformable_group);\n    }\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_forward(const at::Tensor &input,\n                             const at::Tensor &bbox,\n                             const at::Tensor &trans,\n                             const int no_trans,\n                             const float spatial_scale,\n                             const int output_dim,\n                             const int group_size,\n                             const int pooled_size,\n                             const int part_size,\n                             const int sample_per_part,\n                             const float trans_std)\n{\n    if (input.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return dcn_v2_psroi_pooling_cuda_forward(input,\n                                                 bbox,\n                                                 trans,\n                                                 no_trans,\n                                                 spatial_scale,\n                                                 output_dim,\n                                                 group_size,\n                                                 pooled_size,\n                                                 part_size,\n                                                 sample_per_part,\n                                                 trans_std);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    else{\n        return dcn_v2_psroi_pooling_cpu_forward(input,\n                                                 bbox,\n                                                 trans,\n                                                 no_trans,\n                                                 spatial_scale,\n                                                 output_dim,\n                                                 group_size,\n                                                 pooled_size,\n                                                 part_size,\n                                                 sample_per_part,\n                                                 trans_std);\n    }\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_backward(const at::Tensor &out_grad,\n                              const at::Tensor &input,\n                              const at::Tensor &bbox,\n                              const at::Tensor &trans,\n                              const at::Tensor &top_count,\n                              const int no_trans,\n                              const float spatial_scale,\n                              const int output_dim,\n                              const int group_size,\n                              const int pooled_size,\n                              const int part_size,\n                              const int sample_per_part,\n                              const float trans_std)\n{\n    if (input.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return dcn_v2_psroi_pooling_cuda_backward(out_grad,\n                                                  input,\n                                                  bbox,\n                                                  trans,\n                                                  top_count,\n                                                  no_trans,\n                                                  spatial_scale,\n                                                  output_dim,\n                                                  group_size,\n                                                  pooled_size,\n                                                  part_size,\n                                                  sample_per_part,\n                                                  trans_std);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    else{\n        return dcn_v2_psroi_pooling_cpu_backward(out_grad,\n                                                  input,\n                                                  bbox,\n                                                  trans,\n                                                  top_count,\n                                                  no_trans,\n                                                  spatial_scale,\n                                                  output_dim,\n                                                  group_size,\n                                                  pooled_size,\n                                                  part_size,\n                                                  sample_per_part,\n                                                  trans_std);\n    }\n}"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/src/vision.cpp",
    "content": "\n#include \"dcn_v2.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"dcn_v2_forward\", &dcn_v2_forward, \"dcn_v2_forward\");\n  m.def(\"dcn_v2_backward\", &dcn_v2_backward, \"dcn_v2_backward\");\n  m.def(\"dcn_v2_psroi_pooling_forward\", &dcn_v2_psroi_pooling_forward, \"dcn_v2_psroi_pooling_forward\");\n  m.def(\"dcn_v2_psroi_pooling_backward\", &dcn_v2_psroi_pooling_backward, \"dcn_v2_psroi_pooling_backward\");\n}\n"
  },
  {
    "path": "code/real/bsrt/model/DCNv2/test.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nimport time\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import gradcheck\n\nfrom dcn_v2 import dcn_v2_conv, DCNv2, DCN\nfrom dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling\n\ndeformable_groups = 1\nN, inC, inH, inW = 2, 2, 4, 4\noutC = 2\nkH, kW = 3, 3\n\n\ndef conv_identify(weight, bias):\n    weight.data.zero_()\n    bias.data.zero_()\n    o, i, h, w = weight.shape\n    y = h//2\n    x = w//2\n    for p in range(i):\n        for q in range(o):\n            if p == q:\n                weight.data[q, p, y, x] = 1.0\n\n\ndef check_zero_offset():\n    conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW,\n                            kernel_size=(kH, kW),\n                            stride=(1, 1),\n                            padding=(1, 1),\n                            bias=True).cuda()\n\n    conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW,\n                          kernel_size=(kH, kW),\n                          stride=(1, 1),\n                          padding=(1, 1),\n                          bias=True).cuda()\n\n    dcn_v2 = DCNv2(inC, outC, (kH, kW),\n                   stride=1, padding=1, dilation=1,\n                   deformable_groups=deformable_groups).cuda()\n\n    conv_offset.weight.data.zero_()\n    conv_offset.bias.data.zero_()\n    conv_mask.weight.data.zero_()\n    conv_mask.bias.data.zero_()\n    conv_identify(dcn_v2.weight, dcn_v2.bias)\n\n    input = torch.randn(N, inC, inH, inW).cuda()\n    offset = conv_offset(input)\n    mask = conv_mask(input)\n    mask = torch.sigmoid(mask)\n    output = dcn_v2(input, offset, mask)\n    output *= 2\n    d = (input - output).abs().max()\n    if d < 1e-10:\n        print('Zero offset passed')\n    else:\n        print('Zero offset failed')\n        print(input)\n        print(output)\n\ndef check_gradient_dconv():\n\n    input = torch.rand(N, inC, inH, inW).cuda() * 0.01\n    input.requires_grad = True\n\n    offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2\n    # offset.data.zero_()\n    # offset.data -= 0.5\n    offset.requires_grad = True\n\n    mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda()\n    # mask.data.zero_()\n    mask.requires_grad = True\n    mask = torch.sigmoid(mask)\n\n    weight = torch.randn(outC, inC, kH, kW).cuda()\n    weight.requires_grad = True\n\n    bias = torch.rand(outC).cuda()\n    bias.requires_grad = True\n\n    stride = 1\n    padding = 1\n    dilation = 1\n\n    print('check_gradient_dconv: ',\n          gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias,\n                    stride, padding, dilation, deformable_groups),\n                    eps=1e-3, atol=1e-4, rtol=1e-2))\n\n\ndef check_pooling_zero_offset():\n\n    input = torch.randn(2, 16, 64, 64).cuda().zero_()\n    input[0, :, 16:26, 16:26] = 1.\n    input[1, :, 10:20, 20:30] = 2.\n    rois = torch.tensor([\n        [0, 65, 65, 103, 103],\n        [1, 81, 41, 119, 79],\n    ]).cuda().float()\n    pooling = DCNv2Pooling(spatial_scale=1.0 / 4,\n                           pooled_size=7,\n                           output_dim=16,\n                           no_trans=True,\n                           group_size=1,\n                           trans_std=0.0).cuda()\n\n    out = pooling(input, rois, input.new())\n    s = ', '.join(['%f' % out[i, :, :, :].mean().item()\n                   for i in range(rois.shape[0])])\n    print(s)\n\n    dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,\n                            pooled_size=7,\n                            output_dim=16,\n                            no_trans=False,\n                            group_size=1,\n                            trans_std=0.0).cuda()\n    offset = torch.randn(20, 2, 7, 7).cuda().zero_()\n    dout = dpooling(input, rois, offset)\n    s = ', '.join(['%f' % dout[i, :, :, :].mean().item()\n                   for i in range(rois.shape[0])])\n    print(s)\n\n\ndef check_gradient_dpooling():\n    input = torch.randn(2, 3, 5, 5).cuda() * 0.01\n    N = 4\n    batch_inds = torch.randint(2, (N, 1)).cuda().float()\n    x = torch.rand((N, 1)).cuda().float() * 15\n    y = torch.rand((N, 1)).cuda().float() * 15\n    w = torch.rand((N, 1)).cuda().float() * 10\n    h = torch.rand((N, 1)).cuda().float() * 10\n    rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)\n    offset = torch.randn(N, 2, 3, 3).cuda()\n    input.requires_grad = True\n    offset.requires_grad = True\n\n    spatial_scale = 1.0 / 4\n    pooled_size = 3\n    output_dim = 3\n    no_trans = 0\n    group_size = 1\n    trans_std = 0.0\n    sample_per_part = 4\n    part_size = pooled_size\n\n    print('check_gradient_dpooling:',\n          gradcheck(dcn_v2_pooling, (input, rois, offset,\n                                     spatial_scale,\n                                     pooled_size,\n                                     output_dim,\n                                     no_trans,\n                                     group_size,\n                                     part_size,\n                                     sample_per_part,\n                                     trans_std),\n                    eps=1e-4))\n\n\ndef example_dconv():\n    input = torch.randn(2, 64, 128, 128).cuda()\n    # wrap all things (offset and mask) in DCN\n    dcn = DCN(64, 64, kernel_size=(3, 3), stride=1,\n              padding=1, deformable_groups=2).cuda()\n    # print(dcn.weight.shape, input.shape)\n    output = dcn(input)\n    targert = output.new(*output.size())\n    targert.data.uniform_(-0.01, 0.01)\n    error = (targert - output).mean()\n    error.backward()\n    print(output.shape)\n\n\ndef example_dpooling():\n    input = torch.randn(2, 32, 64, 64).cuda()\n    batch_inds = torch.randint(2, (20, 1)).cuda().float()\n    x = torch.randint(256, (20, 1)).cuda().float()\n    y = torch.randint(256, (20, 1)).cuda().float()\n    w = torch.randint(64, (20, 1)).cuda().float()\n    h = torch.randint(64, (20, 1)).cuda().float()\n    rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)\n    offset = torch.randn(20, 2, 7, 7).cuda()\n    input.requires_grad = True\n    offset.requires_grad = True\n\n    # normal roi_align\n    pooling = DCNv2Pooling(spatial_scale=1.0 / 4,\n                           pooled_size=7,\n                           output_dim=32,\n                           no_trans=True,\n                           group_size=1,\n                           trans_std=0.1).cuda()\n\n    # deformable pooling\n    dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,\n                            pooled_size=7,\n                            output_dim=32,\n                            no_trans=False,\n                            group_size=1,\n                            trans_std=0.1).cuda()\n\n    out = pooling(input, rois, offset)\n    dout = dpooling(input, rois, offset)\n    print(out.shape)\n    print(dout.shape)\n\n    target_out = out.new(*out.size())\n    target_out.data.uniform_(-0.01, 0.01)\n    target_dout = dout.new(*dout.size())\n    target_dout.data.uniform_(-0.01, 0.01)\n    e = (target_out - out).mean()\n    e.backward()\n    e = (target_dout - dout).mean()\n    e.backward()\n\n\ndef example_mdpooling():\n    input = torch.randn(2, 32, 64, 64).cuda()\n    input.requires_grad = True\n    batch_inds = torch.randint(2, (20, 1)).cuda().float()\n    x = torch.randint(256, (20, 1)).cuda().float()\n    y = torch.randint(256, (20, 1)).cuda().float()\n    w = torch.randint(64, (20, 1)).cuda().float()\n    h = torch.randint(64, (20, 1)).cuda().float()\n    rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)\n\n    # mdformable pooling (V2)\n    dpooling = DCNPooling(spatial_scale=1.0 / 4,\n                          pooled_size=7,\n                          output_dim=32,\n                          no_trans=False,\n                          group_size=1,\n                          trans_std=0.1,\n                          deform_fc_dim=1024).cuda()\n\n    dout = dpooling(input, rois)\n    target = dout.new(*dout.size())\n    target.data.uniform_(-0.1, 0.1)\n    error = (target - dout).mean()\n    error.backward()\n    print(dout.shape)\n\n\nif __name__ == '__main__':\n\n    example_dconv()\n    example_dpooling()\n    example_mdpooling()\n\n    check_pooling_zero_offset()\n    # zero offset check\n    if inC == outC:\n        check_zero_offset()\n\n    check_gradient_dpooling()\n    check_gradient_dconv()\n    # \"\"\"\n    # ****** Note: backward is not reentrant error may not be a serious problem,\n    # ****** since the max error is less than 1e-7,\n    # ****** Still looking for what trigger this problem\n    # \"\"\"\n"
  },
  {
    "path": "code/real/bsrt/model/__init__.py",
    "content": "import os\nfrom importlib import import_module\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel as P\nimport torch.utils.model_zoo\nimport time\n\nclass Model(nn.Module):\n    def __init__(self, args, ckp):\n        super(Model, self).__init__()\n        self.args = args\n        if args.local_rank == 0:\n            print(\"Making model: \", args.model)\n            print(\"Patch size: \", args.patch_size)\n            \n        self.scale = args.scale\n        self.idx_scale = 0\n        self.input_large = (args.model == 'VDSR')\n        self.self_ensemble = args.self_ensemble\n        self.chop = args.chop\n        self.precision = args.precision\n        self.cpu = args.cpu\n        self.device = torch.device('cpu' if args.cpu else 'cuda:%d' % args.local_rank)\n        self.n_GPUs = args.n_GPUs\n        self.save_models = args.save_models\n\n        module = import_module('model.' + args.model.lower())\n        self.model = module.make_model(args).to(self.device)\n\n        if args.precision == 'half':\n            self.model.half()\n\n        self.load(\n            ckp.get_path('model'),\n            pre_train=args.pre_train,\n            resume=args.resume,\n            cpu=args.cpu\n        )\n\n        # time.sleep(3)\n\n        if args.n_GPUs > 1:\n            self.model = nn.parallel.DistributedDataParallel(self.model,\n                device_ids=[args.local_rank],\n                find_unused_parameters=True\n                )\n\n        print(self.model, file=ckp.log_file)\n\n    def forward(self, x, idx_scale):\n        self.idx_scale = idx_scale\n        if hasattr(self.model, 'set_scale'):\n            self.model.set_scale(idx_scale)\n\n        if self.training:\n            # if self.n_GPUs > 1:\n            return self.model(x)\n        else:\n            if self.chop:\n                forward_function = self.forward_chop\n            else:\n                forward_function = self.model.forward\n\n            if self.self_ensemble:\n                return self.forward_x8(x, forward_function=forward_function)\n            else:\n                # return self.model(x)\n                return forward_function(x)\n\n    def save(self, apath, epoch, is_best=False):\n        save_dirs = [os.path.join(apath, 'model_latest.pt')]\n\n        if is_best:\n            save_dirs.append(os.path.join(apath, 'model_best.pt'))\n        if self.save_models:\n            save_dirs.append(\n                os.path.join(apath, 'model_{}.pt'.format(epoch))\n            )\n        if self.n_GPUs > 1:\n            model = self.model.module\n        else:\n            model = self.model\n\n        for s in save_dirs:\n            torch.save(self.model.state_dict(), s)\n\n    def load(self, apath, pre_train='', resume=-1, cpu=False):\n        load_from = None\n        kwargs = {}\n        if cpu:\n            kwargs = {'map_location': lambda storage, loc: storage}\n\n        if resume == -1:\n            load_from = torch.load(\n                os.path.join(apath, 'model_latest.pt'),\n                **kwargs\n            )\n        elif resume == 0:\n            if pre_train == 'download':\n                print('Download the model')\n                dir_model = os.path.join('..', 'models')\n                os.makedirs(dir_model, exist_ok=True)\n                load_from = torch.utils.model_zoo.load_url(\n                    self.model.url,\n                    model_dir=dir_model,\n                    **kwargs\n                )\n            elif pre_train:\n                if self.args.local_rank == 0:\n                    print('Load the model from {}'.format(pre_train))\n                map_location = {'cuda:%d' % 0: 'cuda:%d' % self.args.local_rank}\n                load_from = torch.load(pre_train, map_location=map_location)\n                # print(load_from.keys())\n        else:\n            load_from = torch.load(\n                os.path.join(apath, 'model_{}.pt'.format(resume)),\n                **kwargs\n            )\n\n        if load_from:\n            self.model.load_state_dict(load_from, strict=True)\n            del load_from\n\n        \n        if self.args.finetune:\n            if self.args.local_rank == 0:\n                print('finetune')\n            for param in self.model.parameters():\n                param.requires_grad = False\n\n            for param in self.model.HRconv.parameters():\n                param.requires_grad = True\n            for param in self.model.conv_last.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_prelayer:\n            if self.args.local_rank == 0:\n                print('finetune_prelayer')\n            if self.args.swinfeature:\n                if self.args.model == 'MBSRT':\n                    for param in self.model.pre_layer1.parameters():\n                        param.requires_grad = True\n                    for param in self.model.pre_layer2.parameters():\n                        param.requires_grad = True\n                else:\n                    for param in self.model.pre_layers.parameters():\n                        param.requires_grad = True\n            else:\n                for param in self.model.feature_extraction.parameters():\n                    param.requires_grad = True\n\n            for param in self.model.conv_after_pre_layer.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_align:\n            if self.args.local_rank == 0:\n                print('finetune_align')\n            for param in self.model.align.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_spynet:\n            if self.args.local_rank == 0:\n                print('finetune_spynet')\n            for param in self.model.spynet.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_swin:\n            if self.args.local_rank == 0:\n                print('finetune_swin')\n            for param in self.model.layers.parameters():\n                param.requires_grad = True\n            for param in self.model.conv_after_body.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_upconv:\n            if self.args.local_rank == 0:\n                print('finetune_upconv')\n            for param in self.model.upconv1.parameters():\n                param.requires_grad = True\n            for param in self.model.upconv2.parameters():\n                param.requires_grad = True\n            for param in self.model.skipup1.parameters():\n                param.requires_grad = True\n            for param in self.model.skipup2.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_conv:\n            if self.args.local_rank == 0:\n                print('finetune_conv')\n            # for param in self.model.conv_first.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.conv_flow.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.fea_L2_conv1.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.fea_L3_conv1.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.toplayer.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.smooth1.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.smooth2.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.latlayer1.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.latlayer2.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.fusion.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.conv_after_pre_layer.parameters():\n            #     param.requires_grad = True\n            for param in self.model.conv_after_body.parameters():\n                param.requires_grad = True\n            \n            \n\n    def forward_chop(self, *args, shave=10, min_size=160000):\n        scale = 1 if self.input_large else self.scale[self.idx_scale]\n        n_GPUs = min(self.n_GPUs, 4)\n        # height, width\n        h, w = args[0].size()[-2:]\n\n        top = slice(0, h//2 + shave)\n        bottom = slice(h - h//2 - shave, h)\n        left = slice(0, w//2 + shave)\n        right = slice(w - w//2 - shave, w)\n        x_chops = [torch.cat([\n            a[..., top, left],\n            a[..., top, right],\n            a[..., bottom, left],\n            a[..., bottom, right]\n        ]) for a in args]\n\n        y_chops = []\n        if h * w < 4 * min_size:\n            for i in range(0, 4, n_GPUs):\n                x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]\n                y = P.data_parallel(self.model, *x, range(n_GPUs))\n                if not isinstance(y, list): y = [y]\n                if not y_chops:\n                    y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]\n                else:\n                    for y_chop, _y in zip(y_chops, y):\n                        y_chop.extend(_y.chunk(n_GPUs, dim=0))\n        else:\n            for p in zip(*x_chops):\n                y = self.forward_chop(*p, shave=shave, min_size=min_size)\n                if not isinstance(y, list): y = [y]\n                if not y_chops:\n                    y_chops = [[_y] for _y in y]\n                else:\n                    for y_chop, _y in zip(y_chops, y): y_chop.append(_y)\n\n        h *= scale\n        w *= scale\n        top = slice(0, h//2)\n        bottom = slice(h - h//2, h)\n        bottom_r = slice(h//2 - h, None)\n        left = slice(0, w//2)\n        right = slice(w - w//2, w)\n        right_r = slice(w//2 - w, None)\n\n        # batch size, number of color channels\n        b, c = y_chops[0][0].size()[:-2]\n        y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]\n        for y_chop, _y in zip(y_chops, y):\n            _y[..., top, left] = y_chop[0][..., top, left]\n            _y[..., top, right] = y_chop[1][..., top, right_r]\n            _y[..., bottom, left] = y_chop[2][..., bottom_r, left]\n            _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]\n\n        if len(y) == 1: y = y[0]\n\n        return y\n\n    def forward_x8(self, *args, forward_function=None):\n        def _transform(v, op):\n            if self.precision != 'single': v = v.float()\n\n            v2np = v.data.cpu().numpy()\n            if op == 'v':\n                tfnp = v2np[:, :, :, ::-1].copy()\n            elif op == 'h':\n                tfnp = v2np[:, :, ::-1, :].copy()\n            elif op == 't':\n                tfnp = v2np.transpose((0, 1, 3, 2)).copy()\n\n            ret = torch.Tensor(tfnp).to(self.device)\n            if self.precision == 'half': ret = ret.half()\n\n            return ret\n\n        list_x = []\n        for a in args:\n            x = [a]\n            for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])\n\n            list_x.append(x)\n\n        list_y = []\n        for x in zip(*list_x):\n            y = forward_function(*x)\n            if not isinstance(y, list): y = [y]\n            if not list_y:\n                list_y = [[_y] for _y in y]\n            else:\n                for _list_y, _y in zip(list_y, y): _list_y.append(_y)\n\n        for _list_y in list_y:\n            for i in range(len(_list_y)):\n                if i > 3:\n                    _list_y[i] = _transform(_list_y[i], 't')\n                if i % 4 > 1:\n                    _list_y[i] = _transform(_list_y[i], 'h')\n                if (i % 4) % 2 == 1:\n                    _list_y[i] = _transform(_list_y[i], 'v')\n\n        y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]\n        if len(y) == 1: y = y[0]\n\n        return y\n"
  },
  {
    "path": "code/real/bsrt/model/arch_util.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.nn.functional as F\nfrom model import common\nfrom model.utils.psconv import PSGConv2d as PSConv2d, PyConv2d\n\n\ndef initialize_weights(net_l, scale=1):\n    if not isinstance(net_l, list):\n        net_l = [net_l]\n    for net in net_l:\n        for m in net.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n                m.weight.data *= scale  # for residual block\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n                m.weight.data *= scale\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias.data, 0.0)\n\n\ndef make_layer(block, n_layers):\n    layers = []\n    for _ in range(n_layers):\n        layers.append(block())\n    return nn.Sequential(*layers)\n\n\n###########################\n\ndef conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0):\n    return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True)\n    \n\nclass ESA(nn.Module):\n    def __init__(self, n_feats, conv=conv_layer):\n        super(ESA, self).__init__()\n        f = n_feats // 4\n        self.conv1 = conv(n_feats, f, kernel_size=1)\n        self.conv_f = conv(f, f, kernel_size=1)\n        self.conv_max = conv(f, f, kernel_size=3, padding=1)\n        self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)\n        self.conv3 = conv(f, f, kernel_size=3, padding=1)\n        self.conv3_ = conv(f, f, kernel_size=3, padding=1)\n        self.conv4 = conv(f, n_feats, kernel_size=1)\n        self.sigmoid = nn.Sigmoid()\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        c1_ = (self.conv1(x))\n        c1 = self.conv2(c1_)\n        v_max = F.max_pool2d(c1, kernel_size=7, stride=3)\n        v_range = self.relu(self.conv_max(v_max))\n        c3 = self.relu(self.conv3(v_range))\n        c3 = self.conv3_(c3)\n        c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False) \n        cf = self.conv_f(c1_)\n        c4 = self.conv4(c3+cf)\n        m = self.sigmoid(c4)\n        \n        return x * m\n\n\nclass DWConv(nn.Module):\n    def __init__(self, dim=768):\n        super(DWConv, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n\n    def forward(self, x):\n        x = self.dwconv(x)\n        return x\n\n##########################\n\nclass SELayer(nn.Module):\n    '''\n    SE-block\n    '''\n    def __init__(self, channel, reduction=16):\n        super(SELayer, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel // reduction, channel, bias=False),\n            # nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        return x * y.expand_as(x)\n\nclass ResidualBlock_noBN(nn.Module):\n    '''Residual block w/o BN\n    ---Conv-ReLU-Conv-+-\n     |________________|\n    '''\n\n    def __init__(self, nf=64):\n        super(ResidualBlock_noBN, self).__init__()\n        self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n\n        # initialization\n        initialize_weights([self.conv1, self.conv2], 0.1)\n\n    def forward(self, x):\n        identity = x\n        out = F.relu(self.conv1(x), inplace=True)\n        out = self.conv2(out)\n        return identity + out\n\n\nclass ResidualBlock_SE(nn.Module):\n    '''Residual block w/o BN\n    ---Conv-ReLU-Conv-+-\n     |________________|\n    '''\n\n    def __init__(self, nf=64, reduction=16):\n        super(ResidualBlock_SE, self).__init__()\n        self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.conv3 = nn.Conv2d(3 * nf, nf, 1, padding=0, dilation=1, bias=True)\n        self.se = SELayer(nf, reduction)\n        # initialization\n        initialize_weights([self.conv1, self.conv2, self.conv3], 0.1)\n\n    def forward(self, x):\n        identity = x\n        basic_out = F.relu(self.conv1(x), inplace=True)\n        basic_out = self.conv2(basic_out)\n        se_out = self.se(basic_out)\n        out = torch.cat((identity, basic_out, se_out), 1)\n        out = self.conv3(out)\n        return out\n\n\nclass _PositionAttentionModule(nn.Module):\n    \"\"\" Position attention module\"\"\"\n\n    def __init__(self, in_channels, **kwargs):\n        super(_PositionAttentionModule, self).__init__()\n        self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)\n        self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)\n        self.conv_d = nn.Conv2d(in_channels, in_channels, 1)\n        self.alpha = nn.Parameter(torch.zeros(1))\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x):\n        batch_size, _, height, width = x.size()\n        feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)\n        feat_c = self.conv_c(x).view(batch_size, -1, height * width)\n        attention_s = self.softmax(torch.bmm(feat_b, feat_c))\n        feat_d = self.conv_d(x).view(batch_size, -1, height * width)\n        feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)\n        out = self.alpha * feat_e + x\n\n        return out\n\n## Spatial Attention (CA) Layer\nclass SALayer(nn.Module):\n    def __init__(self, wn=None):\n        super(SALayer,self).__init__()\n        self.body = nn.Sequential(\n            wn(nn.Conv2d(2, 1, 7, 1, 3, bias=False)),\n            nn.Sigmoid()\n        )\n    def forward(self, x):\n        avg_f = torch.mean(x, dim=1, keepdim=True)\n        max_f = torch.max(x, dim=1, keepdim=True)[0]\n        y = torch.cat([avg_f, max_f], dim=1)\n        return self.body(y).expand_as(x) * x\n\n\n## Channel Attention (CA) Layer\nclass CALayerV2(nn.Module):\n    def __init__(self, n_feat, reduction=16, wn=None):\n        super(CALayerV2, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.max_pool = nn.AdaptiveMaxPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        self.conv_du = nn.Sequential(\n                wn(nn.Conv2d(n_feat, n_feat//reduction, 1, padding=0, bias=False)),\n                nn.ReLU(inplace=True),\n                wn(nn.Conv2d(n_feat//reduction, n_feat, 1, padding=0, bias=False)),\n                # nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y1 = self.avg_pool(x)\n        y2 = self.max_pool(x)\n        y1 = self.conv_du(y1)\n        y2 = self.conv_du(y2)\n        return x * torch.sigmoid(y1+y2)\n\nclass DALayer(nn.Module):\n    def __init__(self, channel, reduction, wn):\n        super(DALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.ca = CALayer(channel, reduction, wn)\n        self.sa = SALayer(wn)\n        self.conv = wn(nn.Conv2d(channel*2, channel, 1))\n\n    def forward(self, x):\n        ca = self.ca(x)\n        sa = self.sa(x)\n        res = self.conv(torch.cat([ca, sa], dim=1))\n        return res + x\n\n\n## Channel Attention (CA) Layer\nclass CALayer(nn.Module):\n    def __init__(self, channel, reduction, wn):\n        super(CALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        self.conv_du = nn.Sequential(\n                wn(nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True)),\n                nn.ReLU(inplace=True),\n                wn(nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True)),\n                nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y = self.avg_pool(x)\n        y = self.conv_du(y)\n        return x * y\n\n\n## Residual Channel Attention Block (RCAB)\nclass RCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction, wn,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False):\n\n        super(RCAB, self).__init__()\n\n        expand = 6\n        linear = 0.75\n        modules_body = []\n        # for i in range(2):\n        modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))\n        modules_body.append(act)\n        modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))\n        modules_body.append(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias))\n        if da:\n            modules_body.append(DALayer(n_feat, reduction, wn))\n        else:\n            modules_body.append(CALayer(n_feat, reduction, wn))\n\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        #res = self.body(x).mul(self.res_scale)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass ResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False):\n        super(ResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = common.default_conv\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_body = []\n        modules_body = [\n            RCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \\\n            for _ in range(n_resblocks)]\n        modules_body.append(wn(conv(n_feat, n_feat, kernel_size)))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n\n################################################################\n################################################################\n################################################################\n\ndef make_layer_idx(block, n_layers):\n    layers = []\n    for i in range(n_layers):\n        layers.append(block(idx=i))\n    return nn.Sequential(*layers)\n\n## Residual Channel Attention Block (RCAB)\nclass LRSCRCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction, wn,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0):\n        super(LRSCRCAB, self).__init__()\n\n        expand = 6\n        linear = 0.75\n\n        modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []\n        # for i in range(2):\n        modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))\n        modules_body.append(act)\n        modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))\n        modules_body.append(wn(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias)))\n        if da:\n            modules_body.append(DALayer(n_feat, reduction, wn))\n        else:\n            modules_body.append(CALayer(n_feat, reduction, wn))\n\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Residual Channel Attention Block (RCAB)\nclass LRSCPYRCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction, wn,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0):\n        super(LRSCPYRCAB, self).__init__()\n\n        expand = 6\n        linear = 0.75\n\n        modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []\n        # for i in range(2):\n        modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))\n        modules_body.append(act)\n        modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))\n        modules_body.append(\n            PyConv2d(in_channels=int(n_feat*linear),\n                out_channels=[n_feat//4, n_feat//4, n_feat//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8]))\n        if da:\n            modules_body.append(DALayer(n_feat, reduction, wn))\n        else:\n            modules_body.append(CALayer(n_feat, reduction, wn))\n\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False, idx=0):\n        super(LRSCResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = common.default_conv\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(conv(n_feat*(idx+1), n_feat, 1, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCRCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \\\n            for i in range(n_resblocks)]\n        modules_body.append(wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size)))\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCPSResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False, idx=0):\n        super(LRSCPSResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = PSConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCRCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \\\n            for i in range(n_resblocks)]\n        modules_tail = [wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size))]\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res = self.tail(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCPyResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False, idx=0):\n        super(LRSCPyResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = PyConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCPYRCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \\\n            for i in range(n_resblocks)]\n        modules_tail = [wn(nn.Conv2d(n_feat*(n_resblocks+1), n_feat, 1))]\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res = self.tail(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\nclass LRSCWideActResBlock(nn.Module):\n    def __init__(self, nf=64, idx=0):\n        super(LRSCWideActResBlock, self).__init__()\n        self.res_scale = 1\n\n        expand = 6\n        linear = 0.8\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n        head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else []\n\n        body = []\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2)))\n\n        self.head = nn.Sequential(*head)\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\nclass LRSCPyWideActResBlock(nn.Module):\n    def __init__(self, nf=64, idx=0):\n        super(LRSCPyWideActResBlock, self).__init__()\n        self.res_scale = 1\n\n        expand = 6\n        linear = 0.75\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n        head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else []\n\n        body = []\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            PyConv2d(in_channels=int(nf*linear),\n                out_channels=[nf//4, nf//4, nf//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8]))\n\n        self.head = nn.Sequential(*head)\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCPyWideActResGroup(nn.Module):\n    def __init__(self, nf, n_resblocks, idx=0):\n        super(LRSCPyWideActResGroup, self).__init__()\n        kernel_size = 3\n\n        conv = PyConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCPyWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)]\n        modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))]\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res = self.tail(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCWideActResGroup(nn.Module):\n    def __init__(self, nf, n_resblocks, idx=0):\n        super(LRSCWideActResGroup, self).__init__()\n        kernel_size = 3\n\n        conv = PyConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)]\n        modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))]\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res = self.tail(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n################################################################\n################################################################\n################################################################\n\n\n## Residual Channel Attention Block (RCAB)\nclass PYRCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction, wn,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False):\n        super(PYRCAB, self).__init__()\n\n        expand = 6\n        linear = 0.75\n        modules_body = []\n        # for i in range(2):\n        modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))\n        modules_body.append(act)\n        modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))\n        # modules_body.append(conv(, n_feat, kernel_size, bias=bias))\n        modules_body.append(PyConv2d(in_channels=int(n_feat*linear),\n                out_channels=[n_feat//4, n_feat//4, n_feat//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8], bias=bias))\n        if da:\n            modules_body.append(DALayer(n_feat, reduction, wn))\n        else:\n            modules_body.append(CALayer(n_feat, reduction, wn))\n\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass PyResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False):\n        super(PyResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = PyConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_body = []\n        modules_body = [\n            PYRCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \\\n            for _ in range(n_resblocks)]\n        modules_body.append(\n            PyConv2d(in_channels=n_feat,\n                out_channels=[n_feat//4, n_feat//4, n_feat//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8]))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\nclass WideActResBlock(nn.Module):\n    def __init__(self, nf=64):\n        super(WideActResBlock, self).__init__()\n        self.res_scale = 1\n        body = []\n        expand = 6\n        linear = 0.8\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2)))\n\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.body(x) * self.res_scale\n        res += x\n        return res\n\n\nclass PSWideActResBlock(nn.Module):\n    def __init__(self, nf=64):\n        super(PSWideActResBlock, self).__init__()\n        self.res_scale = 1\n        body = []\n        expand = 6\n        linear = 0.75\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            wn(PSConv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2)))\n\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.body(x) * self.res_scale\n        res += x\n        return res\n\n\nclass PyWideActResBlock(nn.Module):\n    def __init__(self, nf=64):\n        super(PyWideActResBlock, self).__init__()\n        self.res_scale = 1\n        body = []\n        expand = 6\n        linear = 0.75\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n        expand_nf = nf*expand\n        linear_nf = int(nf * linear)\n\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            PyConv2d(in_channels=linear_nf,\n                out_channels=[nf//4, nf//4, nf//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8]))\n\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.body(x) * self.res_scale\n        res += x\n        return res\n\n\ndef flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True, use_pad_mask=False):\n    \"\"\"Warp an image or feature map with optical flow.\n\n    Args:\n        x (Tensor): Tensor with size (n, c, h, w).\n        flow (Tensor): Tensor with size (n, h, w, 2), normal value.\n        interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'.\n        padding_mode (str): 'zeros' or 'border' or 'reflection'.\n            Default: 'zeros'.\n        align_corners (bool): Before pytorch 1.3, the default value is\n            align_corners=True. After pytorch 1.3, the default value is\n            align_corners=False. Here, we use the True as default.\n        use_pad_mask (bool): only used for PWCNet, x is first padded with ones along the channel dimension.\n            The mask is generated according to the grid_sample results of the padded dimension.\n\n\n    Returns:\n        Tensor: Warped image or feature map.\n    \"\"\"\n    # assert x.size()[-2:] == flow.size()[1:3] # temporaily turned off for image-wise shift\n    n, _, h, w = x.size()\n    x = x.float()\n    # create mesh grid\n    # grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) # an illegal memory access on TITAN RTX + PyTorch1.9.1\n    grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), torch.arange(0, w, dtype=x.dtype, device=x.device))\n    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2\n    grid.requires_grad = False\n    grid = grid.type_as(x)\n    vgrid = grid + flow\n\n    # if use_pad_mask: # for PWCNet\n    #     x = F.pad(x, (0,0,0,0,0,1), mode='constant', value=1)\n\n    # scale grid to [-1,1]\n    if interp_mode == 'nearest4': # todo: bug, no gradient for flow model in this case!!! but the result is good\n        vgrid_x_floor = 2.0 * torch.floor(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0\n        vgrid_x_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0\n        vgrid_y_floor = 2.0 * torch.floor(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0\n        vgrid_y_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0\n\n        output00 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)\n        output01 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)\n        output10 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)\n        output11 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)\n\n        return torch.cat([output00, output01, output10, output11], 1)\n\n    else:\n        vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0\n        vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0\n        vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)\n        output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)\n\n        # if use_pad_mask: # for PWCNet\n        #     output = _flow_warp_masking(output)\n\n        # TODO, what if align_corners=False\n        return output\n\n\n# def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):\n#     \"\"\"Warp an image or feature map with optical flow\n#     Args:\n#         x (Tensor): size (N, C, H, W)\n#         flow (Tensor): size (N, H, W, 2), normal value\n#         interp_mode (str): 'nearest' or 'bilinear'\n#         padding_mode (str): 'zeros' or 'border' or 'reflection'\n\n#     Returns:\n#         Tensor: warped image or feature map\n#     \"\"\"\n#     assert x.size()[-2:] == flow.size()[1:3]\n#     B, C, H, W = x.size()\n#     # mesh grid\n#     grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))\n#     grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2\n#     grid.requires_grad = False\n#     grid = grid.type_as(x)\n#     vgrid = grid + flow\n#     # scale grid to [-1,1]\n#     vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0\n#     vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0\n#     vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)\n#     output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)\n#     return output\n"
  },
  {
    "path": "code/real/bsrt/model/bsrt.py",
    "content": "import functools\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport model.arch_util as arch_util\nfrom torch.cuda.amp import autocast\nimport model.swin_util as swu\nimport time\nimport os\nimport math\nfrom utils.debayer import Debayer3x3\nimport torchvision.utils as tvutils\nfrom datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch\n\ntry:\n    from model.non_local.non_local_cross_dot_product import NONLocalBlock2D as NonLocalCross\n    from model.non_local.non_local_dot_product import NONLocalBlock2D as NonLocal\nexcept ImportError:\n    raise ImportError('Failed to import Non_Local module.')\n\ntry:\n    from model.DCNv2.dcn_v2 import DCN_sep as DCN, FlowGuidedDCN, InsideFlowGuidedDCN\nexcept ImportError:\n    raise ImportError('Failed to import DCNv2 module.')\n\n\ndef make_model(args, parent=False):\n    nframes = args.burst_size\n    img_size = args.patch_size * 2\n    patch_size = 1\n    in_chans = args.burst_channel\n    out_chans = args.n_colors\n    \n    if args.model_level == \"S\":\n        depths = [6]*1 + [6] * 4\n        num_heads = [6]*1 + [6] * 4\n        embed_dim = 60\n    elif args.model_level == \"L\":\n        depths = [6]*1 + [8] * 6\n        num_heads = [6]*1 + [6] * 6\n        embed_dim = 180\n    window_size = 8\n    mlp_ratio = 2\n    upscale = args.scale[0]\n    non_local = args.non_local\n    use_checkpoint=args.use_checkpoint\n\n    if args.local_rank <= 0:\n        print(\"depths: \", depths)\n\n    return BSRT(args=args,nframes=nframes,\n                   img_size=img_size,\n                   patch_size=patch_size,\n                   in_chans=in_chans,\n                   out_chans=out_chans,\n                   embed_dim=embed_dim,\n                   depths=depths,\n                   num_heads=num_heads,\n                   window_size=window_size,\n                   mlp_ratio=mlp_ratio,\n                   upscale=upscale,\n                   non_local=non_local,\n                   use_checkpoint=use_checkpoint)\n\n\nclass BasicModule(nn.Module):\n    \"\"\"Basic Module for SpyNet.\n    \"\"\"\n\n    def __init__(self):\n        super(BasicModule, self).__init__()\n\n        self.basic_module = nn.Sequential(\n            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))\n\n    def forward(self, tensor_input):\n        return self.basic_module(tensor_input)\n\n\nclass SpyNet(nn.Module):\n    \"\"\"SpyNet architecture.\n\n    Args:\n        load_path (str): path for pretrained SpyNet. Default: None.\n        return_levels (list[int]): return flows of different levels. Default: [5].\n    \"\"\"\n\n    def __init__(self, load_path=None, return_levels=[5]):\n        super(SpyNet, self).__init__()\n        self.return_levels = return_levels\n        self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])\n        if load_path:\n            if not os.path.exists(load_path):\n                import requests\n                url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/spynet_sintel_final-3d2a1287.pth'\n                r = requests.get(url, allow_redirects=True)\n                print(f'downloading SpyNet pretrained model from {url}')\n                os.makedirs(os.path.dirname(load_path), exist_ok=True)\n                open(load_path, 'wb').write(r.content)\n\n            self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])\n\n        self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))\n        self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))\n\n    def preprocess(self, tensor_input):\n        tensor_output = (tensor_input - self.mean) / self.std\n        return tensor_output\n\n    def process(self, ref, supp, w, h, w_floor, h_floor):\n        flow_list = []\n\n        ref = [self.preprocess(ref)]\n        supp = [self.preprocess(supp)]\n\n        # ref = [ref]\n        # supp = [supp]\n\n        for level in range(5):\n            ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))\n            supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))\n\n        flow = ref[0].new_zeros(\n            [ref[0].size(0), 2,\n             int(math.floor(ref[0].size(2) / 2.0)),\n             int(math.floor(ref[0].size(3) / 2.0))])\n\n        for level in range(len(ref)):\n            upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0\n\n            if upsampled_flow.size(2) != ref[level].size(2):\n                upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')\n            if upsampled_flow.size(3) != ref[level].size(3):\n                upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')\n\n            flow = self.basic_module[level](torch.cat([\n                ref[level],\n                arch_util.flow_warp(\n                    supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),\n                upsampled_flow\n            ], 1)) + upsampled_flow\n\n            if level in self.return_levels:\n                scale = 2**(5-level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8)\n                flow_out = F.interpolate(input=flow, size=(h//scale, w//scale), mode='bilinear', align_corners=False)\n                flow_out[:, 0, :, :] *= float(w//scale) / float(w_floor//scale)\n                flow_out[:, 1, :, :] *= float(h//scale) / float(h_floor//scale)\n                if torch.abs(flow_out).mean() > 200:\n                    print(f\"level {level}, flow > 200: {torch.abs(flow_out).mean():.4f}\")\n                    # return None\n                    flow_out.clamp(-250, 250)\n                flow_list.insert(0, flow_out)\n\n        return flow_list\n\n    def forward(self, ref, supp):\n        assert ref.size() == supp.size()\n\n        h, w = ref.size(2), ref.size(3)\n        w_floor = math.floor(math.ceil(w / 32.0) * 32.0)\n        h_floor = math.floor(math.ceil(h / 32.0) * 32.0)\n\n        ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)\n        supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)\n\n        flow_list = self.process(ref, supp, w, h, w_floor, h_floor)\n\n        return flow_list[0] if len(flow_list) == 1 else flow_list\n\n\n\nclass FlowGuidedPCDAlign(nn.Module):\n    ''' Alignment module using Pyramid, Cascading and Deformable convolution\n    with 3 pyramid levels. [From EDVR]\n    '''\n\n    def __init__(self, nf=64, groups=8):\n        super(FlowGuidedPCDAlign, self).__init__()\n        # L3: level 3, 1/4 spatial size\n        self.L3_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L3_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n\n        # L2: level 2, 1/2 spatial size\n        self.L2_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L2_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n        self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n\n        # L1: level 1, original spatial size\n        self.L1_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L1_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n        self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n\n        # Cascading DCN\n        self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)\n\n    def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l):\n        '''align other neighboring frames to the reference frame in the feature level\n        nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features\n        '''\n        # L3\n        L3_offset = torch.cat([nbr_fea_warped_l[2], ref_fea_l[2], flows_l[2]], dim=1)\n        L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))\n        L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))\n        L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset, flows_l[2]))\n        # L2\n        L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)\n        L2_offset = torch.cat([nbr_fea_warped_l[1], ref_fea_l[1], flows_l[1]], dim=1)\n        L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))\n        L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset*2], dim=1)))\n        L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))\n        L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset, flows_l[1])\n        L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)\n        L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))\n        # L1\n        L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)\n        L1_offset = torch.cat([nbr_fea_warped_l[0], ref_fea_l[0], flows_l[0]], dim=1)\n        L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))\n        L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))\n        L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))\n        L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset, flows_l[0])\n        L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)\n        L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))\n\n        # Cascading\n        offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)\n        offset = self.lrelu(self.cas_offset_conv1(offset))\n        offset = self.lrelu(self.cas_offset_conv2(offset))\n        L1_fea = self.cas_dcnpack(L1_fea, offset)\n\n        return L1_fea\n\n\nclass CrossNonLocal_Fusion(nn.Module):\n    ''' Cross Non Local fusion module\n    '''\n    def __init__(self, nf=64, out_feat=96, nframes=5, center=2):\n        super(CrossNonLocal_Fusion, self).__init__()\n        self.center = center\n\n        self.non_local_T = nn.ModuleList()\n        self.non_local_F = nn.ModuleList()\n\n        for i in range(nframes):\n            self.non_local_T.append(NonLocalCross(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False))\n            self.non_local_F.append(NonLocal(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False))\n\n        # fusion conv: using 1x1 to save parameters and computation\n        self.fea_fusion = nn.Conv2d(nframes * nf*2, out_feat, 3, 1, 1, bias=True)\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)\n\n    def forward(self, aligned_fea):\n        B, N, C, H, W = aligned_fea.size()  # N video frames\n        ref = aligned_fea[:, self.center, :, :, :].clone()\n\n        cor_l = []\n        non_l = []\n        for i in range(N):\n            nbr = aligned_fea[:, i, :, :, :]\n            non_l.append(self.non_local_F[i](nbr))\n            cor_l.append(self.non_local_T[i](nbr, ref))\n\n        aligned_fea_T = torch.cat(cor_l, dim=1)\n        aligned_fea_F = torch.cat(non_l, dim=1)\n        aligned_fea = torch.cat([aligned_fea_T, aligned_fea_F], dim=1)\n\n        #### fusion\n        fea = self.fea_fusion(aligned_fea)\n\n        return fea\n\n\n\nclass BSRT(nn.Module):\n    def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chans=3, out_chans=3,\n                 embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, upscale=4, non_local=False,\n                 **kwargs):\n        super(BSRT, self).__init__()\n        num_in_ch = in_chans\n        num_out_ch = out_chans\n        num_feat = 64\n        groups = 8\n        # embed_dim = num_feat\n        back_RBs = 5\n        n_resblocks = 6\n\n        self.args = args\n        self.center = 0\n        self.upscale = upscale\n        self.window_size = window_size\n        self.non_local = non_local\n        self.nframes = nframes\n\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = embed_dim\n        self.mlp_ratio = mlp_ratio\n\n        spynet_path='/home/luoziwei/.pretrained_models/spynet_sintel_final-3d2a1287.pth'\n        self.spynet = SpyNet(spynet_path, [3, 4, 5])\n        self.conv_flow = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)\n        self.flow_ps = nn.PixelShuffle(2)\n\n        # split image into non-overlapping patches\n        self.patch_embed = swu.PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # merge non-overlapping patches into image\n        self.patch_unembed = swu.PatchUnEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n\n        #####################################################################################################\n        ################################### 1, shallow feature extraction ###################################\n        self.conv_first = nn.Conv2d(num_in_ch*(1+2*0), embed_dim, 3, 1, 1, bias=True)\n        \n        # # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        if args.swinfeature:\n            if self.args.local_rank <= 0:\n                print(\"using swinfeature\")\n            self.pre_layers = nn.ModuleList()\n            for i_layer in range(depths[0]):\n                layer = swu.SwinTransformerBlock(dim=embed_dim, \n                            input_resolution=(patches_resolution[0]//2,\n                                              patches_resolution[1]//2),\n                             num_heads=num_heads[0], window_size=window_size,\n                             shift_size=0 if (i_layer % 2 == 0) else window_size // 2,\n                             mlp_ratio=mlp_ratio,\n                             qkv_bias=qkv_bias, qk_scale=qk_scale,\n                             drop=drop_rate, attn_drop=attn_drop_rate,\n                             drop_path=dpr[i_layer],\n                             norm_layer=norm_layer)\n                self.pre_layers.append(layer)\n\n            self.pre_norm = norm_layer(embed_dim)\n        else:\n            WARB = functools.partial(arch_util.WideActResBlock, nf=embed_dim)\n            self.feature_extraction = arch_util.make_layer(WARB, 5)\n\n        self.conv_after_pre_layer = nn.Conv2d(embed_dim, num_feat*4, 3, 1, 1, bias=True)\n        self.mid_ps = nn.PixelShuffle(2)\n\n        self.fea_L2_conv1 = nn.Conv2d(num_feat, num_feat*2, 3, 2, 1, bias=True)\n        self.fea_L3_conv1 = nn.Conv2d(num_feat*2, num_feat*4, 3, 2, 1, bias=True)\n\n        #####################################################################################################\n        ################################### 2, Feature Enhanced PCD Align ###################################\n\n        # Top layers\n        self.toplayer = nn.Conv2d(num_feat*4, num_feat, kernel_size=1, stride=1, padding=0)\n        # Smooth layers\n        self.smooth1 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1)\n        self.smooth2 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1)\n        # Lateral layers\n        self.latlayer1 = nn.Conv2d(num_feat*2, num_feat, kernel_size=1, stride=1, padding=0)\n        self.latlayer2 = nn.Conv2d(num_feat*1, num_feat, kernel_size=1, stride=1, padding=0)\n\n        # self.align = PCD_Align(nf=num_feat, groups=groups)\n        self.align = FlowGuidedPCDAlign(nf=num_feat, groups=groups)\n        #####################################################################################################\n        ################################### 3, Multi-frame Feature Fusion  ##################################\n\n        if self.non_local:\n            if self.args.local_rank <= 0:\n                print(\"using non_local\")\n            self.fusion = CrossNonLocal_Fusion(nf=num_feat, out_feat=embed_dim, nframes=nframes, center=self.center)\n        else:\n            self.fusion = nn.Conv2d(nframes * num_feat, embed_dim, 1, 1, bias=True)\n\n        #####################################################################################################\n        ################################### 4, deep feature extraction ######################################\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            swu.trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # build Residual Swin Transformer blocks (RSTB)\n        self.layers = nn.ModuleList()\n        for i_layer in range(1, self.num_layers):\n            layer = swu.RSTB(dim=embed_dim,\n                         input_resolution=(patches_resolution[0],\n                                           patches_resolution[1]),\n                         depth=depths[i_layer],\n                         num_heads=num_heads[i_layer],\n                         window_size=window_size,\n                         mlp_ratio=self.mlp_ratio,\n                         qkv_bias=qkv_bias, qk_scale=qk_scale,\n                         drop=drop_rate, attn_drop=attn_drop_rate,\n                         drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results\n                         norm_layer=norm_layer,\n                         downsample=None,\n                         use_checkpoint=use_checkpoint,\n                         img_size=img_size,\n                         patch_size=patch_size\n                         )\n            self.layers.append(layer)\n        \n        self.norm = norm_layer(self.num_features)\n\n        # build the last conv layer in deep feature extraction\n        self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)\n\n        #####################################################################################################\n        ################################ 5, high quality image reconstruction ################################\n\n        self.upconv1 = nn.Conv2d(embed_dim, num_feat * 4, 3, 1, 1, bias=True)\n        self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)\n        self.pixel_shuffle = nn.PixelShuffle(2)\n        self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)\n        self.conv_last = nn.Conv2d(64, args.n_colors, 3, 1, 1, bias=True)\n\n        #### skip #############\n        self.skip_pixel_shuffle = nn.PixelShuffle(2)\n        self.skipup1 = nn.Conv2d(num_in_ch//4, num_feat * 4, 3, 1, 1, bias=True)\n        self.skipup2 = nn.Conv2d(num_feat, args.n_colors * 4, 3, 1, 1, bias=True)\n\n        #### activation function\n        self.lrelu = nn.LeakyReLU(0.1, inplace=True)\n        self.lrelu2 = nn.LeakyReLU(0.1, inplace=True)\n\n        self.apply(self._init_weights)\n\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            swu.trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def _upsample_add(self, x, y):\n        return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + y\n\n    def check_image_size(self, x):\n        _, _, h, w = x.size()\n        mod_pad_h = (self.window_size - h % self.window_size) % self.window_size\n        mod_pad_w = (self.window_size - w % self.window_size) % self.window_size\n        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')\n        return x\n\n    def pre_forward_features(self, x):\n        if self.args.swinfeature:\n            x_size = (x.shape[-2], x.shape[-1])\n            x = self.patch_embed(x, use_norm=True)\n            if self.ape:\n                x = x + self.absolute_pos_embed\n            x = self.pos_drop(x)\n\n            for idx, layer in enumerate(self.pre_layers):\n                x = layer(x, x_size)\n\n            x = self.pre_norm(x)\n            x = self.patch_unembed(x, x_size)\n\n        else:\n            x = self.feature_extraction(x)\n\n        return x\n\n    def forward_features(self, x):\n        x_size = (x.shape[-2], x.shape[-1])\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for idx, layer in enumerate(self.layers):\n            x = layer(x, x_size)\n            if torch.any(torch.isinf(x)) or torch.any(torch.isnan(x)):\n                print('layer: ', idx)\n\n        x = self.norm(x)  # B L C\n        x = self.patch_unembed(x, x_size)\n\n        return x\n\n    @autocast()\n    def forward(self, x, print_time=False):\n        B, N, C, H, W = x.size()  # N video frames\n        x_center = x[:, self.center, :, :, :].contiguous()\n\n        #### skip module ########\n        skip1 = self.lrelu2(self.skip_pixel_shuffle(self.skipup1(self.skip_pixel_shuffle(x_center))))\n        skip2 = self.skip_pixel_shuffle(self.skipup2(skip1))\n\n        x_ = self.conv_flow(self.flow_ps(x.view(B*N, C, H, W))).view(B, N, -1, H*2, W*2)\n        \n        # calculate flows\n        ref_flows = self.get_ref_flows(x_)\n\n        #### extract LR features\n        x = self.lrelu(self.conv_first(x.view(B*N, -1, H, W)))\n\n        L1_fea = self.mid_ps(self.conv_after_pre_layer(self.pre_forward_features(x)))\n        _, _, H, W = L1_fea.size()\n\n        L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))\n        L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))\n\n        # FPN enhance features\n        L3_fea = self.lrelu(self.toplayer(L3_fea))\n        L2_fea = self.smooth1(self._upsample_add(L3_fea, self.latlayer1(L2_fea)))\n        L1_fea = self.smooth2(self._upsample_add(L2_fea, self.latlayer2(L1_fea)))\n\n        L1_fea = L1_fea.view(B, N, -1, H, W).contiguous()\n        L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2 ).contiguous()\n        L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4).contiguous()\n\n        #### PCD align\n        # ref feature list\n        ref_fea_l = [\n            L1_fea[:, self.center, :, :, :].clone(), \n            L2_fea[:, self.center, :, :, :].clone(),\n            L3_fea[:, self.center, :, :, :].clone()\n        ]\n        aligned_fea = []\n        for i in range(N):\n            nbr_fea_l = [\n                L1_fea[:, i, :, :, :].clone(), \n                L2_fea[:, i, :, :, :].clone(),\n                L3_fea[:, i, :, :, :].clone()\n            ]\n            flows_l = [\n                ref_flows[0][:, i, :, :, :].clone(), \n                ref_flows[1][:, i, :, :, :].clone(), \n                ref_flows[2][:, i, :, :, :].clone()\n            ]\n            # print(nbr_fea_l[0].shape, flows_l[0].shape)\n            nbr_warped_l = [\n                arch_util.flow_warp(nbr_fea_l[0], flows_l[0].permute(0, 2, 3, 1), 'bilinear'),\n                arch_util.flow_warp(nbr_fea_l[1], flows_l[1].permute(0, 2, 3, 1), 'bilinear'),\n                arch_util.flow_warp(nbr_fea_l[2], flows_l[2].permute(0, 2, 3, 1), 'bilinear')\n            ]\n            aligned_fea.append(self.align(nbr_fea_l, nbr_warped_l, ref_fea_l, flows_l))\n\n        aligned_fea = torch.stack(aligned_fea, dim=1)  # [B, N, C, H, W] --> [B, T, C, H, W]\n\n        if not self.non_local:\n            aligned_fea = aligned_fea.view(B, -1, H, W)\n\n        x = self.lrelu(self.fusion(aligned_fea))\n\n        x = self.lrelu(self.conv_after_body(self.forward_features(x))) + x\n\n        x = self.lrelu(self.pixel_shuffle(self.upconv1(x)))\n        x = skip1 + x\n        x = self.lrelu(self.pixel_shuffle(self.upconv2(x)))\n        x = self.lrelu(self.HRconv(x))\n        x = self.conv_last(x)\n\n        x = skip2 + x\n        return x\n\n\n    def get_ref_flows(self, x):\n        '''Get flow between frames ref and other'''\n\n        b, n, c, h, w = x.size()\n        x_nbr = x.reshape(-1, c, h, w)\n        x_ref = x[:, self.center:self.center+1, :, :, :].repeat(1, n, 1, 1, 1).reshape(-1, c, h, w)\n\n        # backward\n        flows = self.spynet(x_ref, x_nbr)\n        flows_list = [flow.view(b, n, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in\n                          zip(flows, range(3))]\n\n        return flows_list\n\n\n\n\n\n\n\n"
  },
  {
    "path": "code/real/bsrt/model/checkpoint.py",
    "content": "import torch\nimport warnings\n\n\ndef detach_variable(inputs):\n    if isinstance(inputs, tuple):\n        out = []\n        for inp in inputs:\n            x = inp.detach()\n            x.requires_grad = inp.requires_grad\n            out.append(x)\n        return tuple(out)\n    else:\n        raise RuntimeError(\n            \"Only tuple of tensors is supported. Got Unsupported input type: \", type(inputs).__name__)\n\n\ndef check_backward_validity(inputs):\n    if not any(inp.requires_grad for inp in inputs):\n        warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n\n\nclass CheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n        with torch.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        for i in range(len(ctx.input_tensors)):\n            temp = ctx.input_tensors[i]\n            ctx.input_tensors[i] = temp.detach()\n            ctx.input_tensors[i].requires_grad = temp.requires_grad\n        with torch.enable_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)\n        return (None, None) + input_grads\n"
  },
  {
    "path": "code/real/bsrt/model/common.py",
    "content": "import math\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef default_conv(in_channels, out_channels, kernel_size, bias=True):\n    return nn.Conv2d(\n        in_channels, out_channels, kernel_size,\n        padding=(kernel_size // 2), bias=bias)\n\n\nclass MeanShift(nn.Conv2d):\n    def __init__(\n            self, rgb_range,\n            rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):\n        super(MeanShift, self).__init__(3, 3, kernel_size=1)\n        std = torch.Tensor(rgb_std)\n        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)\n        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std\n        for p in self.parameters():\n            p.requires_grad = False\n\n\nclass BasicBlock(nn.Sequential):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,\n            bn=True, act=nn.ReLU(True)):\n\n        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]\n        if bn:\n            m.append(nn.BatchNorm2d(out_channels))\n        if act is not None:\n            m.append(act)\n\n        super(BasicBlock, self).__init__(*m)\n\n\nclass ResBlock(nn.Module):\n    def __init__(\n            self, conv, n_feats, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(ResBlock, self).__init__()\n        m = []\n        for i in range(2):\n            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))\n            if bn:\n                m.append(nn.BatchNorm2d(n_feats))\n            if i == 0:\n                m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x).mul(self.res_scale)\n        res += x\n\n        return res\n\n\nclass Upsampler(nn.Sequential):\n    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):\n\n        m = []\n        if (scale & (scale - 1)) == 0:  # Is scale = 2^n?\n            for _ in range(int(math.log(scale, 2))):\n                m.append(conv(n_feats, 4 * n_feats, 3, bias))\n                m.append(nn.PixelShuffle(2))\n                if bn:\n                    m.append(nn.BatchNorm2d(n_feats))\n                if act == 'relu':\n                    m.append(nn.ReLU(True))\n                elif act == 'prelu':\n                    m.append(nn.PReLU(n_feats))\n\n        elif scale == 3:\n            m.append(conv(n_feats, 9 * n_feats, 3, bias))\n            m.append(nn.PixelShuffle(3))\n            if bn:\n                m.append(nn.BatchNorm2d(n_feats))\n            if act == 'relu':\n                m.append(nn.ReLU(True))\n            elif act == 'prelu':\n                m.append(nn.PReLU(n_feats))\n        else:\n            raise NotImplementedError\n\n        super(Upsampler, self).__init__(*m)\n\n\nclass UpOnly(nn.Sequential):\n    def __init__(self, scale):\n\n        m = []\n        if (scale & (scale - 1)) == 0:  # Is scale = 2^n?\n            for _ in range(int(math.log(scale, 2))):\n                m.append(nn.PixelShuffle(2))\n\n\n        elif scale == 3:\n\n            m.append(nn.PixelShuffle(3))\n\n        else:\n            raise NotImplementedError\n\n        super(UpOnly, self).__init__(*m)\n\n\ndef lanczos_kernel(dx, a=3, N=None, dtype=None, device=None):\n    '''\n    Generates 1D Lanczos kernels for translation and interpolation.\n    Args:\n        dx : float, tensor (batch_size, 1), the translation in pixels to shift an image.\n        a : int, number of lobes in the kernel support.\n            If N is None, then the width is the kernel support (length of all lobes),\n            S = 2(a + ceil(dx)) + 1.\n        N : int, width of the kernel.\n            If smaller than S then N is set to S.\n    Returns:\n        k: tensor (?, ?), lanczos kernel\n    '''\n\n    if not torch.is_tensor(dx):\n        dx = torch.tensor(dx, dtype=dtype, device=device)\n\n    if device is None:\n        device = dx.device\n\n    if dtype is None:\n        dtype = dx.dtype\n\n    D = dx.abs().ceil().int()\n    S = 2 * (a + D) + 1  # width of kernel support\n\n    S_max = S.max() if hasattr(S, 'shape') else S\n\n    if (N is None) or (N < S_max):\n        N = S\n\n    Z = (N - S) // 2  # width of zeros beyond kernel support\n\n    start = (-(a + D + Z)).min()\n    end = (a + D + Z + 1).max()\n    x = torch.arange(start, end, dtype=dtype, device=device).view(1, -1) - dx\n    px = (np.pi * x) + 1e-3\n\n    sin_px = torch.sin(px)\n    sin_pxa = torch.sin(px / a)\n\n    k = a * sin_px * sin_pxa / px ** 2  # sinc(x) masked by sinc(x/a)\n\n    return k\n\n\ndef lanczos_shift(img, shift, p=5, a=3):\n    '''\n    Shifts an image by convolving it with a Lanczos kernel.\n    Lanczos interpolation is an approximation to ideal sinc interpolation,\n    by windowing a sinc kernel with another sinc function extending up to a\n    few nunber of its lobes (typically a=3).\n\n    Args:\n        img : tensor (batch_size, channels, height, width), the images to be shifted\n        shift : tensor (batch_size, 2) of translation parameters (dy, dx)\n        p : int, padding width prior to convolution (default=3)\n        a : int, number of lobes in the Lanczos interpolation kernel (default=3)\n    Returns:\n        I_s: tensor (batch_size, channels, height, width), shifted images\n    '''\n    img = img.transpose(0, 1)\n    dtype = img.dtype\n\n    if len(img.shape) == 2:\n        img = img[None, None].repeat(1, shift.shape[0], 1, 1)  # batch of one image\n    elif len(img.shape) == 3:  # one image per shift\n        assert img.shape[0] == shift.shape[0]\n        img = img[None,]\n\n    # Apply padding\n\n    padder = torch.nn.ReflectionPad2d(p)  # reflect pre-padding\n    I_padded = padder(img)\n\n    # Create 1D shifting kernels\n\n    y_shift = shift[:, [0]]\n    x_shift = shift[:, [1]]\n\n    k_y = (lanczos_kernel(y_shift, a=a, N=None, dtype=dtype)\n           .flip(1)  # flip axis of convolution\n           )[:, None, :, None]  # expand dims to get shape (batch, channels, y_kernel, 1)\n    k_x = (lanczos_kernel(x_shift, a=a, N=None, dtype=dtype)\n           .flip(1)\n           )[:, None, None, :]  # shape (batch, channels, 1, x_kernel)\n\n    # Apply kernels\n    # print(I_padded.shape, k_y.shape)\n    I_s = torch.conv1d(I_padded,\n                       groups=k_y.shape[0],\n                       weight=k_y,\n                       padding=[k_y.shape[2] // 2, 0])  # same padding\n    I_s = torch.conv1d(I_s,\n                       groups=k_x.shape[0],\n                       weight=k_x,\n                       padding=[0, k_x.shape[3] // 2])\n\n    I_s = I_s[..., p:-p, p:-p]  # remove padding\n\n    # print(I_s.shape)\n    return I_s.transpose(0, 1)  # , k.squeeze()\n"
  },
  {
    "path": "code/real/bsrt/model/non_local/network.py",
    "content": "from torch import nn\n# from lib.non_local_concatenation import NONLocalBlock2D\n# from lib.non_local_gaussian import NONLocalBlock2D\nfrom lib.non_local_embedded_gaussian import NONLocalBlock2D\n# from lib.non_local_dot_product import NONLocalBlock2D\n\n\nclass Network(nn.Module):\n    def __init__(self):\n        super(Network, self).__init__()\n\n        self.conv_1 = nn.Sequential(\n            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(32),\n            nn.ReLU(),\n            nn.MaxPool2d(2),\n        )\n\n        self.nl_1 = NONLocalBlock2D(in_channels=32)\n        self.conv_2 = nn.Sequential(\n            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(64),\n            nn.ReLU(),\n            nn.MaxPool2d(2),\n        )\n\n        self.nl_2 = NONLocalBlock2D(in_channels=64)\n        self.conv_3 = nn.Sequential(\n            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(128),\n            nn.ReLU(),\n            nn.MaxPool2d(2),\n        )\n\n        self.fc = nn.Sequential(\n            nn.Linear(in_features=128*3*3, out_features=256),\n            nn.ReLU(),\n            nn.Dropout(0.5),\n\n            nn.Linear(in_features=256, out_features=10)\n        )\n\n    def forward(self, x):\n        batch_size = x.size(0)\n\n        feature_1 = self.conv_1(x)\n        nl_feature_1 = self.nl_1(feature_1)\n\n        feature_2 = self.conv_2(nl_feature_1)\n        nl_feature_2 = self.nl_2(feature_2)\n\n        output = self.conv_3(nl_feature_2).view(batch_size, -1)\n        output = self.fc(output)\n\n        return output\n\n    def forward_with_nl_map(self, x):\n        batch_size = x.size(0)\n\n        feature_1 = self.conv_1(x)\n        nl_feature_1, nl_map_1 = self.nl_1(feature_1, return_nl_map=True)\n\n        feature_2 = self.conv_2(nl_feature_1)\n        nl_feature_2, nl_map_2 = self.nl_2(feature_2, return_nl_map=True)\n\n        output = self.conv_3(nl_feature_2).view(batch_size, -1)\n        output = self.fc(output)\n\n        return output, [nl_map_1, nl_map_2]\n\n\nif __name__ == '__main__':\n    import torch\n\n    img = torch.randn(3, 1, 28, 28)\n    net = Network()\n    out = net(img)\n    print(out.size())\n\n"
  },
  {
    "path": "code/real/bsrt/model/non_local/non_local_concatenation.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(2))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                             kernel_size=1, stride=1, padding=0)\n\n        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                           kernel_size=1, stride=1, padding=0)\n\n        self.concat_project = nn.Sequential(\n            nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),\n            nn.ReLU()\n        )\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = nn.Sequential(self.phi, max_pool_layer)\n\n    def forward(self, x, return_nl_map=False):\n        '''\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        '''\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        # (b, c, N, 1)\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)\n        # (b, c, 1, N)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)\n\n        h = theta_x.size(2)\n        w = phi_x.size(3)\n        theta_x = theta_x.repeat(1, 1, 1, w)\n        phi_x = phi_x.repeat(1, 1, h, 1)\n\n        concat_feature = torch.cat([theta_x, phi_x], dim=1)\n        f = self.concat_project(concat_feature)\n        b, _, h, w = f.size()\n        f = f.view(b, h, w)\n\n        N = f.size(-1)\n        f_div_C = f / N\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n"
  },
  {
    "path": "code/real/bsrt/model/non_local/non_local_cross_dot_product.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(4))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                             kernel_size=1, stride=1, padding=0)\n\n        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                           kernel_size=1, stride=1, padding=0)\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = nn.Sequential(self.phi, max_pool_layer)\n\n    def forward(self, x, ref, return_nl_map=False):\n        \"\"\"\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        \"\"\"\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_ref = self.theta(ref).view(batch_size, self.inter_channels, -1)\n        theta_ref = theta_ref.permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n        f = torch.matmul(theta_ref, phi_x)\n        N = f.size(-1)\n        f_div_C = f / N\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n\n\n"
  },
  {
    "path": "code/real/bsrt/model/non_local/non_local_dot_product.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(2))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                             kernel_size=1, stride=1, padding=0)\n\n        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                           kernel_size=1, stride=1, padding=0)\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = nn.Sequential(self.phi, max_pool_layer)\n\n    def forward(self, x, return_nl_map=False):\n        \"\"\"\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        \"\"\"\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n        f = torch.matmul(theta_x, phi_x)\n        N = f.size(-1)\n        f_div_C = f / N\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n\n\n"
  },
  {
    "path": "code/real/bsrt/model/non_local/non_local_embedded_gaussian.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        \"\"\"\n        :param in_channels:\n        :param inter_channels:\n        :param dimension:\n        :param sub_sample:\n        :param bn_layer:\n        \"\"\"\n\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(2))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                             kernel_size=1, stride=1, padding=0)\n        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                           kernel_size=1, stride=1, padding=0)\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = nn.Sequential(self.phi, max_pool_layer)\n\n    def forward(self, x, return_nl_map=False):\n        \"\"\"\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        \"\"\"\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n        f = torch.matmul(theta_x, phi_x)\n        f_div_C = F.softmax(f, dim=-1)\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer,)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer,)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n\n"
  },
  {
    "path": "code/real/bsrt/model/non_local/non_local_gaussian.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(2))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = max_pool_layer\n\n    def forward(self, x, return_nl_map=False):\n        \"\"\"\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        \"\"\"\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = x.view(batch_size, self.in_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n\n        if self.sub_sample:\n            phi_x = self.phi(x).view(batch_size, self.in_channels, -1)\n        else:\n            phi_x = x.view(batch_size, self.in_channels, -1)\n\n        f = torch.matmul(theta_x, phi_x)\n        f_div_C = F.softmax(f, dim=-1)\n\n        # if self.store_last_batch_nl_map:\n        #     self.nl_map = f_div_C\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n\n\n\n\n\n"
  },
  {
    "path": "code/real/bsrt/model/swin_util.py",
    "content": "# -----------------------------------------------------------------------------------\n# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257\n# Originally Written by Ze Liu, Modified by Jingyun Liang.\n# -----------------------------------------------------------------------------------\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n# import torch.utils.checkpoint as checkpoint\nfrom model.checkpoint import CheckpointFunction as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nimport time\nfrom functools import reduce, lru_cache\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass Mlp_GEGLU(nn.Module):\n    \"\"\" Multilayer perceptron with gated linear unit (GEGLU). Ref. \"GLU Variants Improve Transformer\".\n\n    Args:\n        x: (B, D, H, W, C)\n\n    Returns:\n        x: (B, D, H, W, C)\n    \"\"\"\n\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n\n        self.fc11 = nn.Linear(in_features, hidden_features)\n        self.fc12 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.act(self.fc11(x)) * self.fc12(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n@lru_cache()\ndef calculate_mask(x_size, window_size, shift_size):\n    # calculate attention mask for SW-MSA\n    H, W = x_size\n    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n    h_slices = (slice(0, -window_size),\n                slice(-window_size, -shift_size),\n                slice(-shift_size, None))\n    w_slices = (slice(0, -window_size),\n                slice(-window_size, -shift_size),\n                slice(-shift_size, None))\n    cnt = 0\n    for h in h_slices:\n        for w in w_slices:\n            img_mask[:, h, w, :] = cnt\n            cnt += 1\n\n    mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1\n    mask_windows = mask_windows.view(-1, window_size * window_size)\n    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n\n    return attn_mask\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        self.use_checkpoint = use_checkpoint\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        # if self.shift_size > 0:\n        #     attn_mask = self.calculate_mask(self.input_resolution)\n        # else:\n        #     attn_mask = None\n\n        # self.register_buffer(\"attn_mask\", attn_mask)\n\n\n    def forward(self, x, x_size):\n        H, W = x_size\n        B, L, C = x.shape\n        # assert L == H * W, \"input feature has wrong size\"\n\n        # if self.input_resolution != x_size:\n        #     self.input_resolution = x_size\n        #     if self.attn_mask is not None:\n        #         self.attn_mask = self.calculate_mask(x_size).to(x.device)\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size\n\n        # if self.input_resolution == x_size:\n        #     if self.use_checkpoint:\n        #         attn_windows = checkpoint.apply(self.attn, x_windows, self.attn_mask)  # nW*B, window_size*window_size, C\n        #     else:\n        #         attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n        # else:\n        #     if self.use_checkpoint:\n        #         attn_windows = checkpoint.apply(self.attn, x_windows, self.calculate_mask(x_size).to(x.device))\n        #     else:\n        #         attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))\n\n        attn_mask = calculate_mask(x_size, self.window_size, self.shift_size).to(x.device)\n        attn_windows = self.attn(x_windows, mask=attn_mask)\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = False\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer, use_checkpoint=use_checkpoint)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x, x_size):\n        for i, blk in enumerate(self.blocks):\n            if self.use_checkpoint:\n                # x = checkpoint.checkpoint(blk, x, x_size)\n                x = checkpoint.apply(blk, 2, x, x_size)\n            else:\n                x = blk(x, x_size)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass RSTB(nn.Module):\n    \"\"\"Residual Swin Transformer Block (RSTB).\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n        img_size: Input image size.\n        patch_size: Patch size.\n        resi_connection: The convolutional block before residual connection.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n                 img_size=224, patch_size=4, resi_connection='1conv'):\n        super(RSTB, self).__init__()\n\n        # print(f'dim: {dim}, input_resolution: {input_resolution}, depth: {depth}, num_heads: {num_heads}, window_size: {window_size}, img_size: {img_size}. patch_size: {patch_size}')\n\n        self.dim = dim\n        self.input_resolution = input_resolution\n\n        self.residual_group = BasicLayer(dim=dim,\n                                         input_resolution=input_resolution,\n                                         depth=depth,\n                                         num_heads=num_heads,\n                                         window_size=window_size,\n                                         mlp_ratio=mlp_ratio,\n                                         qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                         drop=drop, attn_drop=attn_drop,\n                                         drop_path=drop_path,\n                                         norm_layer=norm_layer,\n                                         downsample=downsample,\n                                         use_checkpoint=use_checkpoint)\n\n        if resi_connection == '1conv':\n            self.conv = nn.Conv2d(dim, dim, 3, 1, 1)\n\n        elif resi_connection == '3conv':\n            # to save parameters and memory\n            self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),\n                                      nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),\n                                      nn.LeakyReLU(negative_slope=0.2, inplace=True),\n                                      nn.Conv2d(dim // 4, dim, 3, 1, 1))\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,\n            norm_layer=None)\n\n        self.patch_unembed = PatchUnEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,\n            norm_layer=None)\n\n    def forward(self, x, x_size):\n        x = self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x\n        return x\n\n\n    def flops(self):\n        flops = 0\n        flops += self.residual_group.flops()\n        H, W = self.input_resolution\n        flops += H * W * self.dim * self.dim * 9\n        flops += self.patch_embed.flops()\n        flops += self.patch_unembed.flops()\n\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x, use_norm=True):\n        x = x.flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if use_norm and self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        H, W = self.img_size\n        if self.norm is not None:\n            flops += H * W * self.embed_dim\n        return flops\n\n\nclass PatchUnEmbed(nn.Module):\n    r\"\"\" Image to Patch Unembedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n    def forward(self, x, x_size):\n        B, HW, C = x.shape\n        x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1])  # B Ph*Pw C\n        return x\n\n    def flops(self):\n        flops = 0\n        return flops\n\n\nclass Upsample(nn.Sequential):\n    \"\"\"Upsample module.\n\n    Args:\n        scale (int): Scale factor. Supported scales: 2^n and 3.\n        num_feat (int): Channel number of intermediate features.\n    \"\"\"\n\n    def __init__(self, scale, num_feat):\n        m = []\n        if (scale & (scale - 1)) == 0:  # scale = 2^n\n            for _ in range(int(math.log(scale, 2))):\n                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))\n                m.append(nn.PixelShuffle(2))\n        elif scale == 3:\n            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))\n            m.append(nn.PixelShuffle(3))\n        else:\n            raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')\n        super(Upsample, self).__init__(*m)\n\n\nclass UpsampleOneStep(nn.Sequential):\n    \"\"\"UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)\n       Used in lightweight SR to save parameters.\n\n    Args:\n        scale (int): Scale factor. Supported scales: 2^n and 3.\n        num_feat (int): Channel number of intermediate features.\n\n    \"\"\"\n\n    def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):\n        self.num_feat = num_feat\n        self.input_resolution = input_resolution\n        m = []\n        m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))\n        m.append(nn.PixelShuffle(scale))\n        super(UpsampleOneStep, self).__init__(*m)\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.num_feat * 3 * 9\n        return flops\n\n\nclass SwinIR(nn.Module):\n    r\"\"\" SwinIR\n        A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 64\n        patch_size (int | tuple(int)): Patch size. Default: 1\n        in_chans (int): Number of input image channels. Default: 3\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n        upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction\n        img_range: Image range. 1. or 255.\n        upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None\n        resi_connection: The convolutional block before residual connection. '1conv'/'3conv'\n    \"\"\"\n\n    def __init__(self, img_size=64, patch_size=1, in_chans=3,\n                 embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',\n                 **kwargs):\n        super(SwinIR, self).__init__()\n        num_in_ch = in_chans\n        num_out_ch = in_chans\n        num_feat = 64\n        self.img_range = img_range\n        if in_chans == 3:\n            rgb_mean = (0.4488, 0.4371, 0.4040)\n            self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)\n        else:\n            self.mean = torch.zeros(1, 1, 1, 1)\n        self.upscale = upscale\n        self.upsampler = upsampler\n        self.window_size = window_size\n\n        #####################################################################################################\n        ################################### 1, shallow feature extraction ###################################\n        self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)\n\n        #####################################################################################################\n        ################################### 2, deep feature extraction ######################################\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = embed_dim\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n        # print('patches_resolution: ', patches_resolution)\n\n        # merge non-overlapping patches into image\n        self.patch_unembed = PatchUnEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build Residual Swin Transformer blocks (RSTB)\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = RSTB(dim=embed_dim,\n                         input_resolution=(patches_resolution[0],\n                                           patches_resolution[1]),\n                         depth=depths[i_layer],\n                         num_heads=num_heads[i_layer],\n                         window_size=window_size,\n                         mlp_ratio=self.mlp_ratio,\n                         qkv_bias=qkv_bias, qk_scale=qk_scale,\n                         drop=drop_rate, attn_drop=attn_drop_rate,\n                         drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results\n                         norm_layer=norm_layer,\n                         downsample=None,\n                         use_checkpoint=use_checkpoint,\n                         img_size=img_size,\n                         patch_size=patch_size,\n                         resi_connection=resi_connection\n\n                         )\n            self.layers.append(layer)\n        self.norm = norm_layer(self.num_features)\n\n        # build the last conv layer in deep feature extraction\n        if resi_connection == '1conv':\n            self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)\n        elif resi_connection == '3conv':\n            # to save parameters and memory\n            self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),\n                                                 nn.LeakyReLU(negative_slope=0.2, inplace=True),\n                                                 nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),\n                                                 nn.LeakyReLU(negative_slope=0.2, inplace=True),\n                                                 nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))\n\n        #####################################################################################################\n        ################################ 3, high quality image reconstruction ################################\n        if self.upsampler == 'pixelshuffle':\n            # for classical SR\n            self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),\n                                                      nn.LeakyReLU(inplace=True))\n            self.upsample = Upsample(upscale, num_feat)\n            self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)\n        elif self.upsampler == 'pixelshuffledirect':\n            # for lightweight SR (to save parameters)\n            self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,\n                                            (patches_resolution[0], patches_resolution[1]))\n        elif self.upsampler == 'nearest+conv':\n            # for real-world SR (less artifacts)\n            assert self.upscale == 4, 'only support x4 now.'\n            self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),\n                                                      nn.LeakyReLU(inplace=True))\n            self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)\n            self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)\n            self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)\n            self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)\n            self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)\n        else:\n            # for image denoising and JPEG compression artifact reduction\n            self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def check_image_size(self, x):\n        _, _, h, w = x.size()\n        mod_pad_h = (self.window_size - h % self.window_size) % self.window_size\n        mod_pad_w = (self.window_size - w % self.window_size) % self.window_size\n        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')\n        return x\n\n    def forward_features(self, x):\n        x_size = (x.shape[2], x.shape[3])\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x, x_size)\n\n        x = self.norm(x)  # B L C\n        x = self.patch_unembed(x, x_size)\n\n        return x\n\n    def forward(self, x):\n        H, W = x.shape[2:]\n        x = self.check_image_size(x)\n        \n        self.mean = self.mean.type_as(x)\n        x = (x - self.mean) * self.img_range\n\n        if self.upsampler == 'pixelshuffle':\n            # for classical SR\n            x = self.conv_first(x)\n            x = self.conv_after_body(self.forward_features(x)) + x\n            x = self.conv_before_upsample(x)\n            x = self.conv_last(self.upsample(x))\n        elif self.upsampler == 'pixelshuffledirect':\n            # for lightweight SR\n            x = self.conv_first(x)\n            x = self.conv_after_body(self.forward_features(x)) + x\n            x = self.upsample(x)\n        elif self.upsampler == 'nearest+conv':\n            # for real-world SR\n            x = self.conv_first(x)\n            x = self.conv_after_body(self.forward_features(x)) + x\n            x = self.conv_before_upsample(x)\n            x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))\n            x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))\n            x = self.conv_last(self.lrelu(self.conv_hr(x)))\n        else:\n            # for image denoising and JPEG compression artifact reduction\n            x_first = self.conv_first(x)\n            res = self.conv_after_body(self.forward_features(x_first)) + x_first\n            x = x + self.conv_last(res)\n\n        x = x / self.img_range + self.mean\n\n        return x[:, :, :H*self.upscale, :W*self.upscale]\n\n    def flops(self):\n        flops = 0\n        H, W = self.patches_resolution\n        flops += H * W * 3 * self.embed_dim * 9\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += H * W * 3 * self.embed_dim * self.embed_dim\n        flops += self.upsample.flops()\n        return flops\n\n\nif __name__ == '__main__':\n    upscale = 4\n    window_size = 8\n    height = (1024 // upscale // window_size + 1) * window_size\n    width = (720 // upscale // window_size + 1) * window_size\n    model = SwinIR(upscale=2, img_size=(height, width),\n                   window_size=window_size, img_range=1., depths=[6, 6, 6, 6],\n                   embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')\n    print(model)\n    print(height, width, model.flops() / 1e9)\n\n    x = torch.randn((1, 3, height, width))\n    x = model(x)\n    print(x.shape)\n"
  },
  {
    "path": "code/real/bsrt/model/utils/interp_methods.py",
    "content": "from math import pi\n\ntry:\n    import torch\nexcept ImportError:\n    torch = None\n\ntry:\n    import numpy\nexcept ImportError:\n    numpy = None\n\nif numpy is None and torch is None:\n    raise ImportError(\"Must have either Numpy or PyTorch but both not found\")\n\n\ndef set_framework_dependencies(x):\n    if type(x) is numpy.ndarray:\n        to_dtype = lambda a: a\n        fw = numpy\n    else:\n        to_dtype = lambda a: a.to(x.dtype)\n        fw = torch\n    eps = fw.finfo(fw.float32).eps\n    return fw, to_dtype, eps\n\n\ndef support_sz(sz):\n    def wrapper(f):\n        f.support_sz = sz\n        return f\n    return wrapper\n\n@support_sz(4)\ndef cubic(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    absx = fw.abs(x)\n    absx2 = absx ** 2\n    absx3 = absx ** 3\n    return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +\n            (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *\n            to_dtype((1. < absx) & (absx <= 2.)))\n\n@support_sz(4)\ndef lanczos2(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /\n            ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))\n\n@support_sz(6)\ndef lanczos3(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /\n            ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))\n\n@support_sz(2)\ndef linear(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *\n            to_dtype((0 <= x) & (x <= 1)))\n\n@support_sz(1)\ndef box(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))\n"
  },
  {
    "path": "code/real/bsrt/model/utils/psconv.py",
    "content": "import torch\nimport torch.nn as nn\n\nclass PyConv2d(nn.Module):\n    \"\"\"PyConv2d with padding (general case). Applies a 2D PyConv over an input signal composed of several input planes.\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (list): Number of channels for each pyramid level produced by the convolution\n        pyconv_kernels (list): Spatial size of the kernel for each pyramid level\n        pyconv_groups (list): Number of blocked connections from input channels to output channels for each pyramid level\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``\n    Example::\n        >>> # PyConv with two pyramid levels, kernels: 3x3, 5x5\n        >>> m = PyConv2d(in_channels=64, out_channels=[32, 32], pyconv_kernels=[3, 5], pyconv_groups=[1, 4])\n        >>> input = torch.randn(4, 64, 56, 56)\n        >>> output = m(input)\n        >>> # PyConv with three pyramid levels, kernels: 3x3, 5x5, 7x7\n        >>> m = PyConv2d(in_channels=64, out_channels=[16, 16, 32], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])\n        >>> input = torch.randn(4, 64, 56, 56)\n        >>> output = m(input)\n    \"\"\"\n    def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_groups, stride=1, dilation=1, bias=False):\n        super(PyConv2d, self).__init__()\n\n        assert len(out_channels) == len(pyconv_kernels) == len(pyconv_groups)\n\n        self.pyconv_levels = [None] * len(pyconv_kernels)\n        for i in range(len(pyconv_kernels)):\n            self.pyconv_levels[i] = nn.Conv2d(in_channels, out_channels[i], kernel_size=pyconv_kernels[i],\n                                              stride=stride, padding=pyconv_kernels[i] // 2, groups=pyconv_groups[i],\n                                              dilation=dilation, bias=bias)\n        self.pyconv_levels = nn.ModuleList(self.pyconv_levels)\n\n    def forward(self, x):\n        out = []\n        for level in self.pyconv_levels:\n            out.append(level(x))\n\n        return torch.cat(out, 1)\n\n################################################################\n\nclass PSConv2d(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, parts=4, bias=False):\n        super(PSConv2d, self).__init__()\n        self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, dilation, dilation, groups=parts, bias=bias)\n        self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * dilation, 2 * dilation, groups=parts, bias=bias)\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)\n\n        def backward_hook(grad):\n            out = grad.clone()\n            out[self.mask] = 0\n            return out\n\n        self.mask = torch.zeros(self.conv.weight.shape).byte().cuda()\n        _in_channels = in_channels // parts\n        _out_channels = out_channels // parts\n        for i in range(parts):\n            self.mask[i * _out_channels: (i + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1\n            self.mask[(i + parts//2)%parts * _out_channels: ((i + parts//2)%parts + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1\n        self.conv.weight.data[self.mask] = 0\n        self.conv.weight.register_hook(backward_hook)\n\n        self.weight = self.conv.weight\n        self.bias = self.conv.bias\n\n    def forward(self, x):\n        x1, x2 = x.chunk(2, dim=1)\n        x_shift = self.gwconv_shift(torch.cat((x2, x1), dim=1))\n        return self.gwconv(x) + self.conv(x) + x_shift\n\n\n# PSConv-based Group Convolution\nclass PSGConv2d(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, parts=4, bias=False):\n        super(PSGConv2d, self).__init__()\n        self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups * parts, bias=bias)\n        self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * padding, 2 * dilation, groups=groups * parts, bias=bias)\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias)\n\n        def backward_hook(grad):\n            out = grad.clone()\n            out[self.mask] = 0\n            return out\n\n        self.mask = torch.zeros(self.conv.weight.shape).bool().cuda()\n        _in_channels = in_channels // (groups * parts)\n        _out_channels = out_channels // (groups * parts)\n        for i in range(parts):\n            for j in range(groups):\n                self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1\n                self.mask[((i + parts // 2) % parts + j * groups) * _out_channels: ((i + parts // 2) % parts + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1\n        self.conv.weight.data[self.mask] = 0\n        self.conv.weight.register_hook(backward_hook)\n        self.groups = groups\n\n        self.weight = self.conv.weight\n        self.bias = self.conv.bias\n\n    def forward(self, x):\n        x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1))\n        x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1)\n        x_shift = self.gwconv_shift(x_merge)\n        gx = self.gwconv(x)\n        cx = self.conv(x)\n        # print(x.shape, gx.shape, cx.shape, x_merge.shape, x_shift.shape)\n        return gx + cx + x_shift\n\n"
  },
  {
    "path": "code/real/bsrt/model/utils/resize_right.py",
    "content": "import warnings\nfrom math import ceil\nimport model.utils.interp_methods as interp_methods\n\n\nclass NoneClass:\n    pass\n\ntry:\n    import torch\n    from torch import nn\n    nnModuleWrapped = nn.Module\nexcept ImportError:\n    warnings.warn('No PyTorch found, will work only with Numpy')\n    torch = None\n    nnModuleWrapped = NoneClass\n\ntry:\n    import numpy\nexcept ImportError:\n    warnings.warn('No Numpy found, will work only with PyTorch')\n    numpy = None\n\n\nif numpy is None and torch is None:\n    raise ImportError(\"Must have either Numpy or PyTorch but both not found\")\n\n\ndef resize(input, scale_factors=None, out_shape=None,\n           interp_method=interp_methods.cubic, support_sz=None,\n           antialiasing=True):\n    # get properties of the input tensor\n    in_shape, n_dims = input.shape, input.ndim\n\n    # fw stands for framework that can be either numpy or torch,\n    # determined by the input type\n    fw = numpy if type(input) is numpy.ndarray else torch\n    eps = fw.finfo(fw.float32).eps\n\n    # set missing scale factors or output shapem one according to another,\n    # scream if both missing\n    scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,\n                                                    scale_factors, fw)\n\n    # sort indices of dimensions according to scale of each dimension.\n    # since we are going dim by dim this is efficient\n    sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])\n                                       for dim in sorted(range(n_dims),\n                                       key=lambda ind: scale_factors[ind])\n                                       if scale_factors[dim] != 1.]\n\n    # unless support size is specified by the user, it is an attribute\n    # of the interpolation method\n    if support_sz is None:\n        support_sz = interp_method.support_sz\n\n    # when using pytorch, we need to know what is the input tensor device\n    if fw is torch:\n        device = input.device\n\n    # output begins identical to input and changes with each iteration\n    output = input\n\n    # iterate over dims\n    for dim, scale_factor in sorted_filtered_dims_and_scales:\n\n        # get 1d set of weights and fields of view for each output location\n        # along this dim\n        field_of_view, weights = prepare_weights_and_field_of_view_1d(\n            dim, scale_factor, in_shape[dim], out_shape[dim], interp_method,\n            support_sz, antialiasing, fw, eps, device)\n\n        # multiply the weights by the values in the field of view and\n        # aggreagate\n        output = apply_weights(output, field_of_view, weights, dim, n_dims,\n                               fw)\n    return output\n\n\nclass ResizeLayer(nnModuleWrapped):\n    def __init__(self, in_shape, scale_factors=None, out_shape=None,\n                 interp_method=interp_methods.cubic, support_sz=None,\n                 antialiasing=True):\n        super(ResizeLayer, self).__init__()\n\n        # fw stands for framework, that can be either numpy or torch. since\n        # this is a torch layer, only one option in this case.\n        fw = torch\n        eps = fw.finfo(fw.float32).eps\n\n        # set missing scale factors or output shapem one according to another,\n        # scream if both missing\n        scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,\n                                                        scale_factors, fw)\n\n        # unless support size is specified by the user, it is an attribute\n        # of the interpolation method\n        if support_sz is None:\n            support_sz = interp_method.support_sz\n\n        self.n_dims = len(in_shape)\n\n        # sort indices of dimensions according to scale of each dimension.\n        # since we are going dim by dim this is efficient\n        self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])\n                                                for dim in\n                                                sorted(range(self.n_dims),\n                                                key=lambda ind:\n                                                scale_factors[ind])\n                                                if scale_factors[dim] != 1.]\n\n        # iterate over dims\n        field_of_view_list = []\n        weights_list = []\n        for dim, scale_factor in self.sorted_filtered_dims_and_scales:\n\n            # get 1d set of weights and fields of view for each output\n            # location along this dim\n            field_of_view, weights = prepare_weights_and_field_of_view_1d(\n                dim, scale_factor, in_shape[dim], out_shape[dim],\n                interp_method, support_sz, antialiasing, fw, eps, input.device)\n\n            # keep weights and fields of views for all dims\n            weights_list.append(nn.Parameter(weights, requires_grad=False))\n            field_of_view_list.append(nn.Parameter(field_of_view,\n                                      requires_grad=False))\n\n        self.field_of_view = nn.ParameterList(field_of_view_list)\n        self.weights = nn.ParameterList(weights_list)\n        self.in_shape = in_shape\n\n    def forward(self, input):\n        # output begins identical to input and changes with each iteration\n        output = input\n\n        for (dim, scale_factor), field_of_view, weights in zip(\n                self.sorted_filtered_dims_and_scales,\n                self.field_of_view,\n                self.weights):\n            # multiply the weights by the values in the field of view and\n            # aggreagate\n            output = apply_weights(output, field_of_view, weights, dim,\n                                   self.n_dims, torch)\n        return output\n\n\ndef prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,\n                                         interp_method, support_sz,\n                                         antialiasing, fw, eps, device=None):\n    # If antialiasing is taking place, we modify the window size and the\n    # interpolation method (see inside function)\n    interp_method, cur_support_sz = apply_antialiasing_if_needed(\n                                                             interp_method,\n                                                             support_sz,\n                                                             scale_factor,\n                                                             antialiasing)\n\n    # STEP 1- PROJECTED GRID: The non-integer locations of the projection of\n    # output pixel locations to the input tensor\n    projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device)\n\n    # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels\n    # that influence it\n    field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz,\n                                      fw, eps)\n\n    # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the\n    # field of view for each output pixel\n    weights = get_weights(interp_method, projected_grid, field_of_view)\n\n    return field_of_view, weights\n\n\ndef apply_weights(input, field_of_view, weights, dim, n_dims, fw):\n    # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying\n    # its set of weights with the pixel values in its field of view.\n    # We now multiply the fields of view with their matching weights.\n    # We do this by tensor multiplication and broadcasting.\n    # this step is separated to a different function, so that it can be\n    # repeated with the same calculated weights and fields.\n\n    # for this operations we assume the resized dim is the first one.\n    # so we transpose and will transpose back after multiplying\n    tmp_input = fw_swapaxes(input, dim, 0, fw)\n\n    # field_of_view is a tensor of order 2: for each output (1d location\n    # along cur dim)- a list of 1d neighbors locations.\n    # note that this whole operations is applied to each dim separately,\n    # this is why it is all in 1d.\n    # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:\n    # for each output pixel (this time indicated in all dims), these are the\n    # values of the neighbors in the 1d field of view. note that we only\n    # consider neighbors along the current dim, but such set exists for every\n    # multi-dim location, hence the final tensor order is image_dims+1.\n    neighbors = tmp_input[field_of_view]\n\n    # weights is an order 2 tensor: for each output location along 1d- a list\n    # of weighs matching the field of view. we augment it with ones, for\n    # broadcasting, so that when multiplies some tensor the weights affect\n    # only its first dim.\n    tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1)))\n\n    # now we simply multiply the weights with the neighbors, and then sum\n    # along the field of view, to get a single value per out pixel\n    tmp_output = (neighbors * tmp_weights).sum(1)\n\n    # we transpose back the resized dim to its original position\n    return fw_swapaxes(tmp_output, 0, dim, fw)\n\n\ndef set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):\n    # eventually we must have both scale-factors and out-sizes for all in/out\n    # dims. however, we support many possible partial arguments\n    if scale_factors is None and out_shape is None:\n        raise ValueError(\"either scale_factors or out_shape should be \"\n                         \"provided\")\n    if out_shape is not None:\n        # if out_shape has less dims than in_shape, we defaultly resize the\n        # first dims for numpy and last dims for torch\n        out_shape = (list(out_shape) + list(in_shape[:-len(out_shape)])\n                     if fw is numpy\n                     else list(in_shape[:-len(out_shape)]) + list(out_shape))\n        if scale_factors is None:\n            # if no scale given, we calculate it as the out to in ratio\n            # (not recomended)\n            scale_factors = [out_sz / in_sz for out_sz, in_sz\n                             in zip(out_shape, in_shape)]\n    if scale_factors is not None:\n        # by default, if a single number is given as scale, we assume resizing\n        # two dims (most common are images with 2 spatial dims)\n        scale_factors = (scale_factors\n                         if isinstance(scale_factors, (list, tuple))\n                         else [scale_factors, scale_factors])\n        # if less scale_factors than in_shape dims, we defaultly resize the\n        # first dims for numpy and last dims for torch\n        scale_factors = (list(scale_factors) + [1] *\n                         (len(in_shape) - len(scale_factors)) if fw is numpy\n                         else [1] * (len(in_shape) - len(scale_factors)) +\n                         list(scale_factors))\n        if out_shape is None:\n            # when no out_shape given, it is calculated by multiplying the\n            # scale by the in_shape (not recomended)\n            out_shape = [ceil(scale_factor * in_sz)\n                         for scale_factor, in_sz in\n                         zip(scale_factors, in_shape)]\n        # next line intentionally after out_shape determined for stability\n        scale_factors = [float(sf) for sf in scale_factors]\n    return scale_factors, out_shape\n\n\ndef get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):\n    # we start by having the ouput coordinates which are just integer locations\n    out_coordinates = fw.arange(out_sz)\n\n    # if using torch we need to match the grid tensor device to the input device\n    out_coordinates = fw_set_device(out_coordinates, device, fw)\n\n    # This is projecting the ouput pixel locations in 1d to the input tensor,\n    # as non-integer locations.\n    # the following fomrula is derived in the paper\n    # \"From Discrete to Continuous Convolutions\" by Shocher et al.\n    return (out_coordinates / scale_factor +\n            (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor))\n\n\ndef get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):\n    # for each output pixel, map which input pixels influence it, in 1d.\n    # we start by calculating the leftmost neighbor, using half of the window\n    # size (eps is for when boundary is exact int)\n    left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)\n\n    # then we simply take all the pixel centers in the field by counting\n    # window size pixels from the left boundary\n    ordinal_numbers = fw.arange(ceil(cur_support_sz - eps))\n    # in case using torch we need to match the device\n    ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw)\n    field_of_view = left_boundaries[:, None] + ordinal_numbers\n\n    # next we do a trick instead of padding, we map the field of view so that\n    # it would be like mirror padding, without actually padding\n    # (which would require enlarging the input tensor)\n    mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw)\n    field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])]\n    field_of_view = fw_set_device(field_of_view,projected_grid.device, fw)\n    return field_of_view\n\n\ndef get_weights(interp_method, projected_grid, field_of_view):\n    # the set of weights per each output pixels is the result of the chosen\n    # interpolation method applied to the distances between projected grid\n    # locations and the pixel-centers in the field of view (distances are\n    # directed, can be positive or negative)\n    weights = interp_method(projected_grid[:, None] - field_of_view)\n\n    # we now carefully normalize the weights to sum to 1 per each output pixel\n    sum_weights = weights.sum(1, keepdims=True)\n    sum_weights[sum_weights == 0] = 1\n    return weights / sum_weights\n\n\ndef apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,\n                                 antialiasing):\n    # antialiasing is \"stretching\" the field of view according to the scale\n    # factor (only for downscaling). this is low-pass filtering. this\n    # requires modifying both the interpolation (stretching the 1d\n    # function and multiplying by the scale-factor) and the window size.\n    if scale_factor >= 1.0 or not antialiasing:\n        return interp_method, support_sz\n    cur_interp_method = (lambda arg: scale_factor *\n                         interp_method(scale_factor * arg))\n    cur_support_sz = support_sz / scale_factor\n    return cur_interp_method, cur_support_sz\n\n\ndef fw_ceil(x, fw):\n    if fw is numpy:\n        return fw.int_(fw.ceil(x))\n    else:\n        return x.ceil().long()\n\n\ndef fw_cat(x, fw):\n    if fw is numpy:\n        return fw.concatenate(x)\n    else:\n        return fw.cat(x)\n\n\ndef fw_swapaxes(x, ax_1, ax_2, fw):\n    if fw is numpy:\n        return fw.swapaxes(x, ax_1, ax_2)\n    else:\n        return x.transpose(ax_1, ax_2)\n\ndef fw_set_device(x, device, fw):\n    if fw is numpy:\n        return x\n    else:\n        return x.to(device)\n"
  },
  {
    "path": "code/real/bsrt/option.py",
    "content": "import argparse\n\nparser = argparse.ArgumentParser(description='EDSR and MDSR')\n\nparser.add_argument('--n_resblocks', type=int, default=16,\n                    help='number of residual blocks')\nparser.add_argument('--n_feats', type=int, default=64,\n                    help='number of feature maps')\nparser.add_argument('--n_colors', type=int, default=3,\n                    help='number of color channels to use')\nparser.add_argument('--lr', type=float, default=1e-4,\n                    help='learning rate')\nparser.add_argument('--burst_size', type=int, default=14,\n                    help='burst size, max 14')\nparser.add_argument('--burst_channel', type=int, default=4,\n                    help='RAW channel, default:4')\nparser.add_argument('--swinfeature', action='store_true',\n                    help='use swin transformer to extract features')\nparser.add_argument('--model_level', type=str, default='S',\n                    help='S: small, L: large')\n\n################## fine-tune ##################\nparser.add_argument('--finetune', action='store_true',\n                    help='finetune model')\nparser.add_argument('--finetune_align', action='store_true',\n                    help='finetune alignment module')\nparser.add_argument('--finetune_swin', action='store_true',\n                    help='finetune swin trans module')\nparser.add_argument('--finetune_conv', action='store_true',\n                    help='finetune rest convs')\nparser.add_argument('--finetune_prelayer', action='store_true',\n                    help='finetune finetune pre feature extract layer')\nparser.add_argument('--finetune_upconv', action='store_true',\n                    help='finetune finetune up conv layer')\nparser.add_argument('--finetune_spynet', action='store_true',\n                    help='finetune finetune up conv layer')\n\n\n# Hardware specifications\nparser.add_argument('--n_threads', type=int, default=6,\n                    help='number of threads for data loading')\nparser.add_argument('--cpu', action='store_true',\n                    help='use cpu only')\nparser.add_argument('--n_GPUs', type=int, default=1,\n                    help='number of GPUs')\nparser.add_argument('--seed', type=int, default=1,\n                    help='random seed')\nparser.add_argument('--local_rank', type=int, default=-1,\n                    help='proc index')\nparser.add_argument('--fp16', action='store_true',\n                    help='use fp16 only')\nparser.add_argument('--use_checkpoint', action='store_true',\n                    help='use use_checkpoint in swin transformer')\n\n# Data specifications\nparser.add_argument('--root', type=str, default='/data/dataset/ntire21/burstsr/real',\n                    help='dataset directory')\nparser.add_argument('--val_root', type=str, default='../test_set',\n                    help='dataset directory')\nparser.add_argument('--mode', type=str, default='train',\n                    help='demo image directory')\nparser.add_argument('--scale', type=str, default='4',\n                    help='super resolution scale')\nparser.add_argument('--patch_size', type=int, default=256,\n                    help='output patch size')\nparser.add_argument('--rgb_range', type=int, default=1,\n                    help='maximum value of RGB')\n\nparser.add_argument('--chop', action='store_true',\n                    help='enable memory-efficient forward')\nparser.add_argument('--no_augment', action='store_true',\n                    help='do not use data augmentation')\n\n# Model specifications\nparser.add_argument('--model', default='LRSC_EDVR',\n                    help='model name')\n\nparser.add_argument('--act', type=str, default='relu',\n                    help='activation function')\nparser.add_argument('--pre_train', type=str, default='',\n                    help='pre-trained model directory')\nparser.add_argument('--extend', type=str, default='.',\n                    help='pre-trained model directory')\n\nparser.add_argument('--res_scale', type=float, default=1,\n                    help='residual scaling')\nparser.add_argument('--shift_mean', default=True,\n                    help='subtract pixel mean from the input')\nparser.add_argument('--dilation', action='store_true',\n                    help='use dilated convolution')\nparser.add_argument('--precision', type=str, default='single',\n                    choices=('single', 'half'),\n                    help='FP precision for test (single | half)')\n\n\n# Option for Residual channel attention network (RCAN)\nparser.add_argument('--n_resgroups', type=int, default=20,\n                    help='number of residual groups')\nparser.add_argument('--reduction', type=int, default=16,\n                    help='number of feature maps reduction')\nparser.add_argument('--DA', action='store_true',\n                    help='use Dual Attention')\nparser.add_argument('--CA', action='store_true',\n                    help='use Channel Attention')\nparser.add_argument('--non_local', action='store_true',\n                    help='use Dual Attention')\n\n# Training specifications\nparser.add_argument('--reset', action='store_true',\n                    help='reset the training')\nparser.add_argument('--test_every', type=int, default=1000,\n                    help='do test per every N batches')\nparser.add_argument('--epochs', type=int, default=100,\n                    help='number of epochs to train')\nparser.add_argument('--batch_size', type=int, default=8,\n                    help='input batch size for training')\nparser.add_argument('--split_batch', type=int, default=1,\n                    help='split the batch into smaller chunks')\nparser.add_argument('--self_ensemble', action='store_true',\n                    help='use self-ensemble method for test')\nparser.add_argument('--test_only', action='store_true',\n                    help='set this option to test the model')\nparser.add_argument('--gan_k', type=int, default=1,\n                    help='k value for adversarial loss')\n\n# Optimization specifications\n\nparser.add_argument('--decay', type=str, default='40-80',\n                    help='learning rate decay type')\nparser.add_argument('--gamma', type=float, default=0.5,\n                    help='learning rate decay factor for step decay')\nparser.add_argument('--optimizer', default='ADAM',\n                    choices=('SGD', 'ADAM', 'RMSprop'),\n                    help='optimizer to use (SGD | ADAM | RMSprop)')\nparser.add_argument('--momentum', type=float, default=0.9,\n                    help='SGD momentum')\nparser.add_argument('--betas', type=tuple, default=(0.9, 0.999),\n                    help='ADAM beta')\nparser.add_argument('--epsilon', type=float, default=1e-8,\n                    help='ADAM epsilon for numerical stability')\nparser.add_argument('--weight_decay', type=float, default=0,\n                    help='weight decay')\nparser.add_argument('--gclip', type=float, default=0,\n                    help='gradient clipping threshold (0 = no clipping)')\n\n# Loss specifications\nparser.add_argument('--loss', type=str, default='1*L1',\n                    help='loss function configuration')\nparser.add_argument('--skip_threshold', type=float, default='1e8',\n                    help='skipping batch that has large error')\n\n# Log specifications\nparser.add_argument('--save', type=str, default='test',\n                    help='file name to save')\nparser.add_argument('--load', type=str, default='',\n                    help='file name to load')\nparser.add_argument('--resume', type=int, default=0,\n                    help='resume from specific checkpoint')\nparser.add_argument('--save_models', action='store_true',\n                    help='save all intermediate models')\nparser.add_argument('--print_every', type=int, default=10,\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--save_results', action='store_true',\n                    help='save output results')\nparser.add_argument('--save_gt', action='store_true',\n                    help='save low-resolution and high-resolution images together')\n\nargs = parser.parse_args()\n\nargs.scale = list(map(lambda x: int(x), args.scale.split('+')))\n\nif args.epochs == 0:\n    args.epochs = 1e8\n\nfor arg in vars(args):\n    if vars(args)[arg] == 'True':\n        vars(args)[arg] = True\n    elif vars(args)[arg] == 'False':\n        vars(args)[arg] = False\n\n"
  },
  {
    "path": "code/real/bsrt/pwcnet/LICENSE",
    "content": "GNU GENERAL PUBLIC LICENSE\n                       Version 3, 29 June 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU General Public License is a free, copyleft license for\nsoftware and other kinds of works.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nthe GNU General Public License is intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.  We, the Free Software Foundation, use the\nGNU General Public License for most of our software; it applies also to\nany other work released this way by its authors.  You can apply it to\nyour programs, too.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  To protect your rights, we need to prevent others from denying you\nthese rights or asking you to surrender the rights.  Therefore, you have\ncertain responsibilities if you distribute copies of the software, or if\nyou modify it: responsibilities to respect the freedom of others.\n\n  For example, if you distribute copies of such a program, whether\ngratis or for a fee, you must pass on to the recipients the same\nfreedoms that you received.  You must make sure that they, too, receive\nor can get the source code.  And you must show them these terms so they\nknow their rights.\n\n  Developers that use the GNU GPL protect your rights with two steps:\n(1) assert copyright on the software, and (2) offer you this License\ngiving you legal permission to copy, distribute and/or modify it.\n\n  For the developers' and authors' protection, the GPL clearly explains\nthat there is no warranty for this free software.  For both users' and\nauthors' sake, the GPL requires that modified versions be marked as\nchanged, so that their problems will not be attributed erroneously to\nauthors of previous versions.\n\n  Some devices are designed to deny users access to install or run\nmodified versions of the software inside them, although the manufacturer\ncan do so.  This is fundamentally incompatible with the aim of\nprotecting users' freedom to change the software.  The systematic\npattern of such abuse occurs in the area of products for individuals to\nuse, which is precisely where it is most unacceptable.  Therefore, we\nhave designed this version of the GPL to prohibit the practice for those\nproducts.  If such problems arise substantially in other domains, we\nstand ready to extend this provision to those domains in future versions\nof the GPL, as needed to protect the freedom of users.\n\n  Finally, every program is threatened constantly by software patents.\nStates should not allow patents to restrict development and use of\nsoftware on general-purpose computers, but in those that do, we wish to\navoid the special danger that patents applied to a free program could\nmake it effectively proprietary.  To prevent this, the GPL assures that\npatents cannot be used to render the program non-free.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Use with the GNU Affero General Public License.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU Affero General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the special requirements of the GNU Affero General Public License,\nsection 13, concerning interaction through a network will apply to the\ncombination as such.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU General Public License from time to time.  Such new versions will\nbe similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    {one line to give the program's name and a brief idea of what it does.}\n    Copyright (C) {year}  {name of author}\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU General Public License as published by\n    the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU General Public License for more details.\n\n    You should have received a copy of the GNU General Public License\n    along with this program.  If not, see <http://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If the program does terminal interaction, make it output a short\nnotice like this when it starts in an interactive mode:\n\n    {project}  Copyright (C) {year}  {fullname}\n    This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.\n    This is free software, and you are welcome to redistribute it\n    under certain conditions; type `show c' for details.\n\nThe hypothetical commands `show w' and `show c' should show the appropriate\nparts of the General Public License.  Of course, your program's commands\nmight be different; for a GUI interface, you would use an \"about box\".\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU GPL, see\n<http://www.gnu.org/licenses/>.\n\n  The GNU General Public License does not permit incorporating your program\ninto proprietary programs.  If your program is a subroutine library, you\nmay consider it more useful to permit linking proprietary applications with\nthe library.  If this is what you want to do, use the GNU Lesser General\nPublic License instead of this License.  But first, please read\n<http://www.gnu.org/philosophy/why-not-lgpl.html>."
  },
  {
    "path": "code/real/bsrt/pwcnet/README.md",
    "content": "# pytorch-pwc\nThis is a personal reimplementation of PWC-Net [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the <a href=\"https://github.com/NVlabs/PWC-Net#license\">licensing terms</a> of the authors. Should you be making use of this particular implementation, please acknowledge it appropriately [2].\n\n<a href=\"https://arxiv.org/abs/1709.02371\" rel=\"Paper\"><img src=\"http://www.arxiv-sanity.com/static/thumbs/1709.02371v1.pdf.jpg\" alt=\"Paper\" width=\"100%\"></a>\n\nFor the original version of this work, please see: https://github.com/NVlabs/PWC-Net\n<br />\nAnother optical flow implementation from me: https://github.com/sniklaus/pytorch-liteflownet\n<br />\nAnd another optical flow implementation from me: https://github.com/sniklaus/pytorch-unflow\n<br />\nYet another optical flow implementation from me: https://github.com/sniklaus/pytorch-spynet\n\n## background\nThe authors of PWC-Net are thankfully already providing a reference implementation in PyTorch. However, its initial version did not reach the performance of the original Caffe version. This is why I created this repositroy, in which I replicated the performance of the official Caffe version by utilizing its weights.\n\nThe official PyTorch implementation has adopted my approach of using the Caffe weights since then, which is why they are all performing equally well now. Many people have reported issues with CUDA when trying to get the official PyTorch version to run though, while my reimplementaiton does not seem to be subject to such problems.\n\n## setup\nTo download the pre-trained models, run `bash download.bash`. These originate from the original authors, I just converted them to PyTorch.\n\nThe correlation layer is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided binary packages as outlined in the CuPy repository.\n\n## usage\nTo run it on your own pair of images, use the following command. You can choose between two models, please make sure to see their paper / the code for more details.\n\n```\npython run.py --model default --first ./images/first.png --second ./images/second.png --out ./out.flo\n```\n\nI am afraid that I cannot guarantee that this reimplementation is correct. However, it produced results identical to the Caffe implementation of the original authors in the examples that I tried. Please feel free to contribute to this repository by submitting issues and pull requests.\n\n## comparison\n<p align=\"center\"><img src=\"comparison/comparison.gif?raw=true\" alt=\"Comparison\"></p>\n\n## license\nAs stated in the <a href=\"https://github.com/NVlabs/PWC-Net#license\">licensing terms</a> of the authors of the paper, the models are free for non-commercial share-alike purpose. Please make sure to further consult their licensing terms.\n\n## references\n```\n[1]  @inproceedings{Sun_CVPR_2018,\n         author = {Deqing Sun and Xiaodong Yang and Ming-Yu Liu and Jan Kautz},\n         title = {{PWC-Net}: {CNNs} for Optical Flow Using Pyramid, Warping, and Cost Volume},\n         booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},\n         year = {2018}\n     }\n```\n\n```\n[2]  @misc{pytorch-pwc,\n         author = {Simon Niklaus},\n         title = {A Reimplementation of {PWC-Net} Using {PyTorch}},\n         year = {2018},\n         howpublished = {\\url{https://github.com/sniklaus/pytorch-pwc}}\n    }\n```"
  },
  {
    "path": "code/real/bsrt/pwcnet/__init__.py",
    "content": ""
  },
  {
    "path": "code/real/bsrt/pwcnet/comparison/comparison.py",
    "content": "#!/usr/bin/env python\n\nimport math\nimport moviepy\nimport moviepy.editor\nimport numpy\nimport PIL\nimport PIL.Image\nimport PIL.ImageFont\nimport PIL.ImageDraw\n\nintX = 32\nintY = 436 - 64\n\nobjImages = [ {\n\t'strFile': 'official - caffe.png',\n\t'strText': 'official - Caffe'\n}, {\n\t'strFile': 'this - pytorch.png',\n\t'strText': 'this - PyTorch'\n} ]\n\nnpyImages = []\n\nfor objImage in objImages:\n\tobjOutput = PIL.Image.open(objImage['strFile']).convert('RGB')\n\n\tfor intU in [ intShift - 10 for intShift in range(20) ]:\n\t\tfor intV in [ intShift - 10 for intShift in range(20) ]:\n\t\t\tif math.sqrt(math.pow(intU, 2.0) + math.pow(intV, 2.0)) <= 5.0:\n\t\t\t\tPIL.ImageDraw.Draw(objOutput).text((intX + intU, intY + intV), objImage['strText'], (255, 255, 255), PIL.ImageFont.truetype('freefont/FreeSerifBold.ttf', 32))\n\t\t\t# end\n\t\t# end\n\t# end\n\n\tPIL.ImageDraw.Draw(objOutput).text((intX, intY), objImage['strText'], (0, 0, 0), PIL.ImageFont.truetype('freefont/FreeSerifBold.ttf', 32))\n\n\tnpyImages.append(numpy.array(objOutput))\n# end\n\nmoviepy.editor.ImageSequenceClip(sequence=npyImages, fps=1).write_gif(filename='comparison.gif', program='ImageMagick', opt='optimizeplus')"
  },
  {
    "path": "code/real/bsrt/pwcnet/correlation/README.md",
    "content": "This is an adaptation of the <a href=\"https://github.com/lmb-freiburg/flownet2\">FlowNet2 implementation</a> in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the <a href=\"https://github.com/lmb-freiburg/flownet2#license-and-citation\">licensing terms</a> of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately."
  },
  {
    "path": "code/real/bsrt/pwcnet/correlation/correlation.py",
    "content": "#!/usr/bin/env python\n\nimport torch\n\nimport cupy\nimport re\n# from torch.cuda.amp import custom_fwd, custom_bwd\n\nkernel_Correlation_rearrange = '''\n\textern \"C\" __global__ void kernel_Correlation_rearrange(\n\t\tconst int n,\n\t\tconst float* input,\n\t\tfloat* output\n\t) {\n\t  int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;\n\n\t  if (intIndex >= n) {\n\t    return;\n\t  }\n\n\t  int intSample = blockIdx.z;\n\t  int intChannel = blockIdx.y;\n\n\t  float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];\n\n\t  __syncthreads();\n\n\t  int intPaddedY = (intIndex / SIZE_3(input)) + 4;\n\t  int intPaddedX = (intIndex % SIZE_3(input)) + 4;\n\t  int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;\n\n\t  output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;\n\t}\n'''\n\nkernel_Correlation_updateOutput = '''\n\textern \"C\" __global__ void kernel_Correlation_updateOutput(\n\t  const int n,\n\t  const float* rbot0,\n\t  const float* rbot1,\n\t  float* top\n\t) {\n\t  extern __shared__ char patch_data_char[];\n\n\t  float *patch_data = (float *)patch_data_char;\n\n\t  // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1\n\t  int x1 = blockIdx.x + 4;\n\t  int y1 = blockIdx.y + 4;\n\t  int item = blockIdx.z;\n\t  int ch_off = threadIdx.x;\n\n\t  // Load 3D patch into shared shared memory\n\t  for (int j = 0; j < 1; j++) { // HEIGHT\n\t    for (int i = 0; i < 1; i++) { // WIDTH\n\t      int ji_off = (j + i) * SIZE_3(rbot0);\n\t      for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS\n\t        int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;\n\t        int idxPatchData = ji_off + ch;\n\t        patch_data[idxPatchData] = rbot0[idx1];\n\t      }\n\t    }\n\t  }\n\n\t  __syncthreads();\n\n\t  __shared__ float sum[32];\n\n\t  // Compute correlation\n\t  for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {\n\t    sum[ch_off] = 0;\n\n\t    int s2o = top_channel % 9 - 4;\n\t    int s2p = top_channel / 9 - 4;\n\n\t    for (int j = 0; j < 1; j++) { // HEIGHT\n\t      for (int i = 0; i < 1; i++) { // WIDTH\n\t        int ji_off = (j + i) * SIZE_3(rbot0);\n\t        for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS\n\t          int x2 = x1 + s2o;\n\t          int y2 = y1 + s2p;\n\n\t          int idxPatchData = ji_off + ch;\n\t          int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;\n\n\t          sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];\n\t        }\n\t      }\n\t    }\n\n\t    __syncthreads();\n\n\t    if (ch_off == 0) {\n\t      float total_sum = 0;\n\t      for (int idx = 0; idx < 32; idx++) {\n\t        total_sum += sum[idx];\n\t      }\n\t      const int sumelems = SIZE_3(rbot0);\n\t      const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;\n\t      top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;\n\t    }\n\t  }\n\t}\n'''\n\nkernel_Correlation_updateGradFirst = '''\n\t#define ROUND_OFF 50000\n\n\textern \"C\" __global__ void kernel_Correlation_updateGradFirst(\n\t  const int n,\n\t  const int intSample,\n\t  const float* rbot0,\n\t  const float* rbot1,\n\t  const float* gradOutput,\n\t  float* gradFirst,\n\t  float* gradSecond\n\t) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {\n\t  int n = intIndex % SIZE_1(gradFirst); // channels\n\t  int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos\n\t  int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos\n\n\t  // round_off is a trick to enable integer division with ceil, even for negative numbers\n\t  // We use a large offset, for the inner part not to become negative.\n\t  const int round_off = ROUND_OFF;\n\t  const int round_off_s1 = round_off;\n\n\t  // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:\n\t  int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)\n\t  int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)\n\n\t  // Same here:\n\t  int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)\n\t  int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)\n\n\t  float sum = 0;\n\t  if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {\n\t    xmin = max(0,xmin);\n\t    xmax = min(SIZE_3(gradOutput)-1,xmax);\n\n\t    ymin = max(0,ymin);\n\t    ymax = min(SIZE_2(gradOutput)-1,ymax);\n\n\t    for (int p = -4; p <= 4; p++) {\n\t      for (int o = -4; o <= 4; o++) {\n\t        // Get rbot1 data:\n\t        int s2o = o;\n\t        int s2p = p;\n\t        int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;\n\t        float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]\n\n\t        // Index offset for gradOutput in following loops:\n\t        int op = (p+4) * 9 + (o+4); // index[o,p]\n\t        int idxopoffset = (intSample * SIZE_1(gradOutput) + op);\n\n\t        for (int y = ymin; y <= ymax; y++) {\n\t          for (int x = xmin; x <= xmax; x++) {\n\t            int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]\n\t            sum += gradOutput[idxgradOutput] * bot1tmp;\n\t          }\n\t        }\n\t      }\n\t    }\n\t  }\n\t  const int sumelems = SIZE_1(gradFirst);\n\t  const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);\n\t  gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;\n\t} }\n'''\n\nkernel_Correlation_updateGradSecond = '''\n\t#define ROUND_OFF 50000\n\n\textern \"C\" __global__ void kernel_Correlation_updateGradSecond(\n\t  const int n,\n\t  const int intSample,\n\t  const float* rbot0,\n\t  const float* rbot1,\n\t  const float* gradOutput,\n\t  float* gradFirst,\n\t  float* gradSecond\n\t) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {\n\t  int n = intIndex % SIZE_1(gradSecond); // channels\n\t  int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos\n\t  int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos\n\n\t  // round_off is a trick to enable integer division with ceil, even for negative numbers\n\t  // We use a large offset, for the inner part not to become negative.\n\t  const int round_off = ROUND_OFF;\n\t  const int round_off_s1 = round_off;\n\n\t  float sum = 0;\n\t  for (int p = -4; p <= 4; p++) {\n\t    for (int o = -4; o <= 4; o++) {\n\t      int s2o = o;\n\t      int s2p = p;\n\n\t      //Get X,Y ranges and clamp\n\t      // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:\n\t      int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)\n\t      int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)\n\n\t      // Same here:\n\t      int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)\n\t      int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)\n\n\t      if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {\n\t        xmin = max(0,xmin);\n\t        xmax = min(SIZE_3(gradOutput)-1,xmax);\n\n\t        ymin = max(0,ymin);\n\t        ymax = min(SIZE_2(gradOutput)-1,ymax);\n\n\t        // Get rbot0 data:\n\t        int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;\n\t        float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]\n\n\t        // Index offset for gradOutput in following loops:\n\t        int op = (p+4) * 9 + (o+4); // index[o,p]\n\t        int idxopoffset = (intSample * SIZE_1(gradOutput) + op);\n\n\t        for (int y = ymin; y <= ymax; y++) {\n\t          for (int x = xmin; x <= xmax; x++) {\n\t            int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]\n\t            sum += gradOutput[idxgradOutput] * bot0tmp;\n\t          }\n\t        }\n\t      }\n\t    }\n\t  }\n\t  const int sumelems = SIZE_1(gradSecond);\n\t  const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);\n\t  gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;\n\t} }\n'''\n\ndef cupy_kernel(strFunction, objVariables):\n\tstrKernel = globals()[strFunction]\n\n\twhile True:\n\t\tobjMatch = re.search('(SIZE_)([0-4])(\\()([^\\)]*)(\\))', strKernel)\n\n\t\tif objMatch is None:\n\t\t\tbreak\n\t\t# end\n\n\t\tintArg = int(objMatch.group(2))\n\n\t\tstrTensor = objMatch.group(4)\n\t\tintSizes = objVariables[strTensor].size()\n\n\t\tstrKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))\n\t# end\n\n\twhile True:\n\t\tobjMatch = re.search('(VALUE_)([0-4])(\\()([^\\)]+)(\\))', strKernel)\n\n\t\tif objMatch is None:\n\t\t\tbreak\n\t\t# end\n\n\t\tintArgs = int(objMatch.group(2))\n\t\tstrArgs = objMatch.group(4).split(',')\n\n\t\tstrTensor = strArgs[0]\n\t\tintStrides = objVariables[strTensor].stride()\n\t\tstrIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]\n\n\t\tstrKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')\n\t# end\n\n\treturn strKernel\n# end\n\n@cupy.memoize(for_each_device=True)\ndef cupy_launch(strFunction, strKernel):\n\treturn cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)\n# end\n\nclass _FunctionCorrelation(torch.autograd.Function):\n\t@staticmethod\n\t# @custom_fwd#(cast_inputs=torch.float32)\n\tdef forward(self, first, second):\n\t\trbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ])\n\t\trbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ])\n\n\t\tself.save_for_backward(first, second, rbot0, rbot1)\n\n\t\tassert(first.is_contiguous() == True)\n\t\tassert(second.is_contiguous() == True)\n\n\t\toutput = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ])\n\n\t\tif first.is_cuda == True:\n\t\t\tn = first.shape[2] * first.shape[3]\n\t\t\tcupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {\n\t\t\t\t'input': first,\n\t\t\t\t'output': rbot0\n\t\t\t}))(\n\t\t\t\tgrid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]),\n\t\t\t\tblock=tuple([ 16, 1, 1 ]),\n\t\t\t\targs=[ n, first.data_ptr(), rbot0.data_ptr() ]\n\t\t\t)\n\n\t\t\tn = second.shape[2] * second.shape[3]\n\t\t\tcupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {\n\t\t\t\t'input': second,\n\t\t\t\t'output': rbot1\n\t\t\t}))(\n\t\t\t\tgrid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]),\n\t\t\t\tblock=tuple([ 16, 1, 1 ]),\n\t\t\t\targs=[ n, second.data_ptr(), rbot1.data_ptr() ]\n\t\t\t)\n\n\t\t\tn = output.shape[1] * output.shape[2] * output.shape[3]\n\t\t\tcupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {\n\t\t\t\t'rbot0': rbot0,\n\t\t\t\t'rbot1': rbot1,\n\t\t\t\t'top': output\n\t\t\t}))(\n\t\t\t\tgrid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),\n\t\t\t\tblock=tuple([ 32, 1, 1 ]),\n\t\t\t\tshared_mem=first.shape[1] * 4,\n\t\t\t\targs=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]\n\t\t\t)\n\n\t\telif first.is_cuda == False:\n\t\t\traise NotImplementedError()\n\n\t\t# end\n\n\t\treturn output\n\t# end\n\n\t@staticmethod\n\t# @custom_bwd\n\tdef backward(self, gradOutput):\n\t\tfirst, second, rbot0, rbot1 = self.saved_tensors\n\n\t\tassert(gradOutput.is_contiguous() == True)\n\n\t\tgradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None\n\t\tgradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None\n\n\t\tif first.is_cuda == True:\n\t\t\tif gradFirst is not None:\n\t\t\t\tfor intSample in range(first.shape[0]):\n\t\t\t\t\tn = first.shape[1] * first.shape[2] * first.shape[3]\n\t\t\t\t\tcupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {\n\t\t\t\t\t\t'rbot0': rbot0,\n\t\t\t\t\t\t'rbot1': rbot1,\n\t\t\t\t\t\t'gradOutput': gradOutput,\n\t\t\t\t\t\t'gradFirst': gradFirst,\n\t\t\t\t\t\t'gradSecond': None\n\t\t\t\t\t}))(\n\t\t\t\t\t\tgrid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),\n\t\t\t\t\t\tblock=tuple([ 512, 1, 1 ]),\n\t\t\t\t\t\targs=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ]\n\t\t\t\t\t)\n\t\t\t\t# end\n\t\t\t# end\n\n\t\t\tif gradSecond is not None:\n\t\t\t\tfor intSample in range(first.shape[0]):\n\t\t\t\t\tn = first.shape[1] * first.shape[2] * first.shape[3]\n\t\t\t\t\tcupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {\n\t\t\t\t\t\t'rbot0': rbot0,\n\t\t\t\t\t\t'rbot1': rbot1,\n\t\t\t\t\t\t'gradOutput': gradOutput,\n\t\t\t\t\t\t'gradFirst': None,\n\t\t\t\t\t\t'gradSecond': gradSecond\n\t\t\t\t\t}))(\n\t\t\t\t\t\tgrid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),\n\t\t\t\t\t\tblock=tuple([ 512, 1, 1 ]),\n\t\t\t\t\t\targs=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ]\n\t\t\t\t\t)\n\t\t\t\t# end\n\t\t\t# end\n\n\t\telif first.is_cuda == False:\n\t\t\traise NotImplementedError()\n\n\t\t# end\n\n\t\treturn gradFirst, gradSecond\n\t# end\n# end\n\ndef FunctionCorrelation(tenFirst, tenSecond):\n\treturn _FunctionCorrelation.apply(tenFirst, tenSecond)\n# end\n\nclass ModuleCorrelation(torch.nn.Module):\n\tdef __init__(self):\n\t\tsuper(ModuleCorrelation, self).__init__()\n\t# end\n\n\tdef forward(self, tenFirst, tenSecond):\n\t\treturn _FunctionCorrelation.apply(tenFirst, tenSecond)\n\t# end\n# end"
  },
  {
    "path": "code/real/bsrt/pwcnet/download.bash",
    "content": "#!/bin/bash\n\nwget --verbose --continue --timestamping http://content.sniklaus.com/github/pytorch-pwc/network-chairs-things.pytorch\nwget --verbose --continue --timestamping http://content.sniklaus.com/github/pytorch-pwc/network-default.pytorch"
  },
  {
    "path": "code/real/bsrt/pwcnet/images/README.md",
    "content": "The used example originates from the MPI Sintel dataset: http://sintel.is.tue.mpg.de/"
  },
  {
    "path": "code/real/bsrt/pwcnet/pwcnet.py",
    "content": "# Based on run.py from PWCNet\nimport torch\n\nimport getopt\nimport math\nimport numpy\nimport PIL.Image\nimport sys\nfrom torch.cuda.amp import autocast\n\ntry:\n    from pwcnet.correlation import correlation # the custom cost volume layer\nexcept:\n    sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python\n\n\nbackwarp_tenGrid = {}\nbackwarp_tenPartial = {}\n\n# @autocast(enabled=False)\ndef backwarp(tenInput, tenFlow):\n    if str(tenFlow.shape) not in backwarp_tenGrid:\n        tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1)\n        tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3])\n\n        backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda()\n\n    if str(tenFlow.shape) not in backwarp_tenPartial:\n        backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ])\n\n    tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)\n    tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1)\n\n    tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1),\n                                                mode='bilinear', padding_mode='zeros', align_corners=False)\n\n    tenMask = tenOutput[:, -1:, :, :]\n    tenMask[tenMask > 0.999] = 1.0\n    tenMask[tenMask < 1.0] = 0.0\n\n    return tenOutput[:, :-1, :, :].contiguous() * tenMask.contiguous()\n\n\nclass Network(torch.nn.Module):\n    def __init__(self):\n        super(Network, self).__init__()\n\n        class Extractor(torch.nn.Module):\n            def __init__(self):\n                super(Extractor, self).__init__()\n\n                self.netOne = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netTwo = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netThr = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netFou = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netFiv = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netSix = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n            def forward(self, tenInput):\n                tenOne = self.netOne(tenInput)\n                tenTwo = self.netTwo(tenOne)\n                tenThr = self.netThr(tenTwo)\n                tenFou = self.netFou(tenThr)\n                tenFiv = self.netFiv(tenFou)\n                tenSix = self.netSix(tenFiv)\n\n                return [tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix]\n\n        class Decoder(torch.nn.Module):\n            def __init__(self, intLevel):\n                super(Decoder, self).__init__()\n                intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1]\n                intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0]\n\n                if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1)\n                if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1)\n                if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1]\n\n                self.netOne = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netTwo = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netThr = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netFou = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netFiv = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n                )\n\n                self.netSix = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1)\n                )\n            # end\n\n            def forward(self, tenFirst, tenSecond, objPrevious):\n                tenFlow = None\n                tenFeat = None\n\n                if objPrevious is None:\n                    tenFlow = None\n                    tenFeat = None\n\n                    tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False)\n\n                    tenFeat = torch.cat([ tenVolume ], 1)\n\n                elif objPrevious is not None:\n                    tenFlow = self.netUpflow(objPrevious['tenFlow'])\n                    tenFeat = self.netUpfeat(objPrevious['tenFeat'])\n\n                    tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False)\n\n                    tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1)\n\n                tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1)\n                tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1)\n                tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1)\n                tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1)\n                tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1)\n\n                tenFlow = self.netSix(tenFeat)\n\n                return {\n                    'tenFlow': tenFlow,\n                    'tenFeat': tenFeat\n                }\n\n        class Refiner(torch.nn.Module):\n            def __init__(self):\n                super(Refiner, self).__init__()\n\n                self.netMain = torch.nn.Sequential(\n                    torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1),\n                    torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n                    torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1)\n                )\n\n            def forward(self, tenInput):\n                return self.netMain(tenInput)\n\n        self.netExtractor = Extractor()\n\n        self.netTwo = Decoder(2)\n        self.netThr = Decoder(3)\n        self.netFou = Decoder(4)\n        self.netFiv = Decoder(5)\n        self.netSix = Decoder(6)\n\n        self.netRefiner = Refiner()\n\n        # self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(__file__.replace('run.py', 'network-' + arguments_strModel + '.pytorch')).items() })\n\n    def forward(self, tenFirst, tenSecond):\n        tenFirst = self.netExtractor(tenFirst)\n        tenSecond = self.netExtractor(tenSecond)\n\n        objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None)\n        objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate)\n        objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate)\n        objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate)\n        objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate)\n\n        return objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat'])\n\n\nclass PWCNet(torch.nn.Module):\n    def __init__(self, load_pretrained=True, weights_path=None, rgb2bgr=False):\n        super(PWCNet, self).__init__()\n        self.net = Network()\n        self.rgb2bgr = rgb2bgr\n\n        if load_pretrained:\n            if weights_path is None:\n                raise Exception\n            else:\n                weights_dict = torch.load(weights_path)\n                self.net.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight\n                                          in weights_dict.items()})\n\n\n    # @autocast()\n    def forward(self, source_img, target_img):\n        assert (source_img.shape[-1] == target_img.shape[-1])\n        assert (source_img.shape[-2] == target_img.shape[-2])\n\n        int_width = source_img.shape[-1]\n        int_height = source_img.shape[-2]\n\n        source_img = source_img.view(-1, 3, int_height, int_width)\n        target_img = target_img.view(-1, 3, int_height, int_width)\n\n        if self.rgb2bgr:\n            source_img = source_img[:, [2, 1, 0]].contiguous()\n            target_img = target_img[:, [2, 1, 0]].contiguous()\n\n        int_preprocessed_width = int(math.floor(math.ceil(int_width / 64.0) * 64.0))\n        int_preprocessed_height = int(math.floor(math.ceil(int_height / 64.0) * 64.0))\n\n        # Make size multiple of 64\n        source_img_re = torch.nn.functional.interpolate(input=source_img,\n                                                        size=(int_preprocessed_height, int_preprocessed_width),\n                                                        mode='bilinear', align_corners=False)\n        target_img_re = torch.nn.functional.interpolate(input=target_img,\n                                                        size=(int_preprocessed_height, int_preprocessed_width),\n                                                        mode='bilinear', align_corners=False)\n\n        flow = self.net(target_img_re, source_img_re)\n        flow = 20.0 * torch.nn.functional.interpolate(input=flow, size=(int_height, int_width), mode='bilinear',\n                                                      align_corners=False)\n\n        scale_factor_x = float(int_width) / float(int_preprocessed_width)\n        scale_factor_y = float(int_height) / float(int_preprocessed_height)\n        flow = torch.stack((flow[:, 0] * scale_factor_x, flow[:, 1] * scale_factor_y), dim=1)\n\n        return flow\n"
  },
  {
    "path": "code/real/bsrt/pwcnet/requirements.txt",
    "content": "cupy>=5.0.0\nnumpy>=1.15.0\nPillow>=5.0.0\ntorch>=1.3.0"
  },
  {
    "path": "code/real/bsrt/pwcnet/run.py",
    "content": "#!/usr/bin/env python\n\nimport torch\n\nimport getopt\nimport math\nimport numpy\nimport os\nimport PIL\nimport PIL.Image\nimport sys\n\ntry:\n\tfrom .correlation import correlation # the custom cost volume layer\nexcept:\n\tsys.path.insert(0, './correlation'); import correlation # you should consider upgrading python\n# end\n\n##########################################################\n\nassert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0\n\ntorch.set_grad_enabled(False) # make sure to not compute gradients for computational performance\n\ntorch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance\n\n##########################################################\n\narguments_strModel = 'default'\narguments_strFirst = './images/first.png'\narguments_strSecond = './images/second.png'\narguments_strOut = './out.flo'\n\nfor strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]:\n\tif strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use\n\tif strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame\n\tif strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame\n\tif strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored\n# end\n\n##########################################################\n\nbackwarp_tenGrid = {}\nbackwarp_tenPartial = {}\n\ndef backwarp(tenInput, tenFlow):\n\tif str(tenFlow.shape) not in backwarp_tenGrid:\n\t\ttenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1)\n\t\ttenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3])\n\n\t\tbackwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda()\n\t# end\n\n\tif str(tenFlow.shape) not in backwarp_tenPartial:\n\t\tbackwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ])\n\t# end\n\n\ttenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)\n\ttenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1)\n\n\ttenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False)\n\n\ttenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0\n\n\treturn tenOutput[:, :-1, :, :] * tenMask\n# end\n\n##########################################################\n\nclass Network(torch.nn.Module):\n\tdef __init__(self):\n\t\tsuper(Network, self).__init__()\n\n\t\tclass Extractor(torch.nn.Module):\n\t\t\tdef __init__(self):\n\t\t\t\tsuper(Extractor, self).__init__()\n\n\t\t\t\tself.netOne = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netTwo = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netThr = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netFou = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netFiv = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netSix = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\t\t\t# end\n\n\t\t\tdef forward(self, tenInput):\n\t\t\t\ttenOne = self.netOne(tenInput)\n\t\t\t\ttenTwo = self.netTwo(tenOne)\n\t\t\t\ttenThr = self.netThr(tenTwo)\n\t\t\t\ttenFou = self.netFou(tenThr)\n\t\t\t\ttenFiv = self.netFiv(tenFou)\n\t\t\t\ttenSix = self.netSix(tenFiv)\n\n\t\t\t\treturn [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ]\n\t\t\t# end\n\t\t# end\n\n\t\tclass Decoder(torch.nn.Module):\n\t\t\tdef __init__(self, intLevel):\n\t\t\t\tsuper(Decoder, self).__init__()\n\n\t\t\t\tintPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1]\n\t\t\t\tintCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0]\n\n\t\t\t\tif intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1)\n\t\t\t\tif intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1)\n\t\t\t\tif intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1]\n\n\t\t\t\tself.netOne = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netTwo = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netThr = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netFou = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netFiv = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1)\n\t\t\t\t)\n\n\t\t\t\tself.netSix = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1)\n\t\t\t\t)\n\t\t\t# end\n\n\t\t\tdef forward(self, tenFirst, tenSecond, objPrevious):\n\t\t\t\ttenFlow = None\n\t\t\t\ttenFeat = None\n\n\t\t\t\tif objPrevious is None:\n\t\t\t\t\ttenFlow = None\n\t\t\t\t\ttenFeat = None\n\n\t\t\t\t\ttenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False)\n\n\t\t\t\t\ttenFeat = torch.cat([ tenVolume ], 1)\n\n\t\t\t\telif objPrevious is not None:\n\t\t\t\t\ttenFlow = self.netUpflow(objPrevious['tenFlow'])\n\t\t\t\t\ttenFeat = self.netUpfeat(objPrevious['tenFeat'])\n\n\t\t\t\t\ttenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False)\n\n\t\t\t\t\ttenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1)\n\n\t\t\t\t# end\n\n\t\t\t\ttenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1)\n\t\t\t\ttenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1)\n\t\t\t\ttenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1)\n\t\t\t\ttenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1)\n\t\t\t\ttenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1)\n\n\t\t\t\ttenFlow = self.netSix(tenFeat)\n\n\t\t\t\treturn {\n\t\t\t\t\t'tenFlow': tenFlow,\n\t\t\t\t\t'tenFeat': tenFeat\n\t\t\t\t}\n\t\t\t# end\n\t\t# end\n\n\t\tclass Refiner(torch.nn.Module):\n\t\t\tdef __init__(self):\n\t\t\t\tsuper(Refiner, self).__init__()\n\n\t\t\t\tself.netMain = torch.nn.Sequential(\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1),\n\t\t\t\t\ttorch.nn.LeakyReLU(inplace=False, negative_slope=0.1),\n\t\t\t\t\ttorch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1)\n\t\t\t\t)\n\t\t\t# end\n\n\t\t\tdef forward(self, tenInput):\n\t\t\t\treturn self.netMain(tenInput)\n\t\t\t# end\n\t\t# end\n\n\t\tself.netExtractor = Extractor()\n\n\t\tself.netTwo = Decoder(2)\n\t\tself.netThr = Decoder(3)\n\t\tself.netFou = Decoder(4)\n\t\tself.netFiv = Decoder(5)\n\t\tself.netSix = Decoder(6)\n\n\t\tself.netRefiner = Refiner()\n\n\t\tself.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(__file__.replace('run.py', 'network-' + arguments_strModel + '.pytorch')).items() })\n\t# end\n\n\tdef forward(self, tenFirst, tenSecond):\n\t\ttenFirst = self.netExtractor(tenFirst)\n\t\ttenSecond = self.netExtractor(tenSecond)\n\n\t\tobjEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None)\n\t\tobjEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate)\n\t\tobjEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate)\n\t\tobjEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate)\n\t\tobjEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate)\n\n\t\treturn objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat'])\n\t# end\n# end\n\nnetNetwork = None\n\n##########################################################\n\ndef estimate(tenFirst, tenSecond):\n\tglobal netNetwork\n\n\tif netNetwork is None:\n\t\tnetNetwork = Network().cuda().eval()\n\t# end\n\n\tassert(tenFirst.shape[1] == tenSecond.shape[1])\n\tassert(tenFirst.shape[2] == tenSecond.shape[2])\n\n\tintWidth = tenFirst.shape[2]\n\tintHeight = tenFirst.shape[1]\n\n\tassert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue\n\tassert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue\n\n\ttenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth)\n\ttenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth)\n\n\tintPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))\n\tintPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))\n\n\ttenPreprocessedFirst = torch.nn.functional.interpolate(input=tenPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)\n\ttenPreprocessedSecond = torch.nn.functional.interpolate(input=tenPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)\n\n\ttenFlow = 20.0 * torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False)\n\n\ttenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)\n\ttenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)\n\n\treturn tenFlow[0, :, :, :].cpu()\n# end\n\n##########################################################\n\nif __name__ == '__main__':\n\ttenFirst = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))\n\ttenSecond = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))\n\n\ttenOutput = estimate(tenFirst, tenSecond)\n\n\tobjOutput = open(arguments_strOut, 'wb')\n\n\tnumpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput)\n\tnumpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput)\n\tnumpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput)\n\n\tobjOutput.close()\n\tprint('finished...')\n# end"
  },
  {
    "path": "code/real/bsrt/requirements.txt",
    "content": "matplotlib\nimageio\nopencv-python\ntensorboardX\ntqdm\ntimm"
  },
  {
    "path": "code/real/bsrt/scripts/__init__.py",
    "content": ""
  },
  {
    "path": "code/real/bsrt/scripts/cal_mean_std.py",
    "content": "import torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image\nfrom datasets.synthetic_burst_train_set import SyntheticBurst\nfrom datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB\n\ndef main():\n    train_zurich_raw2rgb = ZurichRAW2RGB(root='/data/dataset/ntire21/burstsr/synthetic', split='train')\n    train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=14, crop_sz=384)\n    means = []\n    stds = []\n\n    for data in tqdm(train_data):\n        print(data.shape)\n        break\n\n\nif __name__ == '__main__':\n    # if not args.cpu: torch.cuda.set_device(0)\n    main()\n"
  },
  {
    "path": "code/real/bsrt/scripts/demo.sh",
    "content": "set -ex\nrlaunch --cpu=4 --gpu=1 --memory=10240 -- python ./scripts/evaluate_burstsr_val.py\n"
  },
  {
    "path": "code/real/bsrt/scripts/download_burstsr_dataset.py",
    "content": "import os\nimport urllib.request\nimport zipfile\nimport shutil\nimport argparse\n\n\ndef download_burstsr_dataset(download_path):\n    out_dir = download_path + '/burstsr_dataset'\n\n    # Download train folders\n    for i in range(9):\n        if not os.path.isfile('{}/train_{:02d}.zip'.format(out_dir, i)):\n            print('Downloading train_{:02d}'.format(i))\n\n            urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/train_{:02d}.zip'.format(i),\n                                       '{}/tmp.zip'.format(out_dir))\n\n            os.rename('{}/tmp.zip'.format(out_dir), '{}/train_{:02d}.zip'.format(out_dir, i))\n\n    # Download val folder\n    if not os.path.isfile('{}/val.zip'.format(out_dir)):\n        print('Downloading val')\n\n        urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/val.zip',\n                                   '{}/tmp.zip'.format(out_dir))\n\n        os.rename('{}/tmp.zip'.format(out_dir), '{}/val.zip'.format(out_dir))\n\n    # Unpack train set\n    for i in range(9):\n        print('Unpacking train_{:02d}'.format(i))\n        with zipfile.ZipFile('{}/train_{:02d}.zip'.format(out_dir, i), 'r') as zip_ref:\n            zip_ref.extractall('{}'.format(out_dir))\n\n    # Move files to a common directory\n    os.makedirs('{}/train'.format(out_dir), exist_ok=True)\n\n    for i in range(9):\n        file_list = os.listdir('{}/train_{:02d}'.format(out_dir, i))\n\n        for b in file_list:\n            source_dir = '{}/train_{:02d}/{}'.format(out_dir, i, b)\n            dst_dir = '{}/train/{}'.format(out_dir, b)\n\n            if os.path.isdir(source_dir):\n                shutil.move(source_dir, dst_dir)\n\n    # Delete individual subsets\n    for i in range(9):\n        shutil.rmtree('{}/train_{:02d}'.format(out_dir, i))\n\n    # Unpack val set\n    print('Unpacking val')\n    with zipfile.ZipFile('{}/val.zip'.format(out_dir), 'r') as zip_ref:\n        zip_ref.extractall('{}'.format(out_dir))\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Downloads and unpacks BurstSR dataset')\n    parser.add_argument('path', type=str, help='Path where the dataset will be downloaded')\n\n    args = parser.parse_args()\n\n    download_burstsr_dataset(args.path)\n\n\nif __name__ == '__main__':\n    main()\n\n\n"
  },
  {
    "path": "code/real/bsrt/scripts/evaluate.sh",
    "content": "set -ex\nrlaunch --cpu=4 --gpu=1 --memory=10240 -- python scripts/evaluate_burstsr_val.py\n"
  },
  {
    "path": "code/real/bsrt/scripts/evaluate_burstsr_val.py",
    "content": "import torch.nn.functional as F\nfrom datasets.burstsr_dataset import BurstSRDataset\nfrom utils.metrics import AlignedPSNR\nfrom pwcnet.pwcnet import PWCNet\n\nroot = '/data/dataset/ntire21/burstsr/real/NTIRE/burstsr_dataset'\n\nclass SimpleBaseline:\n    def __init__(self):\n        pass\n\n    def __call__(self, burst):\n        burst_rgb = burst[:, 0, [0, 1, 3]]\n        burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])\n        burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')\n        return burst_rgb\n\n\ndef main():\n    # Load dataset\n    dataset = BurstSRDataset(root=root,\n                             split='val', burst_size=14, crop_sz=80, random_flip=False)\n\n    # TODO Set your network here\n    net = SimpleBaseline()\n\n    device = 'cuda'\n\n    # Load alignment network, used in AlignedPSNR\n    alignment_net = PWCNet(load_pretrained=True,\n                           weights_path='PATH_TO_PWCNET_WEIGHTS')\n    alignment_net = alignment_net.to(device)\n    aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40)\n\n    scores_all = []\n    for idx in range(len(dataset)):\n        burst, frame_gt, meta_info_burst, meta_info_gt = dataset[idx]\n        burst = burst.unsqueeze(0).to(device)\n        frame_gt = frame_gt.unsqueeze(0).to(device)\n\n        net_pred = net(burst)\n\n        # Calculate Aligned PSNR\n        score = aligned_psnr_fn(net_pred, frame_gt, burst)\n\n        scores_all.append(score)\n\n    mean_psnr = sum(scores_all) / len(scores_all)\n\n    print('Mean PSNR is {:0.3f}'.format(mean_psnr.item()))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/real/bsrt/scripts/save_results_synburst_val.py",
    "content": "import torch.nn.functional as F\nimport cv2\nfrom datasets.synthetic_burst_val_set import SyntheticBurstVal\nimport torch\nimport numpy as np\nimport os\n\n\nclass SimpleBaseline:\n    def __init__(self):\n        pass\n\n    def __call__(self, burst):\n        burst_rgb = burst[:, 0, [0, 1, 3]]\n        burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])\n        burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')\n        return burst_rgb\n\n\ndef main():\n    dataset = SyntheticBurstVal('PATH_TO_SyntheticBurstVal')\n    out_dir = 'PATH_WHERE_RESULTS_ARE_SAVED'\n\n    # TODO Set your network here\n    net = SimpleBaseline()\n\n    device = 'cuda'\n    os.makedirs(out_dir, exist_ok=True)\n\n    for idx in range(len(dataset)):\n        burst, burst_name = dataset[idx]\n\n        burst = burst.to(device).unsqueeze(0)\n\n        with torch.no_grad():\n            net_pred = net(burst)\n\n        # Normalize to 0  2^14 range and convert to numpy array\n        net_pred_np = (net_pred.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16)\n\n        # Save predictions as png\n        cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/real/bsrt/scripts/test_burstsr_dataset.py",
    "content": "import torch.nn.functional as F\nimport cv2\nfrom datasets.burstsr_dataset import BurstSRDataset\nfrom torch.utils.data.dataloader import DataLoader\nfrom utils.metrics import AlignedPSNR\nfrom utils.postprocessing_functions import BurstSRPostProcess\nfrom utils.data_format_utils import convert_dict\nfrom pwcnet.pwcnet import PWCNet\n\n\ndef main():\n    # Load dataset\n    dataset = BurstSRDataset(root='PATH_TO_BURST_SR',\n                             split='val', burst_size=3, crop_sz=56, random_flip=False)\n\n    data_loader = DataLoader(dataset, batch_size=2)\n\n    # Load alignment network, used in AlignedPSNR\n    alignment_net = PWCNet(load_pretrained=True,\n                           weights_path='PATH_TO_PWCNET_WEIGHTS')\n    alignment_net = alignment_net.to('cuda')\n\n    aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40)\n\n    # Postprocessing function to obtain sRGB images\n    postprocess_fn = BurstSRPostProcess(return_np=True)\n\n    for d in data_loader:\n        burst, frame_gt, meta_info_burst, meta_info_gt = d\n\n        # A simple baseline which upsamples the base image using bilinear upsampling\n        burst_rgb = burst[:, 0, [0, 1, 3]]\n        burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])\n        burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')\n\n        # Calculate Aligned PSNR\n        score = aligned_psnr_fn(burst_rgb.cuda(), frame_gt.cuda(), burst.cuda())\n        print('PSNR is {:0.3f}'.format(score))\n\n        meta_info_gt = convert_dict(meta_info_gt, burst.shape[0])\n\n        # Apply simple post-processing to obtain RGB images\n        pred_0 = postprocess_fn.process(burst_rgb[0], meta_info_gt[0])\n        gt_0 = postprocess_fn.process(frame_gt[0], meta_info_gt[0])\n\n        pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR)\n        gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR)\n\n        # Visualize input, ground truth\n        cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0)\n        cv2.imshow('GT', gt_0)\n\n        input_key = cv2.waitKey(0)\n        if input_key == ord('q'):\n            return\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/real/bsrt/scripts/test_synthetic_bursts.py",
    "content": "import torch.nn.functional as F\nimport cv2\nfrom datasets.synthetic_burst_train_set import SyntheticBurst\nfrom torch.utils.data.dataloader import DataLoader\nfrom utils.metrics import PSNR\nfrom utils.postprocessing_functions import SimplePostProcess\nfrom utils.data_format_utils import convert_dict\nfrom datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB\n\n\ndef main():\n    zurich_raw2rgb = ZurichRAW2RGB(root='PATH_TO_ZURICH_RAW_TO_RGB', split='test')\n    dataset = SyntheticBurst(zurich_raw2rgb, burst_size=3, crop_sz=256)\n\n    data_loader = DataLoader(dataset, batch_size=2)\n\n    # Function to calculate PSNR. Note that the boundary pixels (40 pixels) will be ignored during PSNR computation\n    psnr_fn = PSNR(boundary_ignore=40)\n\n    # Postprocessing function to obtain sRGB images\n    postprocess_fn = SimplePostProcess(return_np=True)\n\n    for d in data_loader:\n        burst, frame_gt, flow_vectors, meta_info = d\n\n        # A simple baseline which upsamples the base image using bilinear upsampling\n        burst_rgb = burst[:, 0, [0, 1, 3]]\n        burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])\n        burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')\n\n        # Calculate PSNR\n        score = psnr_fn(burst_rgb, frame_gt)\n\n        print('PSNR is {:0.3f}'.format(score))\n\n        meta_info = convert_dict(meta_info, burst.shape[0])\n\n        # Apply simple post-processing to obtain RGB images\n        pred_0 = postprocess_fn.process(burst_rgb[0], meta_info[0])\n        gt_0 = postprocess_fn.process(frame_gt[0], meta_info[0])\n\n        pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR)\n        gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR)\n\n        # Visualize input, ground truth\n        cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0)\n        cv2.imshow('GT', gt_0)\n\n        input_key = cv2.waitKey(0)\n        if input_key == ord('q'):\n            return\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/real/bsrt/test.py",
    "content": "import torch.nn.functional as F\nimport cv2\n\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\n\n\nfrom datasets.realworld_burst_test_set import RealWorldBurstTest\nfrom datasets.burstsr_dataset import flatten_raw_image_batch, pack_raw_image_batch\nimport model\n\nimport utility\nfrom option import args\n\nimport torch.multiprocessing as mp\nimport torch.backends.cudnn as cudnn\nimport torch.utils.data.distributed\nimport time\n\n\ncheckpoint = utility.checkpoint(args)\n\ndef main_worker(local_rank, nprocs, args):\n    device = 'cuda'\n    cudnn.benchmark = True\n    args.local_rank = local_rank\n    utility.setup(local_rank, nprocs)\n    torch.cuda.set_device(local_rank)\n\n    dataset = RealWorldBurstTest(args.root)\n    out_dir = 'bsrt_realworld'\n    os.makedirs(out_dir, exist_ok=True)\n\n    _model = model.Model(args, checkpoint)\n\n    tt = []\n    for idx in tqdm(range(len(dataset))):\n        burst, meta_info = dataset[idx]\n        burst_name = meta_info['burst_name']\n\n        burst = burst.to(device).unsqueeze(0)\n\n        with torch.no_grad():\n            tic = time.time()\n            sr = _model(burst, 0).float()\n            toc = time.time()\n            tt.append(toc-tic)\n\n        # Normalize to 0  2^14 range and convert to numpy array\n        net_pred_np = (sr.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16)\n        cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np)\n\n    print('avg time: {:.4f}'.format(np.mean(tt)))\n    utility.cleanup()\n\ndef main():\n    mp.spawn(main_worker, nprocs=1, args=(1, args))\n\nif __name__ == '__main__':\n    main()\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "code/real/bsrt/test_real.py",
    "content": "\nimport cv2\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\nimport random\nimport utility\nfrom option import args\nimport torchvision.utils as tvutils\nfrom pwcnet.pwcnet import PWCNet\n\nfrom utils.postprocessing_functions import BurstSRPostProcess\nfrom datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image_batch, pack_raw_image\nfrom utils.metrics import AlignedPSNR\nfrom utils.data_format_utils import convert_dict\nfrom data_processing.camera_pipeline import demosaic\nimport model\n\nimport torch.multiprocessing as mp\nimport torch.backends.cudnn as cudnn\nimport torch.utils.data.distributed\nimport time\n\nfrom torchsummaryX import summary\n\n\ncheckpoint = utility.checkpoint(args)\n\n\ndef main():\n    mp.spawn(main_worker, nprocs=1, args=(1, args))\n\n\ndef main_worker(local_rank, nprocs, args):\n    cudnn.benchmark = True\n    args.local_rank = local_rank\n    utility.setup(local_rank, nprocs)\n    torch.cuda.set_device(local_rank)\n\n    dataset = BurstSRDataset(root=args.root, burst_size=14, crop_sz=80, split='val')\n    out_dir = 'val/bsrt_real'\n\n    _model = model.Model(args, checkpoint)\n\n    for param in _model.parameters():\n        param.requires_grad = False\n\n    alignment_net = PWCNet(load_pretrained=True,\n                           weights_path='./pwcnet/pwcnet-network-default.pth')\n    alignment_net = alignment_net.to('cuda')\n    for param in alignment_net.parameters():\n        param.requires_grad = False\n\n    aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40)\n\n    postprocess_fn = BurstSRPostProcess(return_np=True)\n\n    os.makedirs(out_dir, exist_ok=True)\n\n    tt = []\n    psnrs, ssims, lpipss = [], [], []\n    for idx in tqdm(range(len(dataset))):\n        burst_, gt, meta_info_burst, meta_info_gt = dataset[idx]\n        burst_ = burst_.unsqueeze(0).cuda()\n        gt = gt.unsqueeze(0).cuda()\n        # burst = flatten_raw_image_batch(burst_)\n        name = meta_info_burst['burst_name']\n\n        with torch.no_grad():\n            tic = time.time()\n            sr = _model(burst_, 0).float()\n            toc = time.time()\n            tt.append(toc-tic)\n\n            # sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short()\n            # sr = sr_int.float() / (2 ** 14)\n\n            psnr, ssim, lpips = aligned_psnr_fn(sr, gt, burst_)\n            psnrs.append(psnr.item())\n            ssims.append(ssim.item())\n            lpipss.append(lpips.item())\n\n        # lrs = burst_[0]\n        # os.makedirs(f'{out_dir}/{name}', exist_ok=True)\n        # for i, lr in enumerate(lrs):\n        #     # print(lr[[0, 1, 3],...].shape)\n        #     lr = postprocess_fn.process(lr[[0, 1, 3],...], meta_info_burst)\n        #     lr = cv2.cvtColor(lr, cv2.COLOR_RGB2BGR)\n        #     cv2.imwrite('{}/{}/{:2d}.png'.format(out_dir, name, i), lr)\n\n        # gt = postprocess_fn.process(gt[0], meta_info_burst)\n        # gt = cv2.cvtColor(gt, cv2.COLOR_RGB2BGR)\n        # cv2.imwrite('{}/{}_gt.png'.format(out_dir, name), gt)\n\n        # sr_ = postprocess_fn.process(sr[0], meta_info_burst)\n        # sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR)\n        # cv2.imwrite('{}/{}_bsrt.png'.format(out_dir, name), sr_)\n\n        del burst_\n        del sr\n        del gt\n\n\n    print(f'avg PSNR: {np.mean(psnrs):.6f}')\n    print(f'avg SSIM: {np.mean(ssims):.6f}')\n    print(f'avg LPIPS: {np.mean(lpipss):.6f}')\n    print(f' avg time: {np.mean(tt):.6f}')\n\n    # utility.cleanup()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/real/bsrt/trainer.py",
    "content": "import os\nimport sys\nfrom decimal import Decimal\nimport cv2\nimport utility\nimport torchvision.utils as tvutils\nimport torch.nn.functional as F\nimport random\n\nimport torch\nfrom tensorboardX import SummaryWriter\nfrom pwcnet.pwcnet import PWCNet\n\nfrom utils.postprocessing_functions import BurstSRPostProcess\nfrom utils.data_format_utils import convert_dict\nfrom utils.metrics import AlignedL1, AlignedPSNR\nfrom datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch, pack_raw_image_batch\nfrom data_processing.camera_pipeline import demosaic\nfrom tqdm import tqdm\nfrom loss.filter import Filter\n\nfrom torch.cuda.amp import autocast as autocast, GradScaler\n\ntrain_log_dir = '../train_log/'\n\nexp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]\ntfboard_name = exp_name + \"_\"\nexp_train_log_dir = os.path.join(train_log_dir, exp_name)\n\nLOG_DIR = os.path.join(exp_train_log_dir, 'logs')\n\n# save img path\nIMG_SAVE_DIR = os.path.join(exp_train_log_dir, 'img_log')\n# Where to load model\nLOAD_MODEL_DIR = os.path.join(exp_train_log_dir, 'models')\n# Where to save new model\nSAVE_MODEL_DIR = os.path.join(exp_train_log_dir, 'real_models')\n\n# Where to save visualization images (for report)\nRESULTS_DIR = os.path.join(exp_train_log_dir, 'report')\n\nutility.mkdir(SAVE_MODEL_DIR)\nutility.mkdir(IMG_SAVE_DIR)\nutility.mkdir(LOG_DIR)\n\n\nclass Trainer():\n    def __init__(self, args, train_loader, train_sampler, valid_loader, my_model, my_loss, ckp):\n        self.args = args\n        self.scale = args.scale[0]\n\n        self.ckp = ckp\n        self.loader_train = train_loader\n        self.loader_valid = valid_loader\n        self.train_sampler = train_sampler\n        self.model = my_model\n        self.loss = my_loss\n        self.optimizer = utility.make_optimizer(args, self.model)\n\n        # Postprocessing function to obtain sRGB images\n        self.postprocess_fn = BurstSRPostProcess(return_np=True)\n\n        self.alignment_net = PWCNet(load_pretrained=True,\n                           weights_path='./pwcnet/pwcnet-network-default.pth')\n        self.alignment_net = self.alignment_net.to('cuda')\n        for param in self.alignment_net.parameters():\n            param.requires_grad = False\n\n        self.aligned_psnr_fn = AlignedPSNR(alignment_net=self.alignment_net, boundary_ignore=40)\n\n        if 'L1' in args.loss:\n            self.aligned_loss = AlignedL1(alignment_net=self.alignment_net, boundary_ignore=40)\n\n        if self.args.fp16:\n            self.scaler = GradScaler()\n\n        self.best_psnr = 0.\n        self.best_epoch = 0\n\n        if self.args.load != '':\n            self.optimizer.load(ckp.dir, epoch=len(ckp.log))\n\n        self.error_last = 1e8\n        self.glob_iter = 0\n\n        self.log_dir = LOG_DIR + \"/\" + args.save\n        self.img_save_dir = IMG_SAVE_DIR + \"/\" + args.save\n        # Where to load model\n        self.load_model_dir = LOAD_MODEL_DIR + \"/\" + args.save\n        # Where to save new model\n        self.save_model_dir = SAVE_MODEL_DIR + \"/\" + args.save\n\n        # Where to save visualization images (for report)\n        self.results_dir = RESULTS_DIR + \"/\" + args.save\n        self.writer = SummaryWriter(log_dir=self.log_dir)\n\n        utility.mkdir(self.save_model_dir)\n        utility.mkdir(self.img_save_dir)\n        utility.mkdir(self.log_dir)\n        utility.mkdir('frames')\n\n\n    def train(self):\n        self.loss.step()\n        epoch = self.optimizer.get_last_epoch() + 1\n        lr = self.optimizer.get_lr()\n\n        if self.train_sampler:\n            self.train_sampler.set_epoch(epoch)\n        if epoch % 100 == 0:\n            self.ckp.write_log(\n                '[Epoch {}]\\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))\n            )\n        self.loss.start_log()\n\n        # self.test()\n        self.model.train()\n        if self.args.local_rank <= 0:\n            timer_data, timer_model, timer_epoch = utility.timer(), utility.timer(), utility.timer()\n            timer_epoch.tic()\n\n        for batch, batch_value in enumerate(self.loader_train):\n\n            burst, gt, meta_info_burst, meta_info_gt = batch_value\n            burst, gt = self.prepare(burst, gt)\n            # burst = flatten_raw_image_batch(burst_)\n\n            if self.args.local_rank == 0:\n                timer_data.hold()\n                timer_model.tic()\n\n            if self.args.fp16:\n                with autocast():\n                    sr = self.model(burst, 0).float()\n                    # loss = self.aligned_loss(sr, gt, burst)\n            else:\n                sr = self.model(burst, 0)\n            \n            loss = self.aligned_loss(sr, gt, burst)\n\n            if self.args.n_GPUs > 1:\n                torch.distributed.barrier()\n                reduced_loss = utility.reduce_mean(loss, self.args.n_GPUs)\n\n            else:\n                reduced_loss = loss\n\n            self.model.zero_grad()\n            if self.args.fp16:\n                self.scaler.scale(loss).backward()\n                # torch.nn.utils.clip_grad_value_(self.model.parameters(), .01)\n                if torch.isinf(sr).sum() + torch.isnan(sr).sum() <= 0:\n                    self.scaler.step(self.optimizer)\n                    self.scaler.update()\n                else:\n                    print(f'Nan num: {torch.isnan(sr).sum()}, inf num: {torch.isinf(sr).sum()}')\n                    reduced_loss = None\n                    os._exit(0)\n                    sys.exit(0)\n            else:\n                loss.backward()\n                # torch.nn.utils.clip_grad_value_(self.model.parameters(), .01)\n                if torch.isinf(sr).sum() + torch.isnan(sr).sum() <= 0:\n                    self.optimizer.step()\n                else:\n                    print(f'Nan num: {torch.isnan(sr).sum()}, inf num: {torch.isinf(sr).sum()}')\n                    reduced_loss = None\n\n            if self.args.local_rank == 0:\n                timer_model.hold()\n                if epoch % 1 == 0 and batch % 10 == 0:\n                    self.writer.add_scalars('Loss', {tfboard_name + '_mse_L1': reduced_loss.detach().cpu().numpy()},\n                                            self.glob_iter)\n\n                if (batch + 1) % self.args.print_every == 0:\n                    self.ckp.write_log('[{}/{}]\\t[{:.4f}]\\t{:.1f}+{:.1f}s'.format(\n                        (batch + 1) * self.args.batch_size,\n                        len(self.loader_train.dataset),\n                        reduced_loss.item(),\n                        timer_model.release(),\n                        timer_data.release()))\n\n                self.glob_iter += 1\n                timer_data.tic()\n\n            if self.args.local_rank <= 0 and (batch + 1) % 200 == 0:\n                if not self.args.test_only:\n                    filename = exp_name + '_latest' + '.pth'\n                    self.save_model(filename)\n\n\n        if self.args.local_rank <= 0:\n            timer_epoch.hold()\n            print('Epoch {} cost time: {:.1f}s, lr: {:5f}'.format(epoch, timer_epoch.release(), lr))\n            if (epoch) % 1 == 0 and not self.args.test_only:\n                filename = exp_name + '_epoch_' + str(epoch) + '.pth'\n                self.save_model(filename)\n\n            if not self.args.test_only:\n                filename = exp_name + '_latest' + '.pth'\n                self.save_model(filename)\n\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n        self.test()\n        self.loss.end_log(len(self.loader_train))\n        self.error_last = self.loss.log[-1, -1]\n        self.optimizer.schedule()\n\n    def test(self):\n        torch.set_grad_enabled(False)\n\n        def ttaup(burst):\n            # burst0 = flatten_raw_image_batch(burst) # B, T, C, H, W\n            # burst1 = utility.bayer_aug(burst0, flip_h=False, flip_w=False, transpose=True)\n\n            # burst0 = pack_raw_image_batch(burst0)\n            # burst1 = pack_raw_image_batch(burst1)\n            return [burst]\n\n        def ttadown(bursts):\n            burst0 = bursts[0]\n            # burst1 = bursts[1].permute(0, 1, 3, 2)\n            # out = (burst0 + burst1) / 2\n            out = burst0\n            return out\n\n        epoch = self.optimizer.get_last_epoch() + 1\n        self.model.eval()\n        if self.args.local_rank == 0:\n            print(\"Testing...\")\n            timer_test = utility.timer()\n        if epoch == 1 or epoch % 1 == 0:\n            self.model.eval()\n            total_psnr = 0\n            total_ssim = 0\n            total_lpips = 0\n            count = 0\n            for i, batch_value in tqdm(enumerate(self.loader_valid)):\n\n                burst, gt, meta_info_burst, meta_info_gt = batch_value\n                burst, gt = self.prepare(burst, gt)\n\n                # burst_ = flatten_raw_image_batch(burst)\n\n                bursts = ttaup(burst)\n\n                with torch.no_grad():\n                    srs = []\n                    for b in bursts:\n                        if self.args.fp16:\n                            with autocast():\n                                sr = self.model(b, 0).float()\n                        else:\n                            sr = self.model(b, 0).float()\n                        srs.append(sr)\n                    sr = ttadown(srs)\n                    \n                # sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short()\n                # sr = sr_int.float() / (2 ** 14)\n                score, ssim_score, lpips_score = self.aligned_psnr_fn(sr, gt, burst)\n\n                if self.args.n_GPUs > 1:\n                    torch.distributed.barrier()\n                    score = utility.reduce_mean(score, self.args.n_GPUs)\n                    ssim_score = utility.reduce_mean(ssim_score, self.args.n_GPUs)\n                    lpips_score = utility.reduce_mean(lpips_score, self.args.n_GPUs)\n\n                total_psnr += score\n                total_ssim += ssim_score\n                total_lpips += lpips_score\n                count += 1\n\n                # # if i > 3 and i < 6 and self.args.local_rank == 0:\n                # if i > 200 and i < 400 and self.args.local_rank <= 0:\n                #     meta_info_gt = convert_dict(meta_info_gt, burst.shape[0])\n                #     meta_info_burst = convert_dict(meta_info_burst, burst.shape[0])\n                #     # Apply simple post-processing to obtain RGB images\n\n                #     in_ = demosaic(burst[0][0])\n                #     in_ = self.postprocess_fn.process(in_, meta_info_burst[0])\n                #     sr_ = self.postprocess_fn.process(sr[0], meta_info_gt[0])\n                #     # gt_ = self.postprocess_fn.process(gt[0], meta_info_gt[0])\n\n                #     in_ = cv2.cvtColor(in_, cv2.COLOR_RGB2BGR)\n                #     sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR)\n                #     # gt_ = cv2.cvtColor(gt_, cv2.COLOR_RGB2BGR)\n\n                #     cv2.imwrite('frames/{}_in.png'.format(i), in_)\n                #     cv2.imwrite('frames/{}_gt.png'.format(i), gt_)\n                #     cv2.imwrite('frames/{}_sr.png'.format(i), sr_)\n\n            total_psnr = total_psnr / count\n            total_ssim = total_ssim / count\n            total_lpips = total_lpips / count\n\n            if self.args.local_rank == 0:\n                print(\"[Epoch: {}]\\n[PSNR: {:.4f}][SSIM: {:.4f}][LPIPS: {:.4f}][Best PSNR: {:.4f}][Best Epoch: {}]\"\n                    .format(epoch, total_psnr, total_ssim, total_lpips, self.best_psnr, self.best_epoch))\n                if epoch >= 1 and total_psnr > self.best_psnr:\n                    self.best_psnr = total_psnr\n                    self.best_epoch = epoch\n                    filename = exp_name + 'best_epoch.pth'\n                    self.save_model(filename)\n                self.writer.add_scalars('PSNR', {tfboard_name + '_PSNR': total_psnr}, self.glob_iter)\n\n                print('Forward: {:.2f}s\\n'.format(timer_test.toc()))\n\n        torch.cuda.synchronize()\n        torch.set_grad_enabled(True)\n        torch.cuda.empty_cache()\n\n    def save_model(self, filename):\n        print('save model...')\n        net_save_path = os.path.join(self.save_model_dir, filename)\n        model = self.model.model\n        if self.args.n_GPUs > 1:\n            model = model.module\n\n        torch.save(model.state_dict(), net_save_path)\n\n    def prepare(self, *args):\n        device = torch.device('cpu' if self.args.cpu else 'cuda:{}'.format(self.args.local_rank))\n\n        def _prepare(tensor):\n            if self.args.precision == 'half': tensor = tensor.half()\n            return tensor.to(device)\n\n        # print(_prepare(args[0]).device)\n        return [_prepare(a) for a in args]\n\n    def terminate(self):\n        if self.args.test_only:\n            self.test()\n            return True\n        else:\n            epoch = self.optimizer.get_last_epoch() + 1\n            return epoch >= self.args.epochs\n"
  },
  {
    "path": "code/real/bsrt/utility.py",
    "content": "import math\nimport time\nimport datetime\nfrom multiprocessing import Process\nfrom multiprocessing import Queue\n\nimport matplotlib.pyplot as plt\n\nimport numpy as np\nimport imageio\nimport os\nimport sys\n\nimport torch\nimport torch.optim as optim\nimport torch.optim.lr_scheduler as lrs\n\nimport torch.distributed as dist\nimport matplotlib\n\nmatplotlib.use('Agg')\n\n\ndef reduce_mean(tensor, nprocs):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.ReduceOp.SUM)\n    rt /= nprocs\n    return rt\n\n\ndef setup(rank, world_size):\n    if sys.platform == 'win32':\n        # Distributed package only covers collective communications with Gloo\n        # backend and FileStore on Windows platform. Set init_method parameter\n        # in init_process_group to a local file.\n        # Example init_method=\"file:///f:/libtmp/some_file\"\n        init_method = \"tcp://localhost:1234\"\n\n        # initialize the process group\n        dist.init_process_group(\n            \"gloo\",\n            init_method=init_method,\n            rank=rank,\n            world_size=world_size\n        )\n    else:\n        os.environ['MASTER_ADDR'] = 'localhost'\n        os.environ['MASTER_PORT'] = '12256'\n\n        # initialize the process group\n        dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n\n\ndef cleanup():\n    dist.destroy_process_group()\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\nclass timer():\n    def __init__(self):\n        self.acc = 0\n        self.tic()\n\n    def tic(self):\n        self.t0 = time.time()\n\n    def toc(self, restart=False):\n        diff = time.time() - self.t0\n        if restart: self.t0 = time.time()\n        return diff\n\n    def hold(self):\n        self.acc += self.toc()\n\n    def release(self):\n        ret = self.acc\n        self.acc = 0\n\n        return ret\n\n    def reset(self):\n        self.acc = 0\n\n\nclass checkpoint():\n    def __init__(self, args):\n        self.args = args\n        self.ok = True\n        self.log = torch.Tensor()\n        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')\n\n        if not args.load:\n            if not args.save:\n                args.save = now\n            self.dir = os.path.join('..', 'experiment', args.save)\n        else:\n            self.dir = os.path.join('..', 'experiment', args.load)\n            if os.path.exists(self.dir):\n                self.log = torch.load(self.get_path('psnr_log.pt'))\n                print('Continue from epoch {}...'.format(len(self.log)))\n            else:\n                args.load = ''\n\n        if args.reset:\n            os.system('rm -rf ' + self.dir)\n            args.load = ''\n\n        os.makedirs(self.dir, exist_ok=True)\n        os.makedirs(self.get_path('model'), exist_ok=True)\n        # for d in args.data_test:\n        #     os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)\n\n        open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w'\n        self.log_file = open(self.get_path('log.txt'), open_type)\n        with open(self.get_path('config.txt'), open_type) as f:\n            f.write(now + '\\n\\n')\n            for arg in vars(args):\n                f.write('{}: {}\\n'.format(arg, getattr(args, arg)))\n            f.write('\\n')\n\n        self.n_processes = 8\n\n    def get_path(self, *subdir):\n        return os.path.join(self.dir, *subdir)\n\n    def save(self, trainer, epoch, is_best=False):\n        trainer.model.save(self.get_path('model'), epoch, is_best=is_best)\n        trainer.loss.save(self.dir)\n        trainer.loss.plot_loss(self.dir, epoch)\n\n        self.plot_psnr(epoch)\n        trainer.optimizer.save(self.dir)\n        torch.save(self.log, self.get_path('psnr_log.pt'))\n\n    def add_log(self, log):\n        self.log = torch.cat([self.log, log])\n\n    def write_log(self, log, refresh=False):\n        print(log)\n        self.log_file.write(log + '\\n')\n        if refresh:\n            self.log_file.close()\n            self.log_file = open(self.get_path('log.txt'), 'a')\n\n    def done(self):\n        self.log_file.close()\n\n    def plot_psnr(self, epoch):\n        axis = np.linspace(1, epoch, epoch)\n        for idx_data, d in enumerate(self.args.data_test):\n            label = 'SR on {}'.format(d)\n            fig = plt.figure()\n            plt.title(label)\n            for idx_scale, scale in enumerate(self.args.scale):\n                plt.plot(\n                    axis,\n                    self.log[:, idx_data, idx_scale].numpy(),\n                    label='Scale {}'.format(scale)\n                )\n            plt.legend()\n            plt.xlabel('Epochs')\n            plt.ylabel('PSNR')\n            plt.grid(True)\n            plt.savefig(self.get_path('test_{}.pdf'.format(d)))\n            plt.close(fig)\n\n    def begin_background(self):\n        self.queue = Queue()\n\n        def bg_target(queue):\n            while True:\n                if not queue.empty():\n                    filename, tensor = queue.get()\n                    if filename is None: break\n                    imageio.imwrite(filename, tensor.numpy())\n\n        self.process = [\n            Process(target=bg_target, args=(self.queue,)) \\\n            for _ in range(self.n_processes)\n        ]\n\n        for p in self.process: p.start()\n\n    def end_background(self):\n        for _ in range(self.n_processes): self.queue.put((None, None))\n        while not self.queue.empty(): time.sleep(1)\n        for p in self.process: p.join()\n\n    def save_results(self, dataset, filename, save_list, scale):\n        if self.args.save_results:\n            filename = self.get_path(\n                'results-{}'.format(dataset.dataset.name),\n                '{}_x{}_'.format(filename, scale)\n            )\n\n            postfix = ('SR', 'LR', 'HR')\n            for v, p in zip(save_list, postfix):\n                normalized = v[0].mul(255 / self.args.rgb_range)\n                tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()\n                self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))\n\n\ndef quantize(img, rgb_range):\n    pixel_range = 255 / rgb_range\n    return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)\n\n\ndef calc_psnr(sr, hr, scale, rgb_range, dataset=None):\n    if hr.nelement() == 1: return 0\n\n    diff = (sr - hr) / rgb_range\n    if dataset and dataset.dataset.benchmark:\n        shave = scale\n        if diff.size(1) > 1:\n            gray_coeffs = [65.738, 129.057, 25.064]\n            convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256\n            diff = diff.mul(convert).sum(dim=1)\n    else:\n        shave = scale + 6\n\n    valid = diff[..., shave:-shave, shave:-shave]\n    mse = valid.pow(2).mean()\n\n    return -10 * math.log10(mse)\n\n\ndef make_optimizer(args, target):\n    '''\n        make optimizer and scheduler together\n    '''\n    # optimizer\n    trainable = filter(lambda x: x.requires_grad, target.parameters())\n    kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}\n\n    if args.optimizer == 'SGD':\n        optimizer_class = optim.SGD\n        kwargs_optimizer['momentum'] = args.momentum\n    elif args.optimizer == 'ADAM':\n        optimizer_class = optim.Adam\n        kwargs_optimizer['betas'] = args.betas\n        kwargs_optimizer['eps'] = args.epsilon\n    elif args.optimizer == 'RMSprop':\n        optimizer_class = optim.RMSprop\n        kwargs_optimizer['eps'] = args.epsilon\n\n    # scheduler\n    milestones = list(map(lambda x: int(x), args.decay.split('-')))\n    kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}\n    scheduler_class = lrs.MultiStepLR\n\n    class CustomOptimizer(optimizer_class):\n        def __init__(self, *args, **kwargs):\n            super(CustomOptimizer, self).__init__(*args, **kwargs)\n\n        def _register_scheduler(self, scheduler_class, **kwargs):\n            self.scheduler = scheduler_class(self, **kwargs)\n\n        def save(self, save_dir):\n            torch.save(self.state_dict(), self.get_dir(save_dir))\n\n        def load(self, load_dir, epoch=1):\n            self.load_state_dict(torch.load(self.get_dir(load_dir)))\n            if epoch > 1:\n                for _ in range(epoch): self.scheduler.step()\n\n        def get_dir(self, dir_path):\n            return os.path.join(dir_path, 'optimizer.pt')\n\n        def schedule(self):\n            self.scheduler.step()\n\n        def get_lr(self):\n            return self.scheduler.get_last_lr()[0]\n\n        def get_last_epoch(self):\n            return self.scheduler.last_epoch\n\n    optimizer = CustomOptimizer(trainable, **kwargs_optimizer)\n    optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)\n    return optimizer\n\n\ndef write_gray_to_tfboard(img):\n    img_debug = img[0, ...].detach().cpu().numpy()\n\n    # img_debug = cv2.normalize(img_debug, None, 0, 255,\n    #                           cv2.NORM_MINMAX, cv2.CV_8U)\n    img_debug = img_debug * 255\n    img_debug = np.clip(img_debug, 0, 255)\n    img_debug = img_debug.astype(np.uint8)\n    return img_debug[0, ...]\n\n\n\n\n\n\n\n######################## BayerUnifyAug ############################\n\nBAYER_PATTERNS = [\"RGGB\", \"BGGR\", \"GRBG\", \"GBRG\"]\nNORMALIZATION_MODE = [\"crop\", \"pad\"]\n\n\ndef bayer_unify(raw, input_pattern, target_pattern, mode) -> np.ndarray:\n    \"\"\"\n    Convert a bayer raw image from one bayer pattern to another.\n    mode: {\"crop\", \"pad\"}\n        The way to handle submosaic shift. \"crop\" abandons the outmost pixels,\n        and \"pad\" introduces extra pixels. Use \"crop\" in training and \"pad\" in\n        testing.\n    \"\"\"\n\n    if input_pattern == target_pattern:\n        h_offset, w_offset = 0, 0\n    elif input_pattern[0] == target_pattern[2] and input_pattern[1] == target_pattern[3]:\n        h_offset, w_offset = 1, 0\n    elif input_pattern[0] == target_pattern[1] and input_pattern[2] == target_pattern[3]:\n        h_offset, w_offset = 0, 1\n    elif input_pattern[0] == target_pattern[3] and input_pattern[1] == target_pattern[2]:\n        h_offset, w_offset = 1, 1\n    else:  # This is not happening in [\"RGGB\", \"BGGR\", \"GRBG\", \"GBRG\"]\n        raise RuntimeError('Unexpected pair of input and target bayer pattern!')\n\n    if mode == \"pad\":\n        # out = np.pad(raw, [[h_offset, h_offset], [w_offset, w_offset]], 'reflect')\n        out = F.pad(raw, (w_offset, w_offset, h_offset, h_offset), mode='reflect')\n    elif mode == \"crop\":\n        _, _, _, h, w = raw.shape\n        out = raw[..., h_offset:h - h_offset, w_offset:w - w_offset]\n    else:\n        raise ValueError('Unknown normalization mode!')\n\n    return out\n\n\ndef bayer_aug(raw, flip_h=False, flip_w=False, transpose=False, input_pattern='RGGB') -> np.ndarray:\n    \"\"\"\n    Apply augmentation to a bayer raw image.\n    \"\"\"\n\n    aug_pattern, target_pattern = input_pattern, input_pattern\n\n    out = raw\n    if flip_h:\n        out = torch.flip(out, [3]) # GBRG, RGGB\n        aug_pattern = aug_pattern[2] + aug_pattern[3] + aug_pattern[0] + aug_pattern[1]\n    if flip_w:\n        out = torch.flip(out, [4])\n        aug_pattern = aug_pattern[1] + aug_pattern[0] + aug_pattern[3] + aug_pattern[2]\n    if transpose:\n        out = out.permute(0, 1, 2, 4, 3)\n        aug_pattern = aug_pattern[0] + aug_pattern[2] + aug_pattern[1] + aug_pattern[3]\n\n    out = bayer_unify(out, aug_pattern, target_pattern, \"crop\")\n    return out\n"
  },
  {
    "path": "code/real/bsrt/utils/__init__.py",
    "content": ""
  },
  {
    "path": "code/real/bsrt/utils/data_format_utils.py",
    "content": "import numpy as np\nimport torch\nimport cv2 as cv\n\n\ndef numpy_to_torch(a: np.ndarray):\n    return torch.from_numpy(a).float().permute(2, 0, 1)\n\n\ndef torch_to_numpy(a: torch.Tensor):\n    return a.permute(1, 2, 0).cpu().numpy()\n\n\ndef torch_to_npimage(a: torch.Tensor, unnormalize=True):\n    a_np = torch_to_numpy(a)\n\n    if unnormalize:\n        a_np = a_np * 255\n    a_np = a_np.astype(np.uint8)\n    return cv.cvtColor(a_np, cv.COLOR_RGB2BGR)\n\n\ndef npimage_to_torch(a, normalize=True, input_bgr=True):\n    if input_bgr:\n        a = cv.cvtColor(a, cv.COLOR_BGR2RGB)\n    a_t = numpy_to_torch(a)\n\n    if normalize:\n        a_t = a_t / 255.0\n\n    return a_t\n\n\ndef convert_dict(base_dict, batch_sz):\n    out_dict = []\n    for b_elem in range(batch_sz):\n        b_info = {}\n        for k, v in base_dict.items():\n            if isinstance(v, (list, torch.Tensor)):\n                b_info[k] = v[b_elem]\n        out_dict.append(b_info)\n\n    return out_dict"
  },
  {
    "path": "code/real/bsrt/utils/debayer.py",
    "content": "import torch\nimport torch.nn\nimport torch.nn.functional\n\nclass Debayer3x3(torch.nn.Module):\n    '''Demosaicing of Bayer images using 3x3 convolutions.\n\n    Requires BG-Bayer color filter array layout. That is,\n    the image[1,1]='B', image[1,2]='G'. This corresponds\n    to OpenCV naming conventions.\n\n    Compared to Debayer2x2 this method does not use upsampling.\n    Instead, we identify five 3x3 interpolation kernels that\n    are sufficient to reconstruct every color channel at every\n    pixel location.\n\n    We convolve the image with these 5 kernels using stride=1\n    and a one pixel replication padding. Finally, we gather\n    the correct channel values for each pixel location. Todo so,\n    we recognize that the Bayer pattern repeats horizontally and\n    vertically every 2 pixels. Therefore, we define the correct\n    index lookups for a 2x2 grid cell and then repeat to image\n    dimensions.\n\n    Note, in every 2x2 grid cell we have red, blue and two greens\n    (G1,G2). The lookups for the two greens differ.\n    '''\n\n    def __init__(self):\n        super(Debayer3x3, self).__init__()\n\n        self.kernels = torch.nn.Parameter(\n            torch.tensor([\n                [0,0,0],\n                [0,1,0],\n                [0,0,0],\n\n                [0, 0.25, 0],\n                [0.25, 0, 0.25],\n                [0, 0.25, 0],\n\n                [0.25, 0, 0.25],\n                [0, 0, 0],\n                [0.25, 0, 0.25],\n\n                [0, 0, 0],\n                [0.5, 0, 0.5],\n                [0, 0, 0],\n\n                [0, 0.5, 0],\n                [0, 0, 0],\n                [0, 0.5, 0],\n            ]).view(5,1,3,3), requires_grad=False\n        )\n\n\n        self.index = torch.nn.Parameter(\n            torch.tensor([\n                # dest channel r\n                [0, 3], # pixel is R,G1\n                [4, 2], # pixel is G2,B\n                # dest channel g\n                [1, 0], # pixel is R,G1\n                [0, 1], # pixel is G2,B\n                # dest channel b\n                [2, 4], # pixel is R,G1\n                [3, 0], # pixel is G2,B\n            ]).view(1,3,2,2), requires_grad=False\n        )\n\n    def forward(self, x):\n        '''Debayer image.\n\n        Parameters\n        ----------\n        x : Bx1xHxW tensor\n            Images to debayer\n\n        Returns\n        -------\n        rgb : Bx3xHxW tensor\n            Color images in RGB channel order.\n        '''\n        B,C,H,W = x.shape\n\n        x = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate')\n        c = torch.nn.functional.conv2d(x, self.kernels, stride=1)\n        rgb = torch.gather(c, 1, self.index.repeat(B,1,H//2,W//2))\n        return rgb\n\nclass Debayer2x2(torch.nn.Module):\n    '''Demosaicing of Bayer images using 2x2 convolutions.\n\n    Requires BG-Bayer color filter array layout. That is,\n    the image[1,1]='B', image[1,2]='G'. This corresponds\n    to OpenCV naming conventions.\n    '''\n\n    def __init__(self):\n        super(Debayer2x2, self).__init__()\n\n        self.kernels = torch.nn.Parameter(\n            torch.tensor([\n                [1, 0],\n                [0, 0],\n\n                [0, 0.5],\n                [0.5, 0],\n\n                [0, 0],\n                [0, 1],\n            ]).view(3,1,2,2), requires_grad=False\n        )\n\n    def forward(self, x):\n        '''Debayer image.\n\n        Parameters\n        ----------\n        x : Bx1xHxW tensor\n            Images to debayer\n\n        Returns\n        -------\n        rgb : Bx3xHxW tensor\n            Color images in RGB channel order.\n        '''\n\n        x = torch.nn.functional.conv2d(x, self.kernels, stride=2)\n        x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)\n        return x\n\nclass DebayerSplit(torch.nn.Module):\n    '''Demosaicing of Bayer images using 3x3 green convolution and red,blue upsampling.\n\n    Requires BG-Bayer color filter array layout. That is,\n    the image[1,1]='B', image[1,2]='G'. This corresponds\n    to OpenCV naming conventions.\n    '''\n    def __init__(self):\n        super().__init__()\n\n        self.pad = torch.nn.ReflectionPad2d(1)\n        self.kernel = torch.nn.Parameter(\n            torch.tensor([\n                [0,1,0],\n                [1,0,1],\n                [0,1,0]\n            ])[None, None] * 0.25)\n\n    def forward(self, x):\n        '''Debayer image.\n\n        Parameters\n        ----------\n        x : Bx1xHxW tensor\n            Images to debayer\n\n        Returns\n        -------\n        rgb : Bx3xHxW tensor\n            Color images in RGB channel order.\n        '''\n        B,_,H,W = x.shape\n        red = x[:, :, ::2, ::2]\n        blue = x[:, :, 1::2, 1::2]\n\n        green = torch.nn.functional.conv2d(self.pad(x), self.kernel)\n        green[:, :, ::2, 1::2] = x[:, :, ::2, 1::2]\n        green[:, :, 1::2, ::2] = x[:, :, 1::2, ::2]\n\n        return torch.cat((\n            torch.nn.functional.interpolate(red, size=(H, W), mode='bilinear', align_corners=False),\n            green,\n            torch.nn.functional.interpolate(blue, size=(H, W), mode='bilinear', align_corners=False)),\n            dim=1)"
  },
  {
    "path": "code/real/bsrt/utils/interp_methods.py",
    "content": "from math import pi\n\ntry:\n    import torch\nexcept ImportError:\n    torch = None\n\ntry:\n    import numpy\nexcept ImportError:\n    numpy = None\n\nif numpy is None and torch is None:\n    raise ImportError(\"Must have either Numpy or PyTorch but both not found\")\n\n\ndef set_framework_dependencies(x):\n    if type(x) is numpy.ndarray:\n        to_dtype = lambda a: a\n        fw = numpy\n    else:\n        to_dtype = lambda a: a.to(x.dtype)\n        fw = torch\n    eps = fw.finfo(fw.float32).eps\n    return fw, to_dtype, eps\n\n\ndef support_sz(sz):\n    def wrapper(f):\n        f.support_sz = sz\n        return f\n    return wrapper\n\n@support_sz(4)\ndef cubic(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    absx = fw.abs(x)\n    absx2 = absx ** 2\n    absx3 = absx ** 3\n    return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +\n            (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *\n            to_dtype((1. < absx) & (absx <= 2.)))\n\n@support_sz(4)\ndef lanczos2(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /\n            ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))\n\n@support_sz(6)\ndef lanczos3(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /\n            ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))\n\n@support_sz(2)\ndef linear(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *\n            to_dtype((0 <= x) & (x <= 1)))\n\n@support_sz(1)\ndef box(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))\n"
  },
  {
    "path": "code/real/bsrt/utils/metrics.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport utils.spatial_color_alignment as sca_utils\nfrom utils.spatial_color_alignment import get_gaussian_kernel, match_colors\nfrom utils.warp import warp\nfrom torch.cuda.amp import autocast\nfrom loss.Charbonnier import CharbonnierLoss as CBLoss\nfrom loss.mssim import MSSSIM\nfrom pytorch_msssim import ssim\nimport lpips\n\n\nclass MSSSIMLoss(nn.Module):\n    def __init__(self, boundary_ignore=None):\n        super().__init__()\n        self.boundary_ignore = boundary_ignore\n        self.msssim = MSSSIM()\n\n    def forward(self, pred, gt, valid=None):\n        if self.boundary_ignore is not None:\n            pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_m = pred\n        gt_m = gt\n\n        loss = self.msssim(pred_m, gt_m)\n\n        return loss\n\nclass CharbonnierLoss(nn.Module):\n    def __init__(self, boundary_ignore=None):\n        super().__init__()\n        self.boundary_ignore = boundary_ignore\n        self.charbonnier_loss = CBLoss(reduce=True)\n\n    def forward(self, pred, gt, valid=None):\n        if self.boundary_ignore is not None:\n            pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_m = pred\n        gt_m = gt\n\n        loss = self.charbonnier_loss(pred_m, gt_m)\n\n        return loss\n\nclass L1(nn.Module):\n    def __init__(self, boundary_ignore=None):\n        super().__init__()\n        self.boundary_ignore = boundary_ignore\n\n    def forward(self, pred, gt, valid=None):\n        if self.boundary_ignore is not None:\n            pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            if valid is not None:\n                valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_m = pred\n        gt_m = gt\n\n        if valid is None:\n            mse = F.l1_loss(pred_m, gt_m)\n        else:\n            mse = F.l1_loss(pred_m, gt_m, reduction='none')\n\n            eps = 1e-12\n            elem_ratio = mse.numel() / valid.numel()\n            mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        return mse\n\nclass L2(nn.Module):\n    def __init__(self, boundary_ignore=None):\n        super().__init__()\n        self.boundary_ignore = boundary_ignore\n\n    def forward(self, pred, gt, valid=None):\n        if self.boundary_ignore is not None:\n            pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            if valid is not None:\n                valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_m = pred\n        gt_m = gt\n\n        if valid is None:\n            mse = F.mse_loss(pred_m, gt_m)\n        else:\n            mse = F.mse_loss(pred_m, gt_m, reduction='none')\n\n            eps = 1e-12\n            elem_ratio = mse.numel() / valid.numel()\n            mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        return mse\n\n\nclass PSNR(nn.Module):\n    def __init__(self, boundary_ignore=None, max_value=1.0):\n        super().__init__()\n        self.l2 = L2(boundary_ignore=boundary_ignore)\n        self.max_value = max_value\n\n    def psnr(self, pred, gt, valid=None):\n        mse = self.l2(pred, gt, valid=valid)\n\n        psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10()\n\n        return psnr\n\n    def forward(self, pred, gt, valid=None):\n        assert pred.dim() == 4 and pred.shape == gt.shape\n        if valid is None:\n            psnr_all = [self.psnr(p.unsqueeze(0), g.unsqueeze(0)) for p, g in\n                        zip(pred, gt)]\n        else:\n            psnr_all = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), v.unsqueeze(0)) for p, g, v in zip(pred, gt, valid)]\n        psnr = sum(psnr_all) / len(psnr_all)\n        return psnr\n\n\nclass AlignedL1(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):\n        super().__init__()\n        self.sr_factor = sr_factor\n        self.boundary_ignore = boundary_ignore\n        self.alignment_net = alignment_net\n\n        self.gauss_kernel, self.ksz = get_gaussian_kernel(sd=1.5)\n\n    def forward(self, pred, gt, burst_input):\n        # Estimate flow between the prediction and the ground truth\n        with torch.no_grad():\n            flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6))\n\n        # Warp the prediction to the ground truth coordinates\n        pred_warped = warp(pred, flow)\n\n        # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between\n        # the input and the ground truth\n        sr_factor = self.sr_factor\n        ds_factor = 1.0 / float(2.0 * sr_factor)\n        flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor\n\n        burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous()\n        burst_0_warped = warp(burst_0, flow_ds)\n        frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)\n\n        # Match the colorspace between the prediction and ground truth\n        pred_warped_m, valid = match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz,\n                                                      self.gauss_kernel)\n\n        # Ignore boundary pixels if specified\n        if self.boundary_ignore is not None:\n            pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore,\n                            self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_warped_m = pred_warped_m.contiguous()\n        gt = gt.contiguous()\n        # Estimate MSE\n        l1 = F.l1_loss(pred_warped_m, gt, reduction='none')\n\n        eps = 1e-12\n        elem_ratio = l1.numel() / valid.numel()\n        l1 = (l1 * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        return l1\n\nclass AlignedL2(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):\n        super().__init__()\n        self.sr_factor = sr_factor\n        self.boundary_ignore = boundary_ignore\n        self.alignment_net = alignment_net\n        self.loss_fn = lpips.LPIPS(net='alex').cuda()\n\n        self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5)\n\n    def forward(self, pred, gt, burst_input):\n        # Estimate flow between the prediction and the ground truth\n        with torch.no_grad():\n            flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6))\n\n        # Warp the prediction to the ground truth coordinates\n        pred_warped = warp(pred, flow)\n\n        # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between\n        # the input and the ground truth\n        sr_factor = self.sr_factor\n        ds_factor = 1.0 / float(2.0 * sr_factor)\n        flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor\n\n        burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous()\n        burst_0_warped = warp(burst_0, flow_ds)\n        frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)\n\n        # Match the colorspace between the prediction and ground truth\n        pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz,\n                                                      self.gauss_kernel)\n\n        # Ignore boundary pixels if specified\n        if self.boundary_ignore is not None:\n            pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore,\n                            self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        # Estimate MSE\n        mse = F.mse_loss(pred_warped_m.contiguous(), gt.contiguous(), reduction='none')\n\n        eps = 1e-12\n        elem_ratio = mse.numel() / valid.numel()\n        mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        ss = ssim(pred_warped_m.contiguous(), gt.contiguous(), data_range=1.0, size_average=True)\n        # eps = 1e-12\n        # elem_ratio = ss.numel() / valid.numel()\n        # ss = (ss * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        lp = self.loss_fn(pred_warped_m.contiguous(), gt.contiguous()).squeeze()\n\n        return mse, ss, lp\n\n\nclass AlignedPSNR(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None, max_value=1.0):\n        super().__init__()\n        self.l2 = AlignedL2(alignment_net=alignment_net, sr_factor=sr_factor, boundary_ignore=boundary_ignore)\n        self.max_value = max_value\n\n    def psnr(self, pred, gt, burst_input):\n        mse, ss, lp = self.l2(pred, gt, burst_input)\n\n        psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10()\n\n        return psnr, ss, lp\n\n    def forward(self, pred, gt, burst_input):\n        all_scores = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)]\n        psnr = sum([score[0] for score in all_scores]) / len(all_scores)\n        ssim_ = sum([score[1] for score in all_scores]) / len(all_scores)\n        lpips_ = sum([score[2] for score in all_scores]) / len(all_scores)\n        return psnr, ssim_, lpips_\n\n\n\nclass AlignedSSIM(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):\n        super().__init__()\n        self.sr_factor = sr_factor\n        self.boundary_ignore = boundary_ignore\n        self.alignment_net = alignment_net\n\n        self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5)\n\n    def _ssim(self, pred, gt, burst_input):\n        # Estimate flow between the prediction and the ground truth\n        with torch.no_grad():\n            flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6))\n\n        # Warp the prediction to the ground truth coordinates\n        pred_warped = warp(pred, flow)\n\n        # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between\n        # the input and the ground truth\n        sr_factor = self.sr_factor\n        ds_factor = 1.0 / float(2.0 * sr_factor)\n        flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor\n\n        burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous()\n        burst_0_warped = warp(burst_0, flow_ds)\n        frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)\n\n        # Match the colorspace between the prediction and ground truth\n        pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz,\n                                                      self.gauss_kernel)\n\n        # Ignore boundary pixels if specified\n        if self.boundary_ignore is not None:\n            pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore,\n                            self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        # Estimate MSE\n        mse = ssim(pred_warped_m.contiguous(), gt.contiguous(), data_range=1.0, size_average=True)\n        # print(mse.shape)\n        # eps = 1e-12\n        # elem_ratio = mse.numel() / valid.numel()\n        # mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        return mse\n\n    def forward(self, pred, gt, burst_input):\n        ssim_all = [self._ssim(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)]\n        _ssim = sum(ssim_all) / len(ssim_all)\n        return _ssim\n\n\nclass AlignedLPIPS(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):\n        super().__init__()\n        self.sr_factor = sr_factor\n        self.boundary_ignore = boundary_ignore\n        self.alignment_net = alignment_net\n        self.loss_fn = lpips.LPIPS(net='alex').cuda()\n\n        self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5)\n\n    def _lpips(self, pred, gt, burst_input):\n        # Estimate flow between the prediction and the ground truth\n        with torch.no_grad():\n            flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6))\n\n        # Warp the prediction to the ground truth coordinates\n        pred_warped = warp(pred, flow)\n\n        # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between\n        # the input and the ground truth\n        sr_factor = self.sr_factor\n        ds_factor = 1.0 / float(2.0 * sr_factor)\n        flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor\n\n        burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous()\n        burst_0_warped = warp(burst_0, flow_ds)\n        frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)\n\n        # Match the colorspace between the prediction and ground truth\n        pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz,\n                                                      self.gauss_kernel)\n\n        # Ignore boundary pixels if specified\n        if self.boundary_ignore is not None:\n            pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore,\n                            self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        # Estimate MSE\n        mse = self.loss_fn(pred_warped_m.contiguous(), gt.contiguous()).squeeze()\n        return mse\n\n    def forward(self, pred, gt, burst_input):\n        lpips_all = [self._lpips(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)]\n        _lpips = sum(lpips_all) / len(lpips_all)\n        return _lpips\n"
  },
  {
    "path": "code/real/bsrt/utils/postprocessing_functions.py",
    "content": "import torch\nimport numpy as np\nimport utils.data_format_utils as df_utils\nfrom data_processing.camera_pipeline import apply_gains, apply_ccm, apply_smoothstep, gamma_compression\n\n\nclass SimplePostProcess:\n    def __init__(self, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False):\n        self.gains = gains\n        self.ccm = ccm\n        self.gamma = gamma\n        self.smoothstep = smoothstep\n        self.return_np = return_np\n\n    def process(self, image, meta_info):\n        return process_linear_image_rgb(image, meta_info, self.gains, self.ccm, self.gamma,\n                                        self.smoothstep, self.return_np)\n\n\ndef process_linear_image_rgb(image, meta_info, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False):\n    if gains:\n        image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])\n\n    if ccm:\n        image = apply_ccm(image, meta_info['cam2rgb'])\n\n    if meta_info['gamma'] and gamma:\n        image = gamma_compression(image)\n\n    if meta_info['smoothstep'] and smoothstep:\n        image = apply_smoothstep(image)\n\n    image = image.clamp(0.0, 1.0)\n\n    if return_np:\n        image = df_utils.torch_to_npimage(image)\n    return image\n\n\nclass BurstSRPostProcess:\n    def __init__(self, no_white_balance=False, gamma=True, smoothstep=True, return_np=False):\n        self.no_white_balance = no_white_balance\n        self.gamma = gamma\n        self.smoothstep = smoothstep\n        self.return_np = return_np\n\n    def process(self, image, meta_info, external_norm_factor=None):\n        return process_burstsr_image_rgb(image, meta_info, external_norm_factor=external_norm_factor,\n                                         no_white_balance=self.no_white_balance, gamma=self.gamma,\n                                         smoothstep=self.smoothstep, return_np=self.return_np)\n\n\ndef process_burstsr_image_rgb(im, meta_info, return_np=False, external_norm_factor=None, gamma=True, smoothstep=True,\n                              no_white_balance=False):\n    im = im * meta_info.get('norm_factor', 1.0)\n\n    if not meta_info.get('black_level_subtracted', False):\n        im = (im - torch.tensor(meta_info['black_level'])[[0, 1, -1]].view(3, 1, 1).to(im.device))\n\n    if not meta_info.get('while_balance_applied', False) and not no_white_balance:\n        im = im * (meta_info['cam_wb'][[0, 1, -1]].view(3, 1, 1) / meta_info['cam_wb'][1]).to(im.device)\n\n    im_out = im\n\n    if external_norm_factor is None:\n        im_out = im_out / im_out.max()\n    else:\n        im_out = im_out / external_norm_factor\n\n    im_out = im_out.clamp(0.0, 1.0)\n\n    if gamma:\n        im_out = im_out ** (1.0 / 2.2)\n\n    if smoothstep:\n        # Smooth curve\n        im_out = 3 * im_out ** 2 - 2 * im_out ** 3\n\n    if return_np:\n        im_out = im_out.permute(1, 2, 0).cpu().numpy() * 255.0\n        im_out = im_out.astype(np.uint8)\n\n    return im_out\n"
  },
  {
    "path": "code/real/bsrt/utils/resize_right.py",
    "content": "import warnings\nfrom math import ceil\nimport interp_methods\n\n\nclass NoneClass:\n    pass\n\ntry:\n    import torch\n    from torch import nn\n    nnModuleWrapped = nn.Module\nexcept ImportError:\n    warnings.warn('No PyTorch found, will work only with Numpy')\n    torch = None\n    nnModuleWrapped = NoneClass\n\ntry:\n    import numpy\nexcept ImportError:\n    warnings.warn('No Numpy found, will work only with PyTorch')\n    numpy = None\n\n\nif numpy is None and torch is None:\n    raise ImportError(\"Must have either Numpy or PyTorch but both not found\")\n\n\ndef resize(input, scale_factors=None, out_shape=None,\n           interp_method=interp_methods.cubic, support_sz=None,\n           antialiasing=True):\n    # get properties of the input tensor\n    in_shape, n_dims = input.shape, input.ndim\n\n    # fw stands for framework that can be either numpy or torch,\n    # determined by the input type\n    fw = numpy if type(input) is numpy.ndarray else torch\n    eps = fw.finfo(fw.float32).eps\n\n    # set missing scale factors or output shapem one according to another,\n    # scream if both missing\n    scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,\n                                                    scale_factors, fw)\n\n    # sort indices of dimensions according to scale of each dimension.\n    # since we are going dim by dim this is efficient\n    sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])\n                                       for dim in sorted(range(n_dims),\n                                       key=lambda ind: scale_factors[ind])\n                                       if scale_factors[dim] != 1.]\n\n    # unless support size is specified by the user, it is an attribute\n    # of the interpolation method\n    if support_sz is None:\n        support_sz = interp_method.support_sz\n\n    # when using pytorch, we need to know what is the input tensor device\n    if fw is torch:\n        device = input.device\n\n    # output begins identical to input and changes with each iteration\n    output = input\n\n    # iterate over dims\n    for dim, scale_factor in sorted_filtered_dims_and_scales:\n\n        # get 1d set of weights and fields of view for each output location\n        # along this dim\n        field_of_view, weights = prepare_weights_and_field_of_view_1d(\n            dim, scale_factor, in_shape[dim], out_shape[dim], interp_method,\n            support_sz, antialiasing, fw, eps, device)\n\n        # multiply the weights by the values in the field of view and\n        # aggreagate\n        output = apply_weights(output, field_of_view, weights, dim, n_dims,\n                               fw)\n    return output\n\n\nclass ResizeLayer(nnModuleWrapped):\n    def __init__(self, in_shape, scale_factors=None, out_shape=None,\n                 interp_method=interp_methods.cubic, support_sz=None,\n                 antialiasing=True):\n        super(ResizeLayer, self).__init__()\n\n        # fw stands for framework, that can be either numpy or torch. since\n        # this is a torch layer, only one option in this case.\n        fw = torch\n        eps = fw.finfo(fw.float32).eps\n\n        # set missing scale factors or output shapem one according to another,\n        # scream if both missing\n        scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,\n                                                        scale_factors, fw)\n\n        # unless support size is specified by the user, it is an attribute\n        # of the interpolation method\n        if support_sz is None:\n            support_sz = interp_method.support_sz\n\n        self.n_dims = len(in_shape)\n\n        # sort indices of dimensions according to scale of each dimension.\n        # since we are going dim by dim this is efficient\n        self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])\n                                                for dim in\n                                                sorted(range(self.n_dims),\n                                                key=lambda ind:\n                                                scale_factors[ind])\n                                                if scale_factors[dim] != 1.]\n\n        # iterate over dims\n        field_of_view_list = []\n        weights_list = []\n        for dim, scale_factor in self.sorted_filtered_dims_and_scales:\n\n            # get 1d set of weights and fields of view for each output\n            # location along this dim\n            field_of_view, weights = prepare_weights_and_field_of_view_1d(\n                dim, scale_factor, in_shape[dim], out_shape[dim],\n                interp_method, support_sz, antialiasing, fw, eps, input.device)\n\n            # keep weights and fields of views for all dims\n            weights_list.append(nn.Parameter(weights, requires_grad=False))\n            field_of_view_list.append(nn.Parameter(field_of_view,\n                                      requires_grad=False))\n\n        self.field_of_view = nn.ParameterList(field_of_view_list)\n        self.weights = nn.ParameterList(weights_list)\n        self.in_shape = in_shape\n\n    def forward(self, input):\n        # output begins identical to input and changes with each iteration\n        output = input\n\n        for (dim, scale_factor), field_of_view, weights in zip(\n                self.sorted_filtered_dims_and_scales,\n                self.field_of_view,\n                self.weights):\n            # multiply the weights by the values in the field of view and\n            # aggreagate\n            output = apply_weights(output, field_of_view, weights, dim,\n                                   self.n_dims, torch)\n        return output\n\n\ndef prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,\n                                         interp_method, support_sz,\n                                         antialiasing, fw, eps, device=None):\n    # If antialiasing is taking place, we modify the window size and the\n    # interpolation method (see inside function)\n    interp_method, cur_support_sz = apply_antialiasing_if_needed(\n                                                             interp_method,\n                                                             support_sz,\n                                                             scale_factor,\n                                                             antialiasing)\n\n    # STEP 1- PROJECTED GRID: The non-integer locations of the projection of\n    # output pixel locations to the input tensor\n    projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device)\n\n    # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels\n    # that influence it\n    field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz,\n                                      fw, eps)\n\n    # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the\n    # field of view for each output pixel\n    weights = get_weights(interp_method, projected_grid, field_of_view)\n\n    return field_of_view, weights\n\n\ndef apply_weights(input, field_of_view, weights, dim, n_dims, fw):\n    # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying\n    # its set of weights with the pixel values in its field of view.\n    # We now multiply the fields of view with their matching weights.\n    # We do this by tensor multiplication and broadcasting.\n    # this step is separated to a different function, so that it can be\n    # repeated with the same calculated weights and fields.\n\n    # for this operations we assume the resized dim is the first one.\n    # so we transpose and will transpose back after multiplying\n    tmp_input = fw_swapaxes(input, dim, 0, fw)\n\n    # field_of_view is a tensor of order 2: for each output (1d location\n    # along cur dim)- a list of 1d neighbors locations.\n    # note that this whole operations is applied to each dim separately,\n    # this is why it is all in 1d.\n    # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:\n    # for each output pixel (this time indicated in all dims), these are the\n    # values of the neighbors in the 1d field of view. note that we only\n    # consider neighbors along the current dim, but such set exists for every\n    # multi-dim location, hence the final tensor order is image_dims+1.\n    neighbors = tmp_input[field_of_view]\n\n    # weights is an order 2 tensor: for each output location along 1d- a list\n    # of weighs matching the field of view. we augment it with ones, for\n    # broadcasting, so that when multiplies some tensor the weights affect\n    # only its first dim.\n    tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1)))\n\n    # now we simply multiply the weights with the neighbors, and then sum\n    # along the field of view, to get a single value per out pixel\n    tmp_output = (neighbors * tmp_weights).sum(1)\n\n    # we transpose back the resized dim to its original position\n    return fw_swapaxes(tmp_output, 0, dim, fw)\n\n\ndef set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):\n    # eventually we must have both scale-factors and out-sizes for all in/out\n    # dims. however, we support many possible partial arguments\n    if scale_factors is None and out_shape is None:\n        raise ValueError(\"either scale_factors or out_shape should be \"\n                         \"provided\")\n    if out_shape is not None:\n        # if out_shape has less dims than in_shape, we defaultly resize the\n        # first dims for numpy and last dims for torch\n        out_shape = (list(out_shape) + list(in_shape[:-len(out_shape)])\n                     if fw is numpy\n                     else list(in_shape[:-len(out_shape)]) + list(out_shape))\n        if scale_factors is None:\n            # if no scale given, we calculate it as the out to in ratio\n            # (not recomended)\n            scale_factors = [out_sz / in_sz for out_sz, in_sz\n                             in zip(out_shape, in_shape)]\n    if scale_factors is not None:\n        # by default, if a single number is given as scale, we assume resizing\n        # two dims (most common are images with 2 spatial dims)\n        scale_factors = (scale_factors\n                         if isinstance(scale_factors, (list, tuple))\n                         else [scale_factors, scale_factors])\n        # if less scale_factors than in_shape dims, we defaultly resize the\n        # first dims for numpy and last dims for torch\n        scale_factors = (list(scale_factors) + [1] *\n                         (len(in_shape) - len(scale_factors)) if fw is numpy\n                         else [1] * (len(in_shape) - len(scale_factors)) +\n                         list(scale_factors))\n        if out_shape is None:\n            # when no out_shape given, it is calculated by multiplying the\n            # scale by the in_shape (not recomended)\n            out_shape = [ceil(scale_factor * in_sz)\n                         for scale_factor, in_sz in\n                         zip(scale_factors, in_shape)]\n        # next line intentionally after out_shape determined for stability\n        scale_factors = [float(sf) for sf in scale_factors]\n    return scale_factors, out_shape\n\n\ndef get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):\n    # we start by having the ouput coordinates which are just integer locations\n    out_coordinates = fw.arange(out_sz)\n\n    # if using torch we need to match the grid tensor device to the input device\n    out_coordinates = fw_set_device(out_coordinates, device, fw)\n\n    # This is projecting the ouput pixel locations in 1d to the input tensor,\n    # as non-integer locations.\n    # the following fomrula is derived in the paper\n    # \"From Discrete to Continuous Convolutions\" by Shocher et al.\n    return (out_coordinates / scale_factor +\n            (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor))\n\n\ndef get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):\n    # for each output pixel, map which input pixels influence it, in 1d.\n    # we start by calculating the leftmost neighbor, using half of the window\n    # size (eps is for when boundary is exact int)\n    left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)\n\n    # then we simply take all the pixel centers in the field by counting\n    # window size pixels from the left boundary\n    ordinal_numbers = fw.arange(ceil(cur_support_sz - eps))\n    # in case using torch we need to match the device\n    ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw)\n    field_of_view = left_boundaries[:, None] + ordinal_numbers\n\n    # next we do a trick instead of padding, we map the field of view so that\n    # it would be like mirror padding, without actually padding\n    # (which would require enlarging the input tensor)\n    mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw)\n    field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])]\n    field_of_view = fw_set_device(field_of_view,projected_grid.device, fw)\n    return field_of_view\n\n\ndef get_weights(interp_method, projected_grid, field_of_view):\n    # the set of weights per each output pixels is the result of the chosen\n    # interpolation method applied to the distances between projected grid\n    # locations and the pixel-centers in the field of view (distances are\n    # directed, can be positive or negative)\n    weights = interp_method(projected_grid[:, None] - field_of_view)\n\n    # we now carefully normalize the weights to sum to 1 per each output pixel\n    sum_weights = weights.sum(1, keepdims=True)\n    sum_weights[sum_weights == 0] = 1\n    return weights / sum_weights\n\n\ndef apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,\n                                 antialiasing):\n    # antialiasing is \"stretching\" the field of view according to the scale\n    # factor (only for downscaling). this is low-pass filtering. this\n    # requires modifying both the interpolation (stretching the 1d\n    # function and multiplying by the scale-factor) and the window size.\n    if scale_factor >= 1.0 or not antialiasing:\n        return interp_method, support_sz\n    cur_interp_method = (lambda arg: scale_factor *\n                         interp_method(scale_factor * arg))\n    cur_support_sz = support_sz / scale_factor\n    return cur_interp_method, cur_support_sz\n\n\ndef fw_ceil(x, fw):\n    if fw is numpy:\n        return fw.int_(fw.ceil(x))\n    else:\n        return x.ceil().long()\n\n\ndef fw_cat(x, fw):\n    if fw is numpy:\n        return fw.concatenate(x)\n    else:\n        return fw.cat(x)\n\n\ndef fw_swapaxes(x, ax_1, ax_2, fw):\n    if fw is numpy:\n        return fw.swapaxes(x, ax_1, ax_2)\n    else:\n        return x.transpose(ax_1, ax_2)\n\ndef fw_set_device(x, device, fw):\n    if fw is numpy:\n        return x\n    else:\n        return x.to(device)\n"
  },
  {
    "path": "code/real/bsrt/utils/spatial_color_alignment.py",
    "content": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef gauss_1d(sz, sigma, center, end_pad=0, density=False):\n    \"\"\" Returns a 1-D Gaussian \"\"\"\n    k = torch.arange(-(sz-1)/2, (sz+1)/2 + end_pad).reshape(1, -1)\n    gauss = torch.exp(-1.0/(2*sigma**2) * (k - center.reshape(-1, 1))**2)\n    if density:\n        gauss /= math.sqrt(2*math.pi) * sigma\n    return gauss\n\n\ndef gauss_2d(sz, sigma, center, end_pad=(0, 0), density=False):\n    \"\"\" Returns a 2-D Gaussian \"\"\"\n    if isinstance(sigma, (float, int)):\n        sigma = (sigma, sigma)\n    if isinstance(sz, int):\n        sz = (sz, sz)\n\n    if isinstance(center, (list, tuple)):\n        center = torch.tensor(center).view(1, 2)\n\n    return gauss_1d(sz[0], sigma[0], center[:, 0], end_pad[0], density).reshape(center.shape[0], 1, -1) * \\\n           gauss_1d(sz[1], sigma[1], center[:, 1], end_pad[1], density).reshape(center.shape[0], -1, 1)\n\n\ndef get_gaussian_kernel(sd):\n    \"\"\" Returns a Gaussian kernel with standard deviation sd \"\"\"\n    ksz = int(4 * sd + 1)\n    assert ksz % 2 == 1\n    K = gauss_2d(ksz, sd, (0.0, 0.0), density=True)\n    K = K / K.sum()\n    return K.unsqueeze(0), ksz\n\n\ndef apply_kernel(im, ksz, gauss_kernel):\n    shape = im.shape\n    im = im.view(-1, 1, *im.shape[-2:])\n\n    pad = [ksz // 2, ksz // 2, ksz // 2, ksz // 2]\n    im = F.pad(im, pad, mode='reflect')\n    im_mean = F.conv2d(im, gauss_kernel).view(shape)\n    return im_mean\n\n\ndef match_colors(im_ref, im_q, im_test, ksz, gauss_kernel):\n    \"\"\" Estimates a color transformation matrix between im_ref and im_q. Applies the estimated transformation to\n        im_test\n    \"\"\"\n    gauss_kernel = gauss_kernel.to(im_ref.device)\n    bi = 5\n\n    # Apply Gaussian smoothing\n    im_ref_mean = apply_kernel(im_ref, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous()\n    im_q_mean = apply_kernel(im_q, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous()\n\n    im_ref_mean_re = im_ref_mean.view(*im_ref_mean.shape[:2], -1)\n    im_q_mean_re = im_q_mean.view(*im_q_mean.shape[:2], -1)\n\n    # Estimate color transformation matrix by minimizing the least squares error\n    c_mat_all = []\n    for ir, iq in zip(im_ref_mean_re, im_q_mean_re):\n        c = torch.lstsq(ir.t(), iq.t())\n        c = c.solution[:3]\n        c_mat_all.append(c)\n\n    c_mat = torch.stack(c_mat_all, dim=0)\n    im_q_mean_conv = torch.matmul(im_q_mean_re.permute(0, 2, 1), c_mat).permute(0, 2, 1)\n    im_q_mean_conv = im_q_mean_conv.view(im_q_mean.shape)\n\n    err = ((im_q_mean_conv - im_ref_mean) * 255.0).norm(dim=1)\n\n    thresh = 20\n\n    # If error is larger than a threshold, ignore these pixels\n    valid = err < thresh\n\n    pad = (im_q.shape[-1] - valid.shape[-1]) // 2\n    pad = [pad, pad, pad, pad]\n    valid = F.pad(valid, pad)\n\n    upsample_factor = im_test.shape[-1] / valid.shape[-1]\n    valid = F.interpolate(valid.unsqueeze(1).float(), scale_factor=upsample_factor, mode='bilinear', align_corners=False)\n    valid = valid > 0.9\n\n    # Apply the transformation to test image\n    im_test_re = im_test.view(*im_test.shape[:2], -1)\n    im_t_conv = torch.matmul(im_test_re.permute(0, 2, 1), c_mat).permute(0, 2, 1)\n    im_t_conv = im_t_conv.view(im_test.shape)\n\n    return im_t_conv, valid\n\n"
  },
  {
    "path": "code/real/bsrt/utils/stn.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass SpatialTransformer(nn.Module):\n    \"\"\"\n    [SpatialTransformer] represesents a spatial transformation block\n    that uses the output from the UNet to preform an grid_sample\n    https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n    \"\"\"\n    def __init__(self, size, mode='bilinear'):\n        \"\"\"\n        Instiatiate the block\n            :param size: size of input to the spatial transformer block\n            :param mode: method of interpolation for grid_sampler\n        \"\"\"\n        super(OldSpatialTransformer, self).__init__()\n        if isinstance(size, int):\n            size = (size, size)\n        # Create sampling grid\n        vectors = [ torch.arange(0, s) for s in size ]\n        grids = torch.meshgrid(vectors)\n        grid  = torch.stack(grids) # y, x, z\n        grid  = torch.unsqueeze(grid, 0)  #add batch\n        grid = grid.type(torch.FloatTensor)\n        self.register_buffer('grid', grid)\n\n        self.mode = mode\n\n    def forward(self, src, flow):\n        \"\"\"\n        Push the src and flow through the spatial transform block\n            :param src: the original moving image\n            :param flow: the output from the U-Net\n        \"\"\"\n        new_locs = self.grid + flow\n\n        shape = flow.shape[2:]\n\n        # Need to normalize grid values to [-1, 1] for resampler\n        for i in range(len(shape)):\n            new_locs[:,i,...] = 2*(new_locs[:,i,...]/(shape[i]-1) - 0.5)\n\n        if len(shape) == 2:\n            new_locs = new_locs.permute(0, 2, 3, 1)\n            new_locs = new_locs[..., [1,0]]\n        elif len(shape) == 3:\n            new_locs = new_locs.permute(0, 2, 3, 4, 1)\n            new_locs = new_locs[..., [2,1,0]]\n\n        return F.grid_sample(src, new_locs, mode=self.mode, align_corners=True)\n"
  },
  {
    "path": "code/real/bsrt/utils/warp.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef warp(feat, flow, mode='bilinear', padding_mode='zeros'):\n    \"\"\"\n    warp an image/tensor (im2) back to im1, according to the optical flow im1 --> im2\n\n    input flow must be in format (x, y) at every pixel\n    feat: [B, C, H, W] (im2)\n    flow: [B, 2, H, W] flow (x, y)\n\n    \"\"\"\n    B, C, H, W = feat.size()\n    # print(feat.device, flow.device)\n\n    # mesh grid\n    rowv, colv = torch.meshgrid([torch.arange(0.5, H + 0.5), torch.arange(0.5, W + 0.5)])\n    grid = torch.stack((colv, rowv), dim=0).unsqueeze(0).float().to(flow.device)\n    # print(grid.device, flow.device, feat.device)\n    # grid = grid.cuda()\n    grid = grid + flow\n\n    # scale grid to [-1,1]\n    grid_norm_c = 2.0 * grid[:, 0] / W - 1.0\n    grid_norm_r = 2.0 * grid[:, 1] / H - 1.0\n\n    grid_norm = torch.stack((grid_norm_c, grid_norm_r), dim=1).to(flow.device)\n\n    grid_norm = grid_norm.permute(0, 2, 3, 1)\n\n    output = F.grid_sample(feat, grid_norm, mode=mode, align_corners=False, padding_mode=padding_mode)\n\n    return output\n"
  },
  {
    "path": "code/real/bsrt/validate.py",
    "content": "\nimport cv2\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\nimport random\nimport utility\nfrom option import args\nimport torchvision.utils as tvutils\nfrom pwcnet.pwcnet import PWCNet\n\nfrom utils.postprocessing_functions import BurstSRPostProcess\nfrom datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image_batch, pack_raw_image\nfrom utils.metrics import AlignedPSNR\nfrom utils.data_format_utils import convert_dict\nfrom data_processing.camera_pipeline import demosaic\nimport model\n\nimport torch.multiprocessing as mp\nimport torch.backends.cudnn as cudnn\nimport torch.utils.data.distributed\nimport time\n\n\ncheckpoint = utility.checkpoint(args)\n\n\ndef main():\n    mp.spawn(main_worker, nprocs=1, args=(1, args))\n\n\ndef main_worker(local_rank, nprocs, args):\n    cudnn.benchmark = True\n    args.local_rank = local_rank\n    utility.setup(local_rank, nprocs)\n    torch.cuda.set_device(local_rank)\n\n    dataset = BurstSRDataset(root=args.root, burst_size=14, crop_sz=80, split='val')\n    # out_dir = 'val/ebsr_real'\n\n    _model = model.Model(args, checkpoint)\n\n    for param in _model.parameters():\n        param.requires_grad = False\n\n    alignment_net = PWCNet(load_pretrained=True,\n                           weights_path='./pwcnet/pwcnet-network-default.pth')\n    alignment_net = alignment_net.to('cuda')\n    for param in alignment_net.parameters():\n        param.requires_grad = False\n\n    aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40)\n\n    postprocess_fn = BurstSRPostProcess(return_np=True)\n\n    # os.makedirs(out_dir, exist_ok=True)\n\n    tt = []\n    psnrs, ssims, lpipss = [], [], []\n    for idx in tqdm(range(len(dataset))):\n        burst, gt, meta_info_burst, meta_info_gt = dataset[idx]\n        burst = burst.unsqueeze(0).cuda()\n        gt = gt.unsqueeze(0).cuda()\n\n        with torch.no_grad():\n            tic = time.time()\n            sr = _model(burst, 0).float()\n            toc = time.time()\n            tt.append(toc-tic)\n\n            # sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short()\n            # sr = sr_int.float() / (2 ** 14)\n\n            psnr, ssim, lpips = aligned_psnr_fn(sr, gt, burst)\n            psnrs.append(psnr.item())\n            ssims.append(ssim.item())\n            lpipss.append(lpips.item())\n\n        # os.makedirs(f'{out_dir}/{idx}', exist_ok=True)\n        # sr_ = postprocess_fn.process(sr[0], meta_info_burst)\n        # sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR)\n        # cv2.imwrite('{}/{}_sr.png'.format(out_dir, idx), sr_)\n\n        del burst\n        del sr\n        del gt\n\n\n    print(f'avg PSNR: {np.mean(psnrs):.6f}')\n    print(f'avg SSIM: {np.mean(ssims):.6f}')\n    print(f'avg LPIPS: {np.mean(lpipss):.6f}')\n    print(f' avg time: {np.mean(tt):.6f}')\n\n    # utility.cleanup()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/synthetic/bsrt/README.md",
    "content": "# BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (Synthetic)\n\n## Dependencies\n- OS: Ubuntu 18.04\n- Python: Python 3.7\n- nvidia :\n   - cuda: 10.1\n   - cudnn: 7.6.1\n- Other reference requirements\n\n## Quick Start\n1.Create a conda virtual environment and activate it\n```python3\nconda create -n pytorch_1.6 python=3.7\nsource activate pytorch_1.6\n```\n2.Install PyTorch and torchvision following the official instructions\n```python3\nconda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch\n```\n3.Install build requirements\n```python3\npip3 install -r requirements.txt\n```\n4.Install DCN\n```python3\ncd DCNv2\npython3 setup.py build develop # build\npython3 test.py # run examples and check\n```\n## Training\n```python3\n# Modify the root path of training dataset and model etc.\n# The number of GPUs should be more than 1\npython main.py --n_GPUs 8 --print_every 40 --lr 0.00003 --decay 50-100 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 32 --burst_size 14 --patch_size 256\n```\n## Test\n```python3\n# Modify the path of test dataset and the path of the trained model\npython test_synburst.py --n_GPUs 1 --model BSRT --model_level S --fp16 --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth --root /data/dataset/ntire21/burstsr/synthetic\n```"
  },
  {
    "path": "code/synthetic/bsrt/data_processing/__init__.py",
    "content": ""
  },
  {
    "path": "code/synthetic/bsrt/data_processing/camera_pipeline.py",
    "content": "import torch\nimport random\nimport math\nimport cv2 as cv\nimport numpy as np\nimport utils.data_format_utils as df_utils\n\"\"\" Based on http://timothybrooks.com/tech/unprocessing\nFunctions for forward and inverse camera pipeline. All functions input a torch float tensor of shape (c, h, w).\nAdditionally, some also support batch operations, i.e. inputs of shape (b, c, h, w)\n\"\"\"\n\n\ndef random_ccm():\n    \"\"\"Generates random RGB -> Camera color correction matrices.\"\"\"\n    # Takes a random convex combination of XYZ -> Camera CCMs.\n    xyz2cams = [[[1.0234, -0.2969, -0.2266],\n               [-0.5625, 1.6328, -0.0469],\n               [-0.0703, 0.2188, 0.6406]],\n              [[0.4913, -0.0541, -0.0202],\n               [-0.613, 1.3513, 0.2906],\n               [-0.1564, 0.2151, 0.7183]],\n              [[0.838, -0.263, -0.0639],\n               [-0.2887, 1.0725, 0.2496],\n               [-0.0627, 0.1427, 0.5438]],\n              [[0.6596, -0.2079, -0.0562],\n               [-0.4782, 1.3016, 0.1933],\n               [-0.097, 0.1581, 0.5181]]]\n\n    num_ccms = len(xyz2cams)\n    xyz2cams = torch.tensor(xyz2cams)\n\n    weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(0.0, 1.0)\n    weights_sum = weights.sum()\n    xyz2cam = (xyz2cams * weights).sum(dim=0) / weights_sum\n\n    # Multiplies with RGB -> XYZ to get RGB -> Camera CCM.\n    rgb2xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],\n                            [0.2126729, 0.7151522, 0.0721750],\n                            [0.0193339, 0.1191920, 0.9503041]])\n    rgb2cam = torch.mm(xyz2cam, rgb2xyz)\n\n    # Normalizes each row.\n    rgb2cam = rgb2cam / rgb2cam.sum(dim=-1, keepdims=True)\n    return rgb2cam\n\n\ndef random_gains():\n    \"\"\"Generates random gains for brightening and white balance.\"\"\"\n    # RGB gain represents brightening.\n    rgb_gain = 1.0 / random.gauss(mu=0.8, sigma=0.1)\n\n    # Red and blue gains represent white balance.\n    red_gain = random.uniform(1.9, 2.4)\n    blue_gain = random.uniform(1.5, 1.9)\n    return rgb_gain, red_gain, blue_gain\n\n\ndef apply_smoothstep(image):\n    \"\"\"Apply global tone mapping curve.\"\"\"\n    image_out = 3 * image**2 - 2 * image**3\n    return image_out\n\n\ndef invert_smoothstep(image):\n    \"\"\"Approximately inverts a global tone mapping curve.\"\"\"\n    image = image.clamp(0.0, 1.0)\n    return 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)\n\n\ndef gamma_expansion(image):\n    \"\"\"Converts from gamma to linear space.\"\"\"\n    # Clamps to prevent numerical instability of gradients near zero.\n    return image.clamp(1e-8) ** 2.2\n\n\ndef gamma_compression(image):\n    \"\"\"Converts from linear to gammaspace.\"\"\"\n    # Clamps to prevent numerical instability of gradients near zero.\n    return image.clamp(1e-8) ** (1.0 / 2.2)\n\n\ndef apply_ccm(image, ccm):\n    \"\"\"Applies a color correction matrix.\"\"\"\n    assert image.dim() == 3 and image.shape[0] == 3\n\n    shape = image.shape\n    image = image.view(3, -1)\n    ccm = ccm.to(image.device).type_as(image)\n\n    image = torch.mm(ccm, image)\n\n    return image.view(shape)\n\n\ndef apply_gains(image, rgb_gain, red_gain, blue_gain):\n    \"\"\"Inverts gains while safely handling saturated pixels.\"\"\"\n    assert image.dim() == 3 and image.shape[0] in [3, 4]\n\n    if image.shape[0] == 3:\n        gains = torch.tensor([red_gain, 1.0, blue_gain]) * rgb_gain\n    else:\n        gains = torch.tensor([red_gain, 1.0, 1.0, blue_gain]) * rgb_gain\n    gains = gains.view(-1, 1, 1)\n    gains = gains.to(image.device).type_as(image)\n\n    return (image * gains).clamp(0.0, 1.0)\n\n\ndef safe_invert_gains(image, rgb_gain, red_gain, blue_gain):\n    \"\"\"Inverts gains while safely handling saturated pixels.\"\"\"\n    assert image.dim() == 3 and image.shape[0] == 3\n\n    gains = torch.tensor([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain\n    gains = gains.view(-1, 1, 1)\n\n    # Prevents dimming of saturated pixels by smoothly masking gains near white.\n    gray = image.mean(dim=0, keepdims=True)\n    inflection = 0.9\n    mask = ((gray - inflection).clamp(0.0) / (1.0 - inflection)) ** 2.0\n\n    safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)\n    return image * safe_gains\n\n\ndef mosaic(image, mode='rggb'):\n    \"\"\"Extracts RGGB Bayer planes from an RGB image.\"\"\"\n    shape = image.shape\n    if image.dim() == 3:\n        image = image.unsqueeze(0)\n\n    if mode == 'rggb':\n        red = image[:, 0, 0::2, 0::2]\n        green_red = image[:, 1, 0::2, 1::2]\n        green_blue = image[:, 1, 1::2, 0::2]\n        blue = image[:, 2, 1::2, 1::2]\n        image = torch.stack((red, green_red, green_blue, blue), dim=1)\n    elif mode == 'grbg':\n        green_red = image[:, 1, 0::2, 0::2]\n        red = image[:, 0, 0::2, 1::2]\n        blue = image[:, 2, 0::2, 1::2]\n        green_blue = image[:, 1, 1::2, 1::2]\n\n        image = torch.stack((green_red, red, blue, green_blue), dim=1)\n\n    if len(shape) == 3:\n        return image.view((4, shape[-2] // 2, shape[-1] // 2))\n    else:\n        return image.view((-1, 4, shape[-2] // 2, shape[-1] // 2))\n\n\ndef demosaic(image):\n    assert isinstance(image, torch.Tensor)\n    image = image.clamp(0.0, 1.0) * 255\n\n    if image.dim() == 4:\n        num_images = image.dim()\n        batch_input = True\n    else:\n        num_images = 1\n        batch_input = False\n        image = image.unsqueeze(0)\n\n    # Generate single channel input for opencv\n    im_sc = torch.zeros((num_images, image.shape[-2] * 2, image.shape[-1] * 2, 1))\n    im_sc[:, ::2, ::2, 0] = image[:, 0, :, :]\n    im_sc[:, ::2, 1::2, 0] = image[:, 1, :, :]\n    im_sc[:, 1::2, ::2, 0] = image[:, 2, :, :]\n    im_sc[:, 1::2, 1::2, 0] = image[:, 3, :, :]\n\n    im_sc = im_sc.numpy().astype(np.uint8)\n\n    out = []\n\n    for im in im_sc:\n        # cv.imwrite('frames/tmp.png', im)\n        im_dem_np = cv.cvtColor(im, cv.COLOR_BAYER_BG2RGB)#_VNG)\n\n        # Convert to torch image\n        im_t = df_utils.npimage_to_torch(im_dem_np, input_bgr=False)\n        out.append(im_t)\n\n    if batch_input:\n        return torch.stack(out, dim=0)\n    else:\n        return out[0]\n\n\ndef random_noise_levels():\n    \"\"\"Generates random noise levels from a log-log linear distribution.\"\"\"\n    log_min_shot_noise = math.log(0.0001)\n    log_max_shot_noise = math.log(0.012)\n    log_shot_noise = random.uniform(log_min_shot_noise, log_max_shot_noise)\n    shot_noise = math.exp(log_shot_noise)\n\n    line = lambda x: 2.18 * x + 1.20\n    log_read_noise = line(log_shot_noise) + random.gauss(mu=0.0, sigma=0.26)\n    read_noise = math.exp(log_read_noise)\n    return shot_noise, read_noise\n\n\ndef add_noise(image, shot_noise=0.01, read_noise=0.0005):\n    \"\"\"Adds random shot (proportional to image) and read (independent) noise.\"\"\"\n    variance = image * shot_noise + read_noise\n    noise = torch.FloatTensor(image.shape).normal_().to(image.device)*variance.sqrt()\n    return image + noise\n\n\ndef process_linear_image_rgb(image, meta_info, return_np=False):\n    image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])\n    image = apply_ccm(image, meta_info['cam2rgb'])\n\n    if meta_info['gamma']:\n        image = gamma_compression(image)\n\n    if meta_info['smoothstep']:\n        image = apply_smoothstep(image)\n\n    image = image.clamp(0.0, 1.0)\n\n    if return_np:\n        image = df_utils.torch_to_npimage(image)\n    return image\n\n\ndef process_linear_image_raw(image, meta_info):\n    image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])\n    image = demosaic(image)\n    image = apply_ccm(image, meta_info['cam2rgb'])\n\n    if meta_info['gamma']:\n        image = gamma_compression(image)\n\n    if meta_info['smoothstep']:\n        image = apply_smoothstep(image)\n    return image.clamp(0.0, 1.0)\n"
  },
  {
    "path": "code/synthetic/bsrt/data_processing/synthetic_burst_generation.py",
    "content": "import torch\nimport random\nimport cv2\nimport numpy as np\nimport torch.nn.functional as F\nfrom data_processing.camera_pipeline import *\nfrom utils.data_format_utils import torch_to_numpy, numpy_to_torch\n\n\ndef random_crop(frames, crop_sz):\n    \"\"\" Extract a random crop of size crop_sz from the input frames. If the crop_sz is larger than the input image size,\n    then the largest possible crop of same aspect ratio as crop_sz will be extracted from frames, and upsampled to\n    crop_sz.\n    \"\"\"\n    if not isinstance(crop_sz, (tuple, list)):\n        crop_sz = (crop_sz, crop_sz)\n    crop_sz = torch.tensor(crop_sz).float()\n\n    shape = frames.shape\n\n    # Select scale_factor. Ensure the crop fits inside the image\n    max_scale_factor = torch.tensor(shape[-2:]).float() / crop_sz\n    max_scale_factor = max_scale_factor.min().item()\n\n    if max_scale_factor < 1.0:\n        scale_factor = max_scale_factor\n    else:\n        scale_factor = 1.0\n\n    # Extract the crop\n    orig_crop_sz = (crop_sz * scale_factor).floor()\n\n    assert orig_crop_sz[-2] <= shape[-2] and orig_crop_sz[-1] <= shape[-1], 'Bug in crop size estimation!'\n\n    r1 = random.randint(0, shape[-2] - orig_crop_sz[-2])\n    c1 = random.randint(0, shape[-1] - orig_crop_sz[-1])\n\n    r2 = r1 + orig_crop_sz[0].int().item()\n    c2 = c1 + orig_crop_sz[1].int().item()\n\n    frames_crop = frames[:, r1:r2, c1:c2]\n\n    # Resize to crop_sz\n    if scale_factor < 1.0:\n        frames_crop = F.interpolate(frames_crop.unsqueeze(0), size=crop_sz.int().tolist(), mode='bilinear', align_corners=False).squeeze(0)\n    return frames_crop\n\n\ndef rgb2rawburst(image, burst_size, downsample_factor=1, burst_transformation_params=None,\n                 image_processing_params=None, interpolation_type='bilinear'):\n    \"\"\" Generates a synthetic LR RAW burst from the input image. The input sRGB image is first converted to linear\n    sensor space using an inverse camera pipeline. A LR burst is then generated by applying random\n    transformations defined by burst_transformation_params to the input image, and downsampling it by the\n    downsample_factor. The generated burst is then mosaicekd and corrputed by random noise.\n    \"\"\"\n\n    if image_processing_params is None:\n        image_processing_params = {}\n\n    _defaults = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 'gamma': True, 'add_noise': True}\n    for k, v in _defaults.items():\n        if k not in image_processing_params:\n            image_processing_params[k] = v\n\n    # Sample camera pipeline params\n    if image_processing_params['random_ccm']:\n        rgb2cam = random_ccm()\n    else:\n        rgb2cam = torch.eye(3).float()\n    cam2rgb = rgb2cam.inverse()\n\n    # Sample gains\n    if image_processing_params['random_gains']:\n        rgb_gain, red_gain, blue_gain = random_gains()\n    else:\n        rgb_gain, red_gain, blue_gain = (1.0, 1.0, 1.0)\n\n    # Approximately inverts global tone mapping.\n    use_smoothstep = image_processing_params['smoothstep']\n    if use_smoothstep:\n        image = invert_smoothstep(image)\n\n    # Inverts gamma compression.\n    use_gamma = image_processing_params['gamma']\n    if use_gamma:\n        image = gamma_expansion(image)\n\n    # Inverts color correction.\n    image = apply_ccm(image, rgb2cam)\n\n    # Approximately inverts white balance and brightening.\n    image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain)\n\n    # Clip saturated pixels.\n    image = image.clamp(0.0, 1.0)\n\n    # Generate LR burst\n    image_burst_rgb, flow_vectors = single2lrburst(image, burst_size=burst_size,\n                                                   downsample_factor=downsample_factor,\n                                                   transformation_params=burst_transformation_params,\n                                                   interpolation_type=interpolation_type)\n\n    # mosaic\n    image_burst = mosaic(image_burst_rgb.clone())\n\n    # Add noise\n    if image_processing_params['add_noise']:\n        shot_noise_level, read_noise_level = random_noise_levels()\n        image_burst = add_noise(image_burst, shot_noise_level, read_noise_level)\n    else:\n        shot_noise_level = 0\n        read_noise_level = 0\n\n    # Clip saturated pixels.\n    image_burst = image_burst.clamp(0.0, 1.0)\n\n    meta_info = {'rgb2cam': rgb2cam, 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain,\n                 'blue_gain': blue_gain, 'smoothstep': use_smoothstep, 'gamma': use_gamma,\n                 'shot_noise_level': shot_noise_level, 'read_noise_level': read_noise_level}\n    return image_burst, image, image_burst_rgb, flow_vectors, meta_info\n\n\ndef get_tmat(image_shape, translation, theta, shear_values, scale_factors):\n    \"\"\" Generates a transformation matrix corresponding to the input transformation parameters \"\"\"\n    im_h, im_w = image_shape\n\n    t_mat = np.identity(3)\n\n    t_mat[0, 2] = translation[0]\n    t_mat[1, 2] = translation[1]\n    t_rot = cv2.getRotationMatrix2D((im_w * 0.5, im_h * 0.5), theta, 1.0)\n    t_rot = np.concatenate((t_rot, np.array([0.0, 0.0, 1.0]).reshape(1, 3)))\n\n    t_shear = np.array([[1.0, shear_values[0], -shear_values[0] * 0.5 * im_w],\n                        [shear_values[1], 1.0, -shear_values[1] * 0.5 * im_h],\n                        [0.0, 0.0, 1.0]])\n\n    t_scale = np.array([[scale_factors[0], 0.0, 0.0],\n                        [0.0, scale_factors[1], 0.0],\n                        [0.0, 0.0, 1.0]])\n\n    t_mat = t_scale @ t_rot @ t_shear @ t_mat\n\n    t_mat = t_mat[:2, :]\n\n    return t_mat\n\n\ndef single2lrburst(image, burst_size, downsample_factor=1, transformation_params=None,\n                   interpolation_type='bilinear'):\n    \"\"\" Generates a burst of size burst_size from the input image by applying random transformations defined by\n    transformation_params, and downsampling the resulting burst by downsample_factor.\n    \"\"\"\n\n    if interpolation_type == 'bilinear':\n        interpolation = cv2.INTER_LINEAR\n    elif interpolation_type == 'lanczos':\n        interpolation = cv2.INTER_LANCZOS4\n    else:\n        raise ValueError\n\n    normalize = False\n    if isinstance(image, torch.Tensor):\n        if image.max() < 2.0:\n            image = image * 255.0\n            normalize = True\n        image = torch_to_numpy(image).astype(np.uint8)\n\n    burst = []\n    sample_pos_inv_all = []\n\n    rvs, cvs = torch.meshgrid([torch.arange(0, image.shape[0]),\n                               torch.arange(0, image.shape[1])])\n\n    sample_grid = torch.stack((cvs, rvs, torch.ones_like(cvs)), dim=-1).float()\n\n    for i in range(burst_size):\n        if i == 0:\n            # For base image, do not apply any random transformations. We only translate the image to center the\n            # sampling grid\n            shift = (downsample_factor / 2.0) - 0.5\n            translation = (shift, shift)\n            theta = 0.0\n            shear_factor = (0.0, 0.0)\n            scale_factor = (1.0, 1.0)\n        else:\n            # Sample random image transformation parameters\n            max_translation = transformation_params.get('max_translation', 0.0)\n\n            if max_translation <= 0.01:\n                shift = (downsample_factor / 2.0) - 0.5\n                translation = (shift, shift)\n            else:\n                translation = (random.uniform(-max_translation, max_translation),\n                               random.uniform(-max_translation, max_translation))\n\n            max_rotation = transformation_params.get('max_rotation', 0.0)\n            theta = random.uniform(-max_rotation, max_rotation)\n\n            max_shear = transformation_params.get('max_shear', 0.0)\n            shear_x = random.uniform(-max_shear, max_shear)\n            shear_y = random.uniform(-max_shear, max_shear)\n            shear_factor = (shear_x, shear_y)\n\n            max_ar_factor = transformation_params.get('max_ar_factor', 0.0)\n            ar_factor = np.exp(random.uniform(-max_ar_factor, max_ar_factor))\n\n            max_scale = transformation_params.get('max_scale', 0.0)\n            scale_factor = np.exp(random.uniform(-max_scale, max_scale))\n\n            scale_factor = (scale_factor, scale_factor * ar_factor)\n\n        output_sz = (image.shape[1], image.shape[0])\n\n        # Generate a affine transformation matrix corresponding to the sampled parameters\n        t_mat = get_tmat((image.shape[0], image.shape[1]), translation, theta, shear_factor, scale_factor)\n        t_mat_tensor = torch.from_numpy(t_mat)\n\n        # Apply the sampled affine transformation\n        image_t = cv2.warpAffine(image, t_mat, output_sz, flags=interpolation,\n                                 borderMode=cv2.BORDER_CONSTANT)\n\n        t_mat_tensor_3x3 = torch.cat((t_mat_tensor.float(), torch.tensor([0.0, 0.0, 1.0]).view(1, 3)), dim=0)\n        t_mat_tensor_inverse = t_mat_tensor_3x3.inverse()[:2, :].contiguous()\n\n        sample_pos_inv = torch.mm(sample_grid.view(-1, 3), t_mat_tensor_inverse.t().float()).view(\n            *sample_grid.shape[:2], -1)\n\n        if transformation_params.get('border_crop') is not None:\n            border_crop = transformation_params.get('border_crop')\n\n            image_t = image_t[border_crop:-border_crop, border_crop:-border_crop, :]\n            sample_pos_inv = sample_pos_inv[border_crop:-border_crop, border_crop:-border_crop, :]\n\n        # Downsample the image\n        image_t = cv2.resize(image_t, None, fx=1.0 / downsample_factor, fy=1.0 / downsample_factor,\n                             interpolation=interpolation)\n        sample_pos_inv = cv2.resize(sample_pos_inv.numpy(), None, fx=1.0 / downsample_factor,\n                                    fy=1.0 / downsample_factor,\n                                    interpolation=interpolation)\n\n        sample_pos_inv = torch.from_numpy(sample_pos_inv).permute(2, 0, 1).contiguous()\n\n        if normalize:\n            image_t = numpy_to_torch(image_t).float() / 255.0\n        else:\n            image_t = numpy_to_torch(image_t).float()\n        burst.append(image_t)\n        sample_pos_inv_all.append(sample_pos_inv / downsample_factor)\n\n    burst_images = torch.stack(burst)\n    sample_pos_inv_all = torch.stack(sample_pos_inv_all)\n\n    # Compute the flow vectors to go from the i'th burst image to the base image\n    flow_vectors = sample_pos_inv_all - sample_pos_inv_all[:, :1, ...]\n\n    return burst_images, flow_vectors\n"
  },
  {
    "path": "code/synthetic/bsrt/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "code/synthetic/bsrt/datasets/burstsr_dataset.py",
    "content": "import os\nimport torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\nimport torch.nn.functional as F\nimport random\nimport time\n\nclass SamsungRAWImage:\n    @staticmethod\n    def load(path):\n        im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED)\n\n        im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16)\n        im_raw = torch.from_numpy(im_raw)\n\n        meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), \"rb\", -1))\n\n        return SamsungRAWImage(im_raw, meta_data['black_level'], meta_data['cam_wb'],\n                               meta_data['daylight_wb'], meta_data['color_matrix'], meta_data['exif_data'],\n                               meta_data.get('crop_info', None), meta_data.get('im_preview', None))\n\n    def __init__(self, im_raw, black_level, cam_wb, daylight_wb, color_matrix, exif_data, crop_info=None,\n                 im_preview=None):\n        self.im_raw = im_raw\n\n        self.black_level = black_level\n        self.cam_wb = cam_wb\n        self.daylight_wb = daylight_wb\n        self.color_matrix = color_matrix\n        self.exif_data = exif_data\n        self.crop_info = crop_info\n        self.im_preview = im_preview\n\n        self.norm_factor = 1023.0\n\n    def get_all_meta_data(self):\n        return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb,\n                'color_matrix': self.color_matrix.tolist()}\n\n    def get_exposure_time(self):\n        return self.exif_data['Image ExposureTime'].values[0].decimal()\n\n    def get_noise_profile(self):\n        noise = self.exif_data['Image Tag 0xC761'].values\n        noise = [n[0] for n in noise]\n        noise = np.array(noise).reshape(3, 2)\n        return noise\n\n    def get_f_number(self):\n        return self.exif_data['Image FNumber'].values[0].decimal()\n\n    def get_iso(self):\n        return self.exif_data['Image ISOSpeedRatings'].values[0]\n\n    def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False):\n        im_raw = self.im_raw.float()\n\n        if substract_black_level:\n            im_raw = im_raw - torch.tensor(self.black_level).view(4, 1, 1)\n\n        if white_balance:\n            im_raw = im_raw * torch.tensor(self.cam_wb).view(4, 1, 1)\n\n        if normalize:\n            im_raw = im_raw / self.norm_factor\n\n\n        return im_raw\n\n    def shape(self):\n        shape = (4, self.im_raw.shape[1], self.im_raw.shape[2])\n        return shape\n\n    def crop_image(self, r1, r2, c1, c2):\n        self.im_raw = self.im_raw[:, r1:r2, c1:c2]\n\n    def get_crop(self, r1, r2, c1, c2):\n        im_raw = self.im_raw[:, r1:r2, c1:c2]\n\n        if self.im_preview is not None:\n            im_preview = self.im_preview[2*r1:2*r2, 2*c1:2*c2]\n        else:\n            im_preview = None\n\n        return SamsungRAWImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.color_matrix,\n                               self.exif_data, im_preview=im_preview)\n\n    def postprocess(self, return_np=True, norm_factor=None):\n        # Convert to rgb\n        # im = torch.from_numpy(self.im_raw.astype(np.float32))\n        im = self.im_raw\n\n        im = (im - torch.tensor(self.black_level).view(4, 1, 1)) * torch.tensor(self.cam_wb).view(4, 1, 1)\n\n        if norm_factor is None:\n            im = im / im.max()\n        else:\n            im = im / norm_factor\n\n        im = torch.stack((im[0], (im[1] + im[2])/2, im[3]), dim=0)\n        # im = torch.stack((im[0], im[1], im[3]), dim=0)\n\n        im_out = im.clamp(0.0, 1.0)\n\n        if return_np:\n            im_out = im_out.permute(1, 2, 0).numpy() * 255.0\n            im_out = im_out.astype(np.uint8)\n        return im_out\n\n\nclass CanonImage:\n    @staticmethod\n    def load(path, split='train'):\n        im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED)\n        im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16)\n        im_raw = torch.from_numpy(im_raw)\n        meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), \"rb\", -1))\n\n        return CanonImage(im_raw.float(), meta_data['black_level'], meta_data['cam_wb'],\n                          meta_data['daylight_wb'], meta_data['rgb_xyz_matrix'], meta_data.get('exif_data', None),\n                          meta_data.get('crop_info', None))\n\n    def __init__(self, im_raw, black_level, cam_wb, daylight_wb, rgb_xyz_matrix, exif_data, crop_info=None):\n        super(CanonImage, self).__init__()\n        self.im_raw = im_raw\n\n        if len(black_level) == 4:\n            black_level = [black_level[0], black_level[1], black_level[3]]\n        self.black_level = black_level\n\n        if len(cam_wb) == 4:\n            cam_wb = [cam_wb[0], cam_wb[1], cam_wb[3]]\n        self.cam_wb = cam_wb\n\n        if len(daylight_wb) == 4:\n            daylight_wb = [daylight_wb[0], daylight_wb[1], daylight_wb[3]]\n        self.daylight_wb = daylight_wb\n\n        self.rgb_xyz_matrix = rgb_xyz_matrix\n        self.xyz_srgb_matrix = torch.tensor([3.2404542, -1.5371385, -0.4985314,\n                                             -0.9692660,  1.8760108,  0.0415560,\n                                             0.0556434, -0.2040259,  1.0572252]).view(3, 3)\n        self.exif_data = exif_data\n        self.crop_info = crop_info\n\n        self.norm_factor = 16383\n\n    def shape(self):\n        shape = (3, self.im_raw.shape[1], self.im_raw.shape[2])\n        return shape\n\n    def get_all_meta_data(self):\n        return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb,\n                'rgb_xyz_matrix': self.rgb_xyz_matrix.tolist(), 'crop_info': self.crop_info,\n                'norm_factor': self.norm_factor}\n\n    def get_exposure_time(self):\n        return self.exif_data['EXIF ExposureTime'].values[0].decimal()\n\n    def get_f_number(self):\n        return self.exif_data['EXIF FNumber'].values[0].decimal()\n\n    def get_iso(self):\n        return self.exif_data['EXIF ISOSpeedRatings'].values[0]\n\n    def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False):\n        im_raw = self.im_raw.float()\n\n        if substract_black_level:\n            im_raw = im_raw - torch.tensor(self.black_level).view(3, 1, 1)\n\n        if white_balance:\n            im_raw = im_raw * torch.tensor(self.cam_wb).view(3, 1, 1) / 1024.0\n\n        if normalize:\n            im_raw = im_raw / self.norm_factor\n\n        return im_raw\n\n    def set_image_data(self, im_data):\n        self.im_raw = im_data\n\n    def crop_image(self, r1, r2, c1, c2):\n        self.im_raw = self.im_raw[:, r1:r2, c1:c2]\n\n    def get_crop(self, r1, r2, c1, c2):\n        im_raw = self.im_raw[:, r1:r2, c1:c2]\n        return CanonImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.rgb_xyz_matrix,\n                          self.exif_data, self.crop_info)\n\n    def set_crop_info(self, crop_info):\n        self.crop_info = crop_info\n\n    def resize(self, size=None, scale_factor=None):\n\n        self.im_raw = F.interpolate(self.im_raw.unsqueeze(0), size=size, scale_factor=scale_factor,\n                                    mode='bilinear').squeeze(0)\n\n    def postprocess(self, return_np=True):\n        # Convert to rgb\n        im = self.im_raw\n\n        im = (im - torch.tensor(self.black_level).view(3, 1, 1)).float() * torch.tensor(self.cam_wb).view(3, 1, 1)\n\n        im_out = im / im.max()\n        im_out = im_out.clamp(0.0, 1.0)\n\n        if return_np:\n            im_out = im_out.permute(1, 2, 0).numpy() * 255.0\n            im_out = im_out.astype(np.uint8)\n        return im_out\n\n\ndef load_txt(path):\n    with open(path, 'r') as fh:\n        out = [d.rstrip() for d in fh.readlines()]\n\n    return out\n\n\nclass BurstSRDataset(torch.utils.data.Dataset):\n    \"\"\" Real-world burst super-resolution dataset. \"\"\"\n    def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='train'):\n        \"\"\"\n        args:\n            root : path of the root directory\n            burst_size : Burst size. Maximum allowed burst size is 14.\n            crop_sz: Size of the extracted crop. Maximum allowed crop size is 80\n            center_crop: Whether to extract a random crop, or a centered crop.\n            random_flip: Whether to apply random horizontal and vertical flip\n            split: Can be 'train' or 'val'\n        \"\"\"\n        assert burst_size <= 14, 'burst_sz must be less than or equal to 14'\n        assert crop_sz <= 80, 'crop_sz must be less than or equal to 80'\n        assert split in ['train', 'val']\n\n        root = root + '/' + split\n        super().__init__()\n\n        self.burst_size = burst_size\n        self.crop_sz = crop_sz\n        self.split = split\n        self.center_crop = center_crop\n        self.random_flip = random_flip\n\n        self.root = root\n\n        self.substract_black_level = True\n        self.white_balance = False\n\n        self.burst_list = self._get_burst_list()\n\n    def _get_burst_list(self):\n        burst_list = sorted(os.listdir('{}'.format(self.root)))\n        # print(burst_list)\n        return burst_list\n\n    def get_burst_info(self, burst_id):\n        burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]}\n        return burst_info\n\n    def _get_raw_image(self, burst_id, im_id):\n        raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id))\n        return raw_image\n\n    def _get_gt_image(self, burst_id):\n        canon_im = CanonImage.load('{}/{}/canon'.format(self.root, self.burst_list[burst_id]), split=self.split)\n        return canon_im\n\n    def get_burst(self, burst_id, im_ids, info=None):\n        frames = [self._get_raw_image(burst_id, i) for i in im_ids]\n\n        gt = self._get_gt_image(burst_id)\n        if info is None:\n            info = self.get_burst_info(burst_id)\n\n        return frames, gt, info\n\n    def _sample_images(self):\n        burst_size = 14\n\n        ids = random.sample(range(1, burst_size), k=self.burst_size - 1)\n        ids = [0, ] + ids\n        return ids\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def __getitem__(self, index):\n        # Sample the images in the burst, in case a burst_size < 14 is used.\n        im_ids = self._sample_images()\n\n        # Read the burst images along with HR ground truth\n        frames, gt, meta_info = self.get_burst(index, im_ids)\n\n        # Extract crop if needed\n        if frames[0].shape()[-1] != self.crop_sz:\n            if getattr(self, 'center_crop', False):\n                r1 = (frames[0].shape()[-2] - self.crop_sz) // 2\n                c1 = (frames[0].shape()[-1] - self.crop_sz) // 2\n            else:\n                r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz)\n                c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz)\n            r2 = r1 + self.crop_sz\n            c2 = c1 + self.crop_sz\n\n            scale_factor = gt.shape()[-1] // frames[0].shape()[-1]\n            frames = [im.get_crop(r1, r2, c1, c2) for im in frames]\n\n            gt = gt.get_crop(scale_factor * r1, scale_factor * r2, scale_factor * c1, scale_factor * c2)\n\n        # Load the RAW image data\n        burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level,\n                                              white_balance=self.white_balance) for im in frames]\n\n        # Convert to tensor\n        gt_image_data = gt.get_image_data(normalize=True, white_balance=self.white_balance,\n                                          substract_black_level=self.substract_black_level)\n\n        if self.random_flip:\n            burst_image_data = [flatten_raw_image(im) for im in burst_image_data]\n\n            pad = [0, 0, 0, 0]\n            if random.random() > 0.5:\n                burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data]\n                gt_image_data = gt_image_data.flip([2, ])[:, :, 2:-2].contiguous()\n                pad[1] = 1\n\n            if random.random() > 0.5:\n                burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data]\n                gt_image_data = gt_image_data.flip([1, ])[:, 2:-2, :].contiguous()\n                pad[3] = 1\n\n            burst_image_data = [pack_raw_image(im) for im in burst_image_data]\n            burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data]\n\n            gt_image_data = F.pad(gt_image_data.unsqueeze(0), [4 * p for p in pad], mode='replicate').squeeze(0)\n\n        burst_image_meta_info = frames[0].get_all_meta_data()\n\n        burst_image_meta_info['black_level_subtracted'] = self.substract_black_level\n        burst_image_meta_info['while_balance_applied'] = self.white_balance\n        burst_image_meta_info['norm_factor'] = frames[0].norm_factor\n\n        gt_image_meta_info = gt.get_all_meta_data()\n\n        burst = torch.stack(burst_image_data, dim=0)\n\n        burst_exposure = frames[0].get_exposure_time()\n        canon_exposure = gt.get_exposure_time()\n\n        burst_f_number = frames[0].get_f_number()\n        canon_f_number = gt.get_f_number()\n\n        burst_iso = frames[0].get_iso()\n        canon_iso = gt.get_iso()\n\n        # Normalize the GT image to account for differences in exposure, ISO etc\n        light_factor_burst = burst_exposure * burst_iso / (burst_f_number ** 2)\n        light_factor_canon = canon_exposure * canon_iso / (canon_f_number ** 2)\n\n        exp_scale_factor = (light_factor_burst / light_factor_canon)\n        gt_image_data = gt_image_data * exp_scale_factor\n\n        gt_image_meta_info['black_level_subtracted'] = self.substract_black_level\n        gt_image_meta_info['while_balance_applied'] = self.white_balance\n        gt_image_meta_info['norm_factor'] = gt.norm_factor / exp_scale_factor\n\n        burst_image_meta_info['exposure'] = burst_exposure\n        burst_image_meta_info['f_number'] = burst_f_number\n        burst_image_meta_info['iso'] = burst_iso\n\n        gt_image_meta_info['exposure'] = canon_exposure\n        gt_image_meta_info['f_number'] = canon_f_number\n        gt_image_meta_info['iso'] = canon_iso\n\n        burst = burst.float()\n        frame_gt = gt_image_data.float()\n\n        meta_info_burst = burst_image_meta_info\n        meta_info_gt = gt_image_meta_info\n\n        del meta_info_gt['crop_info']\n\n        for k, v in meta_info_gt.items():\n            if isinstance(v, (list, tuple)):\n                meta_info_gt[k] = torch.tensor(v)\n\n        for k, v in meta_info_burst.items():\n            if isinstance(v, (list, tuple)):\n                meta_info_burst[k] = torch.tensor(v)\n\n        meta_info_burst['burst_name'] = meta_info['burst_name']\n        \n        return burst, frame_gt, meta_info_burst, meta_info_gt\n\n\ndef pack_raw_image(im_raw):\n    if isinstance(im_raw, np.ndarray):\n        im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2))\n    elif isinstance(im_raw, torch.Tensor):\n        im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype).to(im_raw.device)\n    else:\n        raise Exception\n\n    im_out[0, :, :] = im_raw[0::2, 0::2]\n    im_out[1, :, :] = im_raw[0::2, 1::2]\n    im_out[2, :, :] = im_raw[1::2, 0::2]\n    im_out[3, :, :] = im_raw[1::2, 1::2]\n    return im_out\n\n\ndef flatten_raw_image(im_raw_4ch):\n    if isinstance(im_raw_4ch, np.ndarray):\n        im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2))\n    elif isinstance(im_raw_4ch, torch.Tensor):\n        im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device)\n    else:\n        raise Exception\n\n    im_out[0::2, 0::2] = im_raw_4ch[0, :, :]\n    im_out[0::2, 1::2] = im_raw_4ch[1, :, :]\n    im_out[1::2, 0::2] = im_raw_4ch[2, :, :]\n    im_out[1::2, 1::2] = im_raw_4ch[3, :, :]\n\n    return im_out\n\ndef pack_raw_image_batch(im_raw):\n    im_out = torch.zeros((im_raw.shape[0], im_raw.shape[1], 4, im_raw.shape[3] // 2, im_raw.shape[4] // 2), dtype=im_raw.dtype).to(im_raw.device)\n    im_out[:, :, 0, :, :] = im_raw[:, :, 0, 0::2, 0::2]\n    im_out[:, :, 1, :, :] = im_raw[:, :, 0, 0::2, 1::2]\n    im_out[:, :, 2, :, :] = im_raw[:, :, 0, 1::2, 0::2]\n    im_out[:, :, 3, :, :] = im_raw[:, :, 0, 1::2, 1::2]\n    return im_out\n\n\ndef flatten_raw_image_batch(im_raw_4ch):\n    im_out = torch.zeros((im_raw_4ch.shape[0], im_raw_4ch.shape[1], 1, im_raw_4ch.shape[3] * 2, im_raw_4ch.shape[4] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device)\n    im_out[:, :, 0, 0::2, 0::2] = im_raw_4ch[:, :, 0, :, :]\n    im_out[:, :, 0, 0::2, 1::2] = im_raw_4ch[:, :, 1, :, :]\n    im_out[:, :, 0, 1::2, 0::2] = im_raw_4ch[:, :, 2, :, :]\n    im_out[:, :, 0, 1::2, 1::2] = im_raw_4ch[:, :, 3, :, :]\n\n    return im_out\n"
  },
  {
    "path": "code/synthetic/bsrt/datasets/burstsr_test_dataset.py",
    "content": "import os\nimport torch\nimport torch.nn.functional as F\nimport random\nfrom .burstsr_dataset import SamsungRAWImage, flatten_raw_image, pack_raw_image\n\n\nclass BurstSRDataset(torch.utils.data.Dataset):\n    \"\"\" Real-world burst super-resolution dataset. \"\"\"\n    def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='test'):\n        \"\"\"\n        args:\n            root : path of the root directory\n            burst_size : Burst size. Maximum allowed burst size is 14.\n            crop_sz: Size of the extracted crop. Maximum allowed crop size is 80\n            center_crop: Whether to extract a random crop, or a centered crop.\n            random_flip: Whether to apply random horizontal and vertical flip\n            split: Can be 'train' or 'val'\n        \"\"\"\n        assert burst_size <= 14, 'burst_sz must be less than or equal to 14'\n        assert crop_sz <= 80, 'crop_sz must be less than or equal to 80'\n        assert split in ['test']\n\n        root = root + '/' + split\n        super().__init__()\n\n        self.burst_size = burst_size\n        self.crop_sz = crop_sz\n        self.split = split\n        self.center_crop = center_crop\n        self.random_flip = random_flip\n\n        self.root = root\n\n        self.substract_black_level = True\n        self.white_balance = False\n\n        self.burst_list = self._get_burst_list()\n\n    def _get_burst_list(self):\n        burst_list = sorted(os.listdir('{}'.format(self.root)))\n\n        return burst_list\n\n    def get_burst_info(self, burst_id):\n        burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]}\n        return burst_info\n\n    def _get_raw_image(self, burst_id, im_id):\n        raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id))\n        return raw_image\n\n    def get_burst(self, burst_id, im_ids, info=None):\n        frames = [self._get_raw_image(burst_id, i) for i in im_ids]\n\n        if info is None:\n            info = self.get_burst_info(burst_id)\n\n        return frames, info\n\n    def _sample_images(self):\n        burst_size = 14\n\n        ids = random.sample(range(1, burst_size), k=self.burst_size - 1)\n        ids = [0, ] + ids\n        return ids\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def __getitem__(self, index):\n        # Sample the images in the burst, in case a burst_size < 14 is used.\n        im_ids = self._sample_images()\n\n        # Read the burst images along with HR ground truth\n        frames, meta_info = self.get_burst(index, im_ids)\n\n        # Extract crop if needed\n        if frames[0].shape()[-1] != self.crop_sz:\n            if getattr(self, 'center_crop', False):\n                r1 = (frames[0].shape()[-2] - self.crop_sz) // 2\n                c1 = (frames[0].shape()[-1] - self.crop_sz) // 2\n            else:\n                r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz)\n                c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz)\n            r2 = r1 + self.crop_sz\n            c2 = c1 + self.crop_sz\n\n            frames = [im.get_crop(r1, r2, c1, c2) for im in frames]\n\n        # Load the RAW image data\n        burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level,\n                                              white_balance=self.white_balance) for im in frames]\n\n        if self.random_flip:\n            burst_image_data = [flatten_raw_image(im) for im in burst_image_data]\n\n            pad = [0, 0, 0, 0]\n            if random.random() > 0.5:\n                burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data]\n                pad[1] = 1\n\n            if random.random() > 0.5:\n                burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data]\n                pad[3] = 1\n\n            burst_image_data = [pack_raw_image(im) for im in burst_image_data]\n            burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data]\n\n        burst_image_meta_info = frames[0].get_all_meta_data()\n\n        burst_image_meta_info['black_level_subtracted'] = self.substract_black_level\n        burst_image_meta_info['while_balance_applied'] = self.white_balance\n        burst_image_meta_info['norm_factor'] = frames[0].norm_factor\n\n        burst = torch.stack(burst_image_data, dim=0)\n\n        burst_exposure = frames[0].get_exposure_time()\n\n        burst_f_number = frames[0].get_f_number()\n\n        burst_iso = frames[0].get_iso()\n\n        burst_image_meta_info['exposure'] = burst_exposure\n        burst_image_meta_info['f_number'] = burst_f_number\n        burst_image_meta_info['iso'] = burst_iso\n\n        burst = burst.float()\n\n        meta_info_burst = burst_image_meta_info\n\n        for k, v in meta_info_burst.items():\n            if isinstance(v, (list, tuple)):\n                meta_info_burst[k] = torch.tensor(v)\n\n        return burst, meta_info_burst"
  },
  {
    "path": "code/synthetic/bsrt/datasets/data_sampler.py",
    "content": "\"\"\"\nModified from torch.utils.data.distributed.DistributedSampler\nSupport enlarging the dataset for *iter-oriented* training, for saving time when restart the\ndataloader after each epoch\n\"\"\"\nimport math\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data.sampler import Sampler\n\n\nclass DistIterSampler(Sampler):\n    \"\"\"Sampler that restricts data loading to a subset of the dataset.\n\n    It is especially useful in conjunction with\n    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each\n    process can pass a DistributedSampler instance as a DataLoader sampler,\n    and load a subset of the original dataset that is exclusive to it.\n\n    .. note::\n        Dataset is assumed to be of constant size.\n\n    Arguments:\n        dataset: Dataset used for sampling.\n        num_replicas (optional): Number of processes participating in\n            distributed training.\n        rank (optional): Rank of the current process within num_replicas.\n    \"\"\"\n\n    def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):\n        if num_replicas is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            num_replicas = dist.get_world_size()\n        if rank is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            rank = dist.get_rank()\n        self.dataset = dataset\n        self.num_replicas = num_replicas\n        self.rank = rank\n        self.epoch = 0\n        self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))\n        self.total_size = self.num_samples * self.num_replicas\n\n    def __iter__(self):\n        # deterministically shuffle based on epoch\n        g = torch.Generator()\n        g.manual_seed(self.epoch)\n        indices = torch.randperm(\n            self.total_size, generator=g\n        ).tolist()  # Returns a random permutation of integers from 0 to n - 1\n\n        dsize = len(self.dataset)\n        indices = [v % dsize for v in indices]\n\n        # subsample\n        indices = indices[self.rank : self.total_size : self.num_replicas]\n        assert len(indices) == self.num_samples\n\n        return iter(indices)\n\n    def __len__(self):\n        return self.num_samples\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n"
  },
  {
    "path": "code/synthetic/bsrt/datasets/realworld_burst_test_set.py",
    "content": "import torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass RealWorldBurstTest(torch.utils.data.Dataset):\n    \"\"\"\n    \"\"\"\n    def __init__(self, root):\n        self.root = root\n        self.burst_list = list(range(20))\n        self.burst_size = 14\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def _read_burst_image(self, index, image_id):\n        im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)\n        im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)\n        return im_t\n\n    def __getitem__(self, index):\n        \"\"\"\n                args:\n                    index: Index of the burst\n\n                returns:\n                    burst: LR RAW burst, a torch tensor of shape\n                           The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.\n                    meta_info: Meta information about the burst\n                \"\"\"\n        burst_name = '{:04d}'.format(index)\n        burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]\n        burst = torch.stack(burst, 0)\n\n        meta_info = {}\n        meta_info['burst_name'] = burst_name\n\n        return burst, meta_info\n"
  },
  {
    "path": "code/synthetic/bsrt/datasets/synthetic_burst_test_set.py",
    "content": "import torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass SyntheticBurstTest(torch.utils.data.Dataset):\n    \"\"\" Synthetic burst test set. The test burst have been generated using the same synthetic pipeline as\n    employed in SyntheticBurst dataset.\n    \"\"\"\n    def __init__(self, root):\n        self.root = root\n        self.burst_list = list(range(92))\n        self.burst_size = 14\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def _read_burst_image(self, index, image_id):\n        im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)\n        im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)\n        return im_t\n\n    def __getitem__(self, index):\n        \"\"\" Generates a synthetic burst\n                args:\n                    index: Index of the burst\n\n                returns:\n                    burst: LR RAW burst, a torch tensor of shape\n                           The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.\n                    meta_info: Meta information about the burst\n                \"\"\"\n        burst_name = '{:04d}'.format(index)\n        burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]\n        burst = torch.stack(burst, 0)\n\n        meta_info = {}\n        meta_info['burst_name'] = burst_name\n\n        return burst, meta_info\n"
  },
  {
    "path": "code/synthetic/bsrt/datasets/synthetic_burst_train_set.py",
    "content": "import torch\nimport numpy as np\nfrom PIL import Image\nfrom data_processing.synthetic_burst_generation import rgb2rawburst, random_crop #syn_burst_utils\nimport torchvision.transforms as tfm\n\n\nclass SyntheticBurst(torch.utils.data.Dataset):\n    \"\"\" Synthetic burst dataset for joint denoising, demosaicking, and super-resolution. RAW Burst sequences are\n    synthetically generated on the fly as follows. First, a single image is loaded from the base_dataset. The sampled\n    image is converted to linear sensor space using the inverse camera pipeline employed in [1]. A burst\n    sequence is then generated by adding random translations and rotations to the converted image. The generated burst\n    is then converted is then mosaicked, and corrupted by random noise to obtain the RAW burst.\n\n    [1] Unprocessing Images for Learned Raw Denoising, Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen,\n    Jiawen and Sharlet, Dillon and Barron, Jonathan T, CVPR 2019\n    \"\"\"\n    def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=tfm.ToTensor()):\n        self.base_dataset = base_dataset\n\n        self.burst_size = burst_size\n        self.crop_sz = crop_sz\n        self.transform = transform\n\n        self.downsample_factor = 4\n        self.burst_transformation_params = {'max_translation': 24.0,\n                                            'max_rotation': 1.0,\n                                            'max_shear': 0.0,\n                                            'max_scale': 0.0,\n                                            'border_crop': 24}\n\n        self.image_processing_params = {'random_ccm': True, 'random_gains': True, 'smoothstep': True,\n                                        'gamma': True,\n                                        'add_noise': True}\n        self.interpolation_type = 'bilinear'\n\n    def __len__(self):\n        return len(self.base_dataset)\n\n    def __getitem__(self, index):\n        \"\"\" Generates a synthetic burst\n        args:\n            index: Index of the image in the base_dataset used to generate the burst\n\n        returns:\n            burst: Generated LR RAW burst, a torch tensor of shape\n                   [burst_size, 4, self.crop_sz / (2*self.downsample_factor), self.crop_sz / (2*self.downsample_factor)]\n                   The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.\n                   The extra factor 2 in the denominator (2*self.downsample_factor) corresponds to the mosaicking\n                   operation.\n\n            frame_gt: The HR RGB ground truth in the linear sensor space, a torch tensor of shape\n                      [3, self.crop_sz, self.crop_sz]\n\n            flow_vectors: The ground truth flow vectors between a burst image and the base image (i.e. the first image in the burst).\n                          The flow_vectors can be used to warp the burst images to the base frame, using the 'warp'\n                          function in utils.warp package.\n                          flow_vectors is torch tensor of shape\n                          [burst_size, 2, self.crop_sz / self.downsample_factor, self.crop_sz / self.downsample_factor].\n                          Note that the flow_vectors are in the LR RGB space, before mosaicking. Hence it has twice\n                          the number of rows and columns, compared to the output burst.\n\n                          NOTE: The flow_vectors are only available during training for the purpose of using any\n                                auxiliary losses if needed. The flow_vectors will NOT be provided for the bursts in the\n                                test set\n\n            meta_info: A dictionary containing the parameters used to generate the synthetic burst.\n        \"\"\"\n        frame = self.base_dataset[index]\n\n        # Augmentation, e.g. convert to tensor\n        if self.transform is not None:\n            # frame = Image.fromarray(frame)\n            frame = self.transform(frame)\n\n        # Extract a random crop from the image\n        crop_sz = self.crop_sz + 2 * self.burst_transformation_params.get('border_crop', 0)\n        frame_crop = random_crop(frame, crop_sz)\n\n        # Generate RAW burst\n        burst, frame_gt, burst_rgb, flow_vectors, meta_info = rgb2rawburst(frame_crop,\n                                                                           self.burst_size,\n                                                                           self.downsample_factor,\n                                                                           burst_transformation_params=self.burst_transformation_params,\n                                                                           image_processing_params=self.image_processing_params,\n                                                                           interpolation_type=self.interpolation_type\n                                                                           )\n\n        if self.burst_transformation_params.get('border_crop') is not None:\n            border_crop = self.burst_transformation_params.get('border_crop')\n            frame_gt = frame_gt[:, border_crop:-border_crop, border_crop:-border_crop]\n\n        return burst, frame_gt, flow_vectors, meta_info\n"
  },
  {
    "path": "code/synthetic/bsrt/datasets/synthetic_burst_val_set.py",
    "content": "import os\nimport torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass SyntheticBurstVal(torch.utils.data.Dataset):\n    \"\"\" Synthetic burst validation set introduced in [1]. The validation burst have been generated using a\n    synthetic data generation pipeline. The dataset can be downloaded from\n    https://data.vision.ee.ethz.ch/bhatg/SyntheticBurstVal.zip\n\n    [1] Deep Burst Super-Resolution. Goutam Bhat, Martin Danelljan, Luc Van Gool, and Radu Timofte. CVPR 2021\n    \"\"\"\n    def __init__(self, root=None, initialize=True):\n        \"\"\"\n        args:\n            root - Path to root dataset directory\n            initialize - boolean indicating whether to load the meta-data for the dataset\n        \"\"\"\n        self.root = os.path.join(root, 'val')\n        self.burst_list = list(range(300))\n        self.burst_size = 14\n\n    def initialize(self):\n        pass\n\n    def __len__(self):\n        return len(self.burst_list)\n\n    def _read_burst_image(self, index, image_id):\n        im = cv2.imread('{}/bursts/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)\n        im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)\n\n        return im_t\n\n    def _read_gt_image(self, index):\n        gt = cv2.imread('{}/gt/{:04d}/im_rgb.png'.format(self.root, index), cv2.IMREAD_UNCHANGED)\n        gt_t = (torch.from_numpy(gt.astype(np.float32)) / 2 ** 14).permute(2, 0, 1).float()\n        return gt_t\n\n    def _read_meta_info(self, index):\n        with open('{}/gt/{:04d}/meta_info.pkl'.format(self.root, index), \"rb\") as input_file:\n            meta_info = pkl.load(input_file)\n\n        return meta_info\n\n    def __getitem__(self, index):\n        \"\"\" Generates a synthetic burst\n        args:\n            index: Index of the burst\n\n        returns:\n            burst: LR RAW burst, a torch tensor of shape\n                   [14, 4, 48, 48]\n                   The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.\n            gt : Ground truth linear image\n            meta_info: Meta info about the burst which can be used to convert gt to sRGB space\n        \"\"\"\n        burst_name = '{:04d}'.format(index)\n        burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]\n        burst = torch.stack(burst, 0)\n\n        gt = self._read_gt_image(index)\n        meta_info = self._read_meta_info(index)\n        meta_info['burst_name'] = burst_name\n        return burst, gt, meta_info\n"
  },
  {
    "path": "code/synthetic/bsrt/datasets/zurich_raw2rgb_dataset.py",
    "content": "import torch\nimport os\nimport numpy as np\nfrom cv2 import imread\n\n\nclass ZurichRAW2RGB(torch.utils.data.Dataset):\n    \"\"\" Canon RGB images from the \"Zurich RAW to RGB mapping\" dataset. You can download the full\n    dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the\n    Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip\n    \"\"\"\n    def __init__(self, root, split='train'):\n        super().__init__()\n\n        if split in ['train', 'test']:\n            self.img_pth = os.path.join(root, split, 'canon')\n        else:\n            raise Exception('Unknown split {}'.format(split))\n\n        self.image_list = self._get_image_list(split)\n\n    def _get_image_list(self, split):\n        if split == 'train':\n            image_list = ['{:d}.jpg'.format(i) for i in range(46839)]\n        elif split == 'test':\n            # image_list = ['{:d}.jpg'.format(int(i)) for i in np.linspace(1, 1200, 400)]\n            image_list = ['{:d}.jpg'.format(i) for i in range(1200)]\n        else:\n            raise Exception\n\n        return image_list\n\n    def _get_image(self, im_id):\n        path = os.path.join(self.img_pth, self.image_list[im_id])\n        img = imread(path)\n        return img\n\n    def get_image(self, im_id):\n        frame = self._get_image(im_id)\n\n        return frame\n\n    def __len__(self):\n        return len(self.image_list)\n\n    def __getitem__(self, index):\n        frame = self._get_image(index)\n\n        return frame\n"
  },
  {
    "path": "code/synthetic/bsrt/demo.sh",
    "content": "#!/usr/bin/env bash\n\n\npython main.py --n_GPUs 8 --print_every 40 --lr 0.0001 --decay 100-200 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 32 --burst_size 14 --patch_size 256\n# python main.py --n_GPUs 8 --print_every 40 --lr 0.0001 --decay 100-200 --save bsrt_large --model BSRT --fp16 --model_level L --swinfeature --batch_size 16 --burst_size 14 --patch_size 256\n\n# python test_synburst.py --n_GPUs 1 --model BSRT --model_level S --fp16 --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth --root /data/dataset/ntire21/burstsr/synthetic\n# python test_synburst.py --n_GPUs 1 --model BSRT --model_level L --fp16 --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_large/bsrt_synburst.pth --root /data/dataset/ntire21/burstsr/synthetic\n"
  },
  {
    "path": "code/synthetic/bsrt/loss/Charbonnier.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass CharbonnierLoss(nn.Module):\n    \"\"\"L1 charbonnier loss.\"\"\"\n\n    def __init__(self, epsilon=1e-3, reduce=True):\n        super(CharbonnierLoss, self).__init__()\n        self.eps = epsilon * epsilon\n        self.reduce = reduce\n\n    def forward(self, X, Y):\n        diff = torch.add(X, -Y)\n        error = torch.sqrt(diff * diff + self.eps)\n        if self.reduce:\n            loss = torch.mean(error)\n        else:\n            loss = error\n        return loss"
  },
  {
    "path": "code/synthetic/bsrt/loss/__init__.py",
    "content": "import os\nfrom importlib import import_module\n\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Loss(nn.modules.loss._Loss):\n    def __init__(self, args, ckp):\n        super(Loss, self).__init__()\n        if args.local_rank == 0:\n            print('Preparing loss function:')\n\n        self.n_GPUs = args.n_GPUs\n        self.loss = []\n        self.loss_module = nn.ModuleList()\n        for loss in args.loss.split('+'):\n            weight, loss_type = loss.split('*')\n            if loss_type == 'MSE':\n                loss_function = nn.MSELoss()\n            elif loss_type == 'L1':\n                loss_function = nn.L1Loss()\n            elif loss_type.find('VGG') >= 0:\n                module = import_module('loss.vgg')\n                loss_function = getattr(module, 'VGG')(\n                    loss_type[3:],\n                    rgb_range=args.rgb_range\n                )\n            elif loss_type.find('GAN') >= 0:\n                module = import_module('loss.adversarial')\n                loss_function = getattr(module, 'Adversarial')(\n                    args,\n                    loss_type\n                )\n            elif loss_type == 'FILTER':\n                module = import_module('loss.filter')\n                loss_function = getattr(module, 'Filter')(args)\n            elif loss_type == 'SSIM':\n                module = import_module('loss.mssim')\n                loss_function = getattr(module, 'SSIM')(args)\n            elif loss_type == 'MSSSIM':\n                module = import_module('loss.mssim')\n                loss_function = getattr(module, 'MSSSIM')(args)\n\n            self.loss.append({\n                'type': loss_type,\n                'weight': float(weight),\n                'function': loss_function}\n            )\n            if loss_type.find('GAN') >= 0:\n                self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})\n\n        if len(self.loss) > 1:\n            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})\n\n        for l in self.loss:\n            if l['function'] is not None:\n                if args.local_rank == 0:\n                    print('{:.3f} * {}'.format(l['weight'], l['type']))\n                self.loss_module.append(l['function'])\n\n        self.log = torch.Tensor()\n\n        device = torch.device('cpu' if args.cpu else 'cuda')\n        self.loss_module.to(device)\n        if args.precision == 'half': self.loss_module.half()\n        if not args.cpu and args.n_GPUs > 1:\n            self.loss_module = nn.DataParallel(\n                self.loss_module, range(args.n_GPUs)\n            )\n\n        if args.load != '': self.load(ckp.dir, cpu=args.cpu)\n\n    def forward(self, sr, hr):\n        losses = []\n        for i, l in enumerate(self.loss):\n            if l['function'] is not None:\n                loss = l['function'](sr, hr)\n                effective_loss = l['weight'] * loss\n                losses.append(effective_loss)\n                self.log[-1, i] += effective_loss.item()\n            elif l['type'] == 'DIS':\n                self.log[-1, i] += self.loss[i - 1]['function'].loss\n\n        loss_sum = sum(losses)\n        if len(self.loss) > 1:\n            self.log[-1, -1] += loss_sum.item()\n\n        return loss_sum\n\n    def step(self):\n        for l in self.get_loss_module():\n            if hasattr(l, 'scheduler'):\n                l.scheduler.step()\n\n    def start_log(self):\n        self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))\n\n    def end_log(self, n_batches):\n        self.log[-1].div_(n_batches)\n\n    def display_loss(self, batch):\n        n_samples = batch + 1\n        log = []\n        for l, c in zip(self.loss, self.log[-1]):\n            log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))\n\n        return ''.join(log)\n\n    def plot_loss(self, apath, epoch):\n        axis = np.linspace(1, epoch, epoch)\n        for i, l in enumerate(self.loss):\n            label = '{} Loss'.format(l['type'])\n            fig = plt.figure()\n            plt.title(label)\n            plt.plot(axis, self.log[:, i].numpy(), label=label)\n            plt.legend()\n            plt.xlabel('Epochs')\n            plt.ylabel('Loss')\n            plt.grid(True)\n            plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))\n            plt.close(fig)\n\n    def get_loss_module(self):\n        if self.n_GPUs == 1:\n            return self.loss_module\n        else:\n            return self.loss_module.module\n\n    def save(self, apath):\n        torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))\n        torch.save(self.log, os.path.join(apath, 'loss_log.pt'))\n\n    def load(self, apath, cpu=False):\n        if cpu:\n            kwargs = {'map_location': lambda storage, loc: storage}\n        else:\n            kwargs = {}\n\n        self.load_state_dict(torch.load(\n            os.path.join(apath, 'loss.pt'),\n            **kwargs\n        ))\n        self.log = torch.load(os.path.join(apath, 'loss_log.pt'))\n        for l in self.get_loss_module():\n            if hasattr(l, 'scheduler'):\n                for _ in range(len(self.log)): l.scheduler.step()\n\n"
  },
  {
    "path": "code/synthetic/bsrt/loss/adversarial.py",
    "content": "import utility\nfrom types import SimpleNamespace\n\nfrom model import common\nfrom loss import discriminator\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nclass Adversarial(nn.Module):\n    def __init__(self, args, gan_type):\n        super(Adversarial, self).__init__()\n        self.gan_type = gan_type\n        self.gan_k = args.gan_k\n        self.dis = discriminator.Discriminator(args)\n        # if gan_type == 'WGAN_GP':\n        if True:\n            # see https://arxiv.org/pdf/1704.00028.pdf pp.4\n            optim_dict = {\n                'optimizer': 'ADAM',\n                'betas': (0.5, 0.9),\n                'epsilon': 1e-8,\n                'lr': 1e-5,\n                'weight_decay': args.weight_decay,\n                'decay': args.decay,\n                'gamma': args.gamma\n            }\n            optim_args = SimpleNamespace(**optim_dict)\n        else:\n            optim_args = args\n\n        self.optimizer = utility.make_optimizer(optim_args, self.dis)\n\n    def forward(self, fake, real):\n        # updating discriminator...\n        self.loss = 0\n        fake_detach = fake.detach()     # do not backpropagate through G\n        for _ in range(self.gan_k):\n            self.optimizer.zero_grad()\n            # d: B x 1 tensor\n            d_fake = self.dis(fake_detach)\n            d_real = self.dis(real)\n            retain_graph = False\n            if self.gan_type in ['GAN', 'SNGAN']:\n                loss_d = self.bce(d_real, d_fake)\n            elif self.gan_type.find('WGAN') >= 0:\n                loss_d = (d_fake - d_real).mean()\n                if self.gan_type.find('GP') >= 0:\n                    epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)\n                    hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)\n                    hat.requires_grad = True\n                    d_hat = self.dis(hat)\n                    gradients = torch.autograd.grad(\n                        outputs=d_hat.sum(), inputs=hat,\n                        retain_graph=True, create_graph=True, only_inputs=True\n                    )[0]\n                    gradients = gradients.view(gradients.size(0), -1)\n                    gradient_norm = gradients.norm(2, dim=1)\n                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()\n                    loss_d += gradient_penalty\n            # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks\n            elif self.gan_type == 'RGAN':\n                better_real = d_real - d_fake.mean(dim=0, keepdim=True)\n                better_fake = d_fake - d_real.mean(dim=0, keepdim=True)\n                loss_d = self.bce(better_real, better_fake)\n                retain_graph = True\n\n            # Discriminator update\n            self.loss += loss_d.item()\n            loss_d.backward(retain_graph=retain_graph)\n            self.optimizer.step()\n\n            if self.gan_type == 'WGAN':\n                for p in self.dis.parameters():\n                    p.data.clamp_(-1, 1)\n\n        self.loss /= self.gan_k\n\n        # updating generator...\n        d_fake_bp = self.dis(fake)      # for backpropagation, use fake as it is\n        if self.gan_type in ['GAN', 'SNGAN']:\n            label_real = torch.ones_like(d_fake_bp)\n            loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)\n        elif self.gan_type.find('WGAN') >= 0:\n            loss_g = -d_fake_bp.mean()\n        elif self.gan_type == 'RGAN':\n            better_real = d_real.detach() - d_fake_bp.mean(dim=0, keepdim=True)\n            better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True).detach()\n            loss_g = self.bce(better_fake, better_real)\n\n        # Generator loss\n        return loss_g\n\n    def state_dict(self, *args, **kwargs):\n        state_discriminator = self.dis.state_dict(*args, **kwargs)\n        state_optimizer = self.optimizer.state_dict()\n\n        return dict(**state_discriminator, **state_optimizer)\n\n    def bce(self, real, fake):\n        label_real = torch.ones_like(real)\n        label_fake = torch.zeros_like(fake)\n        bce_real = F.binary_cross_entropy_with_logits(real, label_real)\n        bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)\n        bce_loss = bce_real + bce_fake\n        return bce_loss\n\n# Some references\n# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py\n# OR\n# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py\n"
  },
  {
    "path": "code/synthetic/bsrt/loss/discriminator.py",
    "content": "from model import common\n\nimport torch.nn as nn\n\nclass Discriminator(nn.Module):\n    '''\n        output is not normalized\n    '''\n    def __init__(self, args, gan_type='GAN'):\n        super(Discriminator, self).__init__()\n\n        in_channels = args.n_colors\n        out_channels = 32\n        depth = 6\n\n        def _block(_in_channels, _out_channels, stride=1):\n\n            Conv = nn.Conv2d(\n                    _in_channels,\n                    _out_channels,\n                    3,\n                    padding=1,\n                    stride=stride,\n                    bias=False\n                )\n\n            if gan_type == 'SNGAN':\n                return nn.Sequential(\n                            spectral_norm(Conv),\n                            nn.BatchNorm2d(_out_channels),\n                            nn.LeakyReLU(negative_slope=0.2, inplace=True)\n                )\n            else:\n                return nn.Sequential(\n                    Conv,\n                    nn.BatchNorm2d(_out_channels),\n                    nn.LeakyReLU(negative_slope=0.2, inplace=True)\n                )\n\n        m_features = [_block(in_channels, out_channels)]\n        for i in range(depth):\n            in_channels = out_channels\n            # if i % 2 == 1:\n            #     stride = 1\n            #     out_channels *= 2\n            # else:\n            out_channels *= 2\n            stride = 2\n            m_features.append(_block(in_channels, out_channels, stride=stride))\n\n        patch_size = args.patch_size // 2**(depth-1)\n\n        # print(out_channels, patch_size)\n\n        m_classifier = [\n            nn.Flatten(),\n            nn.Linear(out_channels*patch_size**2, 512),\n            nn.LeakyReLU(0.2, True),\n            nn.Linear(512, 1)\n        ]\n\n        self.features = nn.Sequential(*m_features)\n        self.classifier = nn.Sequential(*m_classifier)\n\n    def forward(self, x):\n        features = self.features(x)\n        # print(features.shape)\n        output = self.classifier(features)\n\n        return output\n\n"
  },
  {
    "path": "code/synthetic/bsrt/loss/filter.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Filter(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n\n        kernel = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])\n        self.conv = nn.Conv2d(args.n_colors, args.n_colors, 3, 3)\n        with torch.no_grad():\n            self.conv.weight.copy_(kernel.float())\n        self.loss = nn.L1Loss()\n\n    def forward(self, x, y):\n        preds_x = self.conv(x)\n        preds_y = self.conv(y)\n\n        return self.loss(preds_x, preds_y)\n"
  },
  {
    "path": "code/synthetic/bsrt/loss/hist_entropy.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass HistEntropy(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n\n    def forward(self, x):\n        p = torch.softmax(x, dim=1)\n        logp = torch.log_softmax(x, dim=1)\n\n        entropy = (-p * logp).sum(dim=(2, 3)).mean()\n\n        return entropy\n"
  },
  {
    "path": "code/synthetic/bsrt/loss/mssim.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\n\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])\n    return gauss/gauss.sum()\n\n\ndef create_window(window_size, channel=1):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()\n    return window\n\n\ndef ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):\n    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).\n    if val_range is None:\n        if torch.max(img1) > 128:\n            max_val = 255\n        else:\n            max_val = 1\n\n        if torch.min(img1) < -0.5:\n            min_val = -1\n        else:\n            min_val = 0\n        L = max_val - min_val\n    else:\n        L = val_range\n\n    padd = 0\n    (_, channel, height, width) = img1.size()\n    if window is None:\n        real_size = min(window_size, height, width)\n        window = create_window(real_size, channel=channel).to(img1.device)\n\n    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2\n\n    C1 = (0.01 * L) ** 2\n    C2 = (0.03 * L) ** 2\n\n    v1 = 2.0 * sigma12 + C2\n    v2 = sigma1_sq + sigma2_sq + C2\n    cs = torch.mean(v1 / v2)  # contrast sensitivity\n\n    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)\n\n    if size_average:\n        ret = ssim_map.mean()\n    else:\n        ret = ssim_map.mean(1).mean(1).mean(1)\n\n    if full:\n        return ret, cs\n    return ret\n\n\ndef msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None):\n    device = img1.device\n    weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)\n    levels = weights.size()[0]\n    ssims = []\n    mcs = []\n    for _ in range(levels):\n        sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)\n\n        # Relu normalize (not compliant with original definition)\n        if normalize == \"relu\":\n            ssims.append(torch.relu(sim))\n            mcs.append(torch.relu(cs))\n        else:\n            ssims.append(sim)\n            mcs.append(cs)\n\n        img1 = F.avg_pool2d(img1, (2, 2))\n        img2 = F.avg_pool2d(img2, (2, 2))\n\n    ssims = torch.stack(ssims)\n    mcs = torch.stack(mcs)\n\n    # Simple normalize (not compliant with original definition)\n    # TODO: remove support for normalize == True (kept for backward support)\n    if normalize == \"simple\" or normalize == True:\n        ssims = (ssims + 1) / 2\n        mcs = (mcs + 1) / 2\n\n    pow1 = mcs ** weights\n    pow2 = ssims ** weights\n\n    # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/\n    output = torch.prod(pow1[:-1] * pow2[-1])\n    return output\n\n\n# Classes to re-use window\nclass SSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, val_range=None):\n        super(SSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.val_range = val_range\n\n        # Assume 1 channel for SSIM\n        self.channel = 1\n        self.window = create_window(window_size)\n\n    def forward(self, img1, img2):\n        (_, channel, _, _) = img1.size()\n\n        if channel == self.channel and self.window.dtype == img1.dtype:\n            window = self.window\n        else:\n            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)\n            self.window = window\n            self.channel = channel\n\n        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)\n\nclass MSSSIM(torch.nn.Module):\n    def __init__(self, window_size=11, size_average=True, channel=3):\n        super(MSSSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.channel = channel\n\n    def forward(self, img1, img2):\n        # TODO: store window between calls if possible\n        return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)"
  },
  {
    "path": "code/synthetic/bsrt/loss/vgg.py",
    "content": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models as models\n\nclass VGG(nn.Module):\n    def __init__(self, conv_index, rgb_range=1):\n        super(VGG, self).__init__()\n        vgg_features = models.vgg19(pretrained=True).features\n        modules = [m for m in vgg_features]\n        if conv_index.find('22') >= 0:\n            self.vgg = nn.Sequential(*modules[:8])\n        elif conv_index.find('54') >= 0:\n            self.vgg = nn.Sequential(*modules[:35])\n\n        vgg_mean = (0.485, 0.456, 0.406)\n        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)\n        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)\n        for p in self.parameters():\n            p.requires_grad = False\n\n    def forward(self, sr, hr):\n        def _forward(x):\n            # x = self.sub_mean(x)\n            x = self.vgg(x)\n            return x\n\n        sr = sr.repeat(1, 3, 1, 1)\n        hr = hr.repeat(1, 3, 1, 1)\n\n        vgg_sr = _forward(sr)\n        with torch.no_grad():\n            vgg_hr = _forward(hr.detach())\n\n        loss = F.mse_loss(vgg_sr, vgg_hr)\n\n        return loss\n"
  },
  {
    "path": "code/synthetic/bsrt/main.py",
    "content": "import torch\nimport random\nimport numpy as np\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms as T\n\nimport utility\nimport model\nimport loss\nfrom option import args\nfrom trainer import Trainer\nfrom datasets.synthetic_burst_train_set import SyntheticBurst\nfrom datasets.synthetic_burst_val_set import SyntheticBurstVal\nfrom datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB\nfrom datasets.data_sampler import DistIterSampler\nimport torch.multiprocessing as mp\nimport torch.backends.cudnn as cudnn\nimport torch.utils.data.distributed\n\n# torch.autograd.set_detect_anomaly(True)\n# torch.multiprocessing.set_sharing_strategy('file_system')\n\ndef init_seeds(seed=0, cuda_deterministic=True):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html\n    if cuda_deterministic:  # slower, more reproducible\n        cudnn.deterministic = True\n        cudnn.benchmark = False\n    else:  # faster, less reproducible\n        cudnn.deterministic = False\n        cudnn.benchmark = True\n\n\ncheckpoint = utility.checkpoint(args)\n\ndef main():\n    if args.n_GPUs > 1:\n        mp.spawn(main_worker, nprocs=args.n_GPUs, args=(args.n_GPUs, args), join=True)\n    else:\n        main_worker(0, args.n_GPUs, args)\n\n\ndef main_worker(local_rank, nprocs, args):\n    if checkpoint.ok:\n        args.local_rank = local_rank\n        if nprocs > 1:\n            init_seeds(local_rank+1)\n            cudnn.benchmark = True\n            utility.setup(local_rank, nprocs)\n        torch.cuda.set_device(args.local_rank)\n\n        batch_size = int(args.batch_size / nprocs)\n        train_zurich_raw2rgb = ZurichRAW2RGB(root=args.root, split='train')\n        train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=args.burst_size, crop_sz=args.patch_size)\n\n        # valid_zurich_raw2rgb = ZurichRAW2RGB(root=args.root, split='test')\n        # valid_data = SyntheticBurst(valid_zurich_raw2rgb, burst_size=14, crop_sz=1024)\n        valid_data = SyntheticBurstVal(root=args.root)\n\n        if local_rank <= 0:\n            print(f\"train data: {len(train_data)}, test data: {len(valid_data)}\")\n            print(f\"Test only: {args.test_only}\")\n\n        if nprocs > 1:\n            train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)\n            # train_sampler = DistIterSampler(train_data, nprocs, local_rank, 1)\n            valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data, shuffle=False)\n            train_loader = DataLoader(dataset=train_data, batch_size=batch_size, num_workers=args.batch_size,\n                                      pin_memory=True, drop_last=True, sampler=train_sampler)\n            valid_loader = DataLoader(dataset=valid_data, batch_size=1, num_workers=1,\n                                      pin_memory=True, drop_last=True, sampler=valid_sampler)\n        else:\n            train_sampler = None\n            train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=8,\n                                    shuffle=True, pin_memory=True, drop_last=True)  # args.cpus\n            valid_loader = DataLoader(dataset=valid_data, batch_size=args.batch_size, num_workers=4, shuffle=False,\n                                    pin_memory=True, drop_last=True)  # args.cpus\n\n        _model = model.Model(args, checkpoint)\n        _loss = loss.Loss(args, checkpoint) if not args.test_only else None\n        t = Trainer(args, train_loader, train_sampler, valid_loader, _model, _loss, checkpoint)\n        while not t.terminate():\n            t.train()\n\n        del _model\n        del _loss\n        del train_loader\n        del valid_loader\n\n        # utility.cleanup()\n\n        checkpoint.done()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2019, Charles Shang\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its\n   contributors may be used to endorse or promote products derived from\n   this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/README.md",
    "content": "## Deformable Convolutional Networks V2 with Pytorch 1.0\n\n### Build\n```bash\n    ./make.sh         # build\n    python test.py    # run examples and gradient check \n```\n\n### An Example\n- deformable conv\n```python\n    from dcn_v2 import DCN\n    input = torch.randn(2, 64, 128, 128).cuda()\n    # wrap all things (offset and mask) in DCN\n    dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()\n    output = dcn(input)\n    print(output.shape)\n```\n- deformable roi pooling\n```python\n    from dcn_v2 import DCNPooling\n    input = torch.randn(2, 32, 64, 64).cuda()\n    batch_inds = torch.randint(2, (20, 1)).cuda().float()\n    x = torch.randint(256, (20, 1)).cuda().float()\n    y = torch.randint(256, (20, 1)).cuda().float()\n    w = torch.randint(64, (20, 1)).cuda().float()\n    h = torch.randint(64, (20, 1)).cuda().float()\n    rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)\n\n    # mdformable pooling (V2)\n    # wrap all things (offset and mask) in DCNPooling\n    dpooling = DCNPooling(spatial_scale=1.0 / 4,\n                         pooled_size=7,\n                         output_dim=32,\n                         no_trans=False,\n                         group_size=1,\n                         trans_std=0.1).cuda()\n\n    dout = dpooling(input, rois)\n```\n### Note\nNow the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,\n```bash\ngit checkout pytorch_0.4\n```\n\n### Known Issues:\n\n- [x] Gradient check w.r.t offset (solved)\n- [ ] Backward is not reentrant (minor)\n\nThis is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).\n\n<s>I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.\nHowever, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some\nnon-differential points? </s>\n\nUpdate: all gradient check passes with double precision. \n\nAnother issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for \nfloat `<1e-15` for double), \nso it may not be a serious problem (?)\n\nPlease post an issue or PR if you have any comments.\n    "
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/__init__.py",
    "content": ""
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/dcn_v2.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import absolute_import, division, print_function\n\nimport math\n\nimport torch\nfrom torch import nn\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\nfrom torch.nn.modules.utils import _pair\nfrom torch.cuda.amp import custom_fwd, custom_bwd\n# from apex import amp\n\nimport _ext as _backend\n\n\nclass _DCNv2(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    # @amp.float_function\n    def forward(\n        ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups\n    ):\n        ctx.stride = _pair(stride)\n        ctx.padding = _pair(padding)\n        ctx.dilation = _pair(dilation)\n        ctx.kernel_size = _pair(weight.shape[2:4])\n        ctx.deformable_groups = deformable_groups\n        output = _backend.dcn_v2_forward(\n            input,\n            weight,\n            bias,\n            offset,\n            mask,\n            ctx.kernel_size[0],\n            ctx.kernel_size[1],\n            ctx.stride[0],\n            ctx.stride[1],\n            ctx.padding[0],\n            ctx.padding[1],\n            ctx.dilation[0],\n            ctx.dilation[1],\n            ctx.deformable_groups,\n        )\n        ctx.save_for_backward(input, offset, mask, weight, bias)\n        return output\n\n    @staticmethod\n    @once_differentiable\n    @custom_bwd\n    # @amp.float_function\n    def backward(ctx, grad_output):\n        input, offset, mask, weight, bias = ctx.saved_tensors\n        grad_input, grad_offset, grad_mask, grad_weight, grad_bias = _backend.dcn_v2_backward(\n            input,\n            weight,\n            bias,\n            offset,\n            mask,\n            grad_output,\n            ctx.kernel_size[0],\n            ctx.kernel_size[1],\n            ctx.stride[0],\n            ctx.stride[1],\n            ctx.padding[0],\n            ctx.padding[1],\n            ctx.dilation[0],\n            ctx.dilation[1],\n            ctx.deformable_groups,\n        )\n\n        return grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None\n\n    @staticmethod\n    def symbolic(\n        g, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups\n    ):\n        from torch.nn.modules.utils import _pair\n\n        stride = _pair(stride)\n        padding = _pair(padding)\n        dilation = _pair(dilation)\n        # as of trt 7, the dcn operation will be translated again by modifying the onnx file\n        # so the exporting code is kept to resemble the forward()\n        return g.op(\n            \"DCNv2_2\",\n            input,\n            offset,\n            mask,\n            weight,\n            bias,\n            stride_i=stride,\n            padding_i=padding,\n            dilation_i=dilation,\n            deformable_groups_i=deformable_groups,\n        )\n\n\ndcn_v2_conv = _DCNv2.apply\n\n\nclass DCNv2(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride,\n        padding,\n        dilation=1,\n        deformable_groups=1,\n    ):\n        super(DCNv2, self).__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = _pair(kernel_size)\n        self.stride = _pair(stride)\n        self.padding = _pair(padding)\n        self.dilation = _pair(dilation)\n        self.deformable_groups = deformable_groups\n\n        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))\n        self.bias = nn.Parameter(torch.Tensor(out_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        n = self.in_channels\n        for k in self.kernel_size:\n            n *= k\n        stdv = 1.0 / math.sqrt(n)\n        self.weight.data.uniform_(-stdv, stdv)\n        self.bias.data.zero_()\n\n    def forward(self, input, offset, mask):\n        assert (\n            2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1]\n            == offset.shape[1]\n        )\n        assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1]\n        return dcn_v2_conv(\n            input,\n            offset,\n            mask,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.deformable_groups,\n        )\n\n\nclass DCN(DCNv2):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride,\n        padding,\n        dilation=1,\n        deformable_groups=1,\n    ):\n        super(DCN, self).__init__(\n            in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups\n        )\n\n        channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]\n        self.conv_offset_mask = nn.Conv2d(\n            self.in_channels,\n            channels_,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            bias=True,\n        )\n        self.init_offset()\n\n    def init_offset(self):\n        self.conv_offset_mask.weight.data.zero_()\n        self.conv_offset_mask.bias.data.zero_()\n\n    def forward(self, input):\n        out = self.conv_offset_mask(input)\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n        offset = torch.cat((o1, o2), dim=1)\n        mask = torch.sigmoid(mask)\n        return dcn_v2_conv(\n            input,\n            offset,\n            mask,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.deformable_groups,\n        )\n\n\nclass DCN_sep(DCNv2):\n    '''Use other features to generate offsets and masks'''\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 stride,\n                 padding,\n                 dilation=1,\n                 deformable_groups=1):\n        super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding,\n                                      dilation, deformable_groups)\n\n        channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]\n        self.conv_offset_mask = nn.Conv2d(\n            self.in_channels,\n            channels_,\n            kernel_size=self.kernel_size,\n            stride=self.stride,\n            padding=self.padding,\n            bias=True)\n        self.init_offset()\n\n    def init_offset(self):\n        self.conv_offset_mask.weight.data.zero_()\n        self.conv_offset_mask.bias.data.zero_()\n\n    def forward(self, input, fea):\n        '''input: input features for deformable conv\n        fea: other features used for generating offsets and mask'''\n        out = self.conv_offset_mask(fea)\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n        offset = torch.cat((o1, o2), dim=1)\n        # offset = torch.clamp(offset, -100, 100)\n\n        offset_mean = torch.mean(torch.abs(offset))\n        if offset_mean > 250:\n            print('Offset mean is {}, larger than 100.'.format(offset_mean))\n            # return None\n            # offset[offset>=150] = 1e-3\n            # offset = offset.clamp(-50, 50)\n\n        mask = torch.sigmoid(mask)\n        return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,\n                           self.dilation, self.deformable_groups)\n\n\nclass FlowGuidedDCN(DCNv2):\n    '''Use other features to generate offsets and masks'''\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 stride,\n                 padding,\n                 dilation=1,\n                 deformable_groups=1):\n        super(FlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding,\n                                      dilation, deformable_groups)\n\n        channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]\n        self.conv_offset_mask = nn.Conv2d(\n            in_channels, channels_, kernel_size, stride, padding, bias=True)\n\n        self.init_offset()\n\n    def init_offset(self):\n        self.conv_offset_mask.weight.data.zero_()\n        self.conv_offset_mask.bias.data.zero_()\n\n    def forward(self, input, fea, flows):\n        '''input: input features for deformable conv: N, C, H, W.\n           fea: other features used for generating offsets and mask: N, C, H, W.\n           flows: N, 2, H, W.\n        '''\n        out = self.conv_offset_mask(fea)\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n\n        offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude\n        offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1)\n\n        offset_mean = torch.mean(torch.abs(offset))\n        if offset_mean > 250:\n            print('FlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean))\n            # offset = offset.clamp(-50, 50)\n            # return None\n\n        mask = torch.sigmoid(mask)\n        return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,\n                           self.dilation, self.deformable_groups)\n\n\n\nclass InsideFlowGuidedDCN(DCNv2):\n    '''Use other features to generate offsets and masks'''\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 kernel_size,\n                 stride,\n                 padding,\n                 dilation=1,\n                 deformable_groups=1):\n        super(InsideFlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding,\n                                      dilation, deformable_groups)\n\n        channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]\n        self.conv_offset_mask = nn.Sequential(\n            nn.Conv2d(in_channels*2+2, out_channels, kernel_size, stride, padding, bias=True),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=True),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(out_channels, channels_, kernel_size, stride, padding, bias=True)\n        )\n\n        self.reset_parameters()\n        self.init_offset()\n\n    def reset_parameters(self):\n        n = self.in_channels\n        for k in self.kernel_size:\n            n *= k\n        stdv = 1.0 / math.sqrt(n)\n        self.weight.data.uniform_(-stdv, stdv)\n        self.bias.data.zero_()\n\n\n    def init_offset(self):\n        self.conv_offset_mask[-1].weight.data.zero_()\n        self.conv_offset_mask[-1].bias.data.zero_()\n\n    def forward(self, input, warped, ref, flows):\n        '''input: input features for deformable conv: N, C, H, W.\n           fea: other features used for generating offsets and mask: N, C, H, W.\n           flows: N, 2, H, W.\n        '''\n        out = self.conv_offset_mask(torch.cat([warped, ref, flows], dim=1))\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n\n        offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude\n        offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1)\n\n        offset_mean = torch.mean(torch.abs(offset))\n        if offset_mean > 250:\n            print('InsideFlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean))\n            print('flow mean is {}'.format(torch.abs(flows).mean()))\n            offset = offset.clamp(-50, 50)\n            # return None\n\n        mask = torch.sigmoid(mask)\n        return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,\n                           self.dilation, self.deformable_groups)\n\n\n\nclass _DCNv2Pooling(Function):\n    @staticmethod\n    def forward(\n        ctx,\n        input,\n        rois,\n        offset,\n        spatial_scale,\n        pooled_size,\n        output_dim,\n        no_trans,\n        group_size=1,\n        part_size=None,\n        sample_per_part=4,\n        trans_std=0.0,\n    ):\n        ctx.spatial_scale = spatial_scale\n        ctx.no_trans = int(no_trans)\n        ctx.output_dim = output_dim\n        ctx.group_size = group_size\n        ctx.pooled_size = pooled_size\n        ctx.part_size = pooled_size if part_size is None else part_size\n        ctx.sample_per_part = sample_per_part\n        ctx.trans_std = trans_std\n\n        output, output_count = _backend.dcn_v2_psroi_pooling_forward(\n            input,\n            rois,\n            offset,\n            ctx.no_trans,\n            ctx.spatial_scale,\n            ctx.output_dim,\n            ctx.group_size,\n            ctx.pooled_size,\n            ctx.part_size,\n            ctx.sample_per_part,\n            ctx.trans_std,\n        )\n        ctx.save_for_backward(input, rois, offset, output_count)\n        return output\n\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_output):\n        input, rois, offset, output_count = ctx.saved_tensors\n        grad_input, grad_offset = _backend.dcn_v2_psroi_pooling_backward(\n            grad_output,\n            input,\n            rois,\n            offset,\n            output_count,\n            ctx.no_trans,\n            ctx.spatial_scale,\n            ctx.output_dim,\n            ctx.group_size,\n            ctx.pooled_size,\n            ctx.part_size,\n            ctx.sample_per_part,\n            ctx.trans_std,\n        )\n\n        return grad_input, None, grad_offset, None, None, None, None, None, None, None, None\n\n\ndcn_v2_pooling = _DCNv2Pooling.apply\n\n\nclass DCNv2Pooling(nn.Module):\n    def __init__(\n        self,\n        spatial_scale,\n        pooled_size,\n        output_dim,\n        no_trans,\n        group_size=1,\n        part_size=None,\n        sample_per_part=4,\n        trans_std=0.0,\n    ):\n        super(DCNv2Pooling, self).__init__()\n        self.spatial_scale = spatial_scale\n        self.pooled_size = pooled_size\n        self.output_dim = output_dim\n        self.no_trans = no_trans\n        self.group_size = group_size\n        self.part_size = pooled_size if part_size is None else part_size\n        self.sample_per_part = sample_per_part\n        self.trans_std = trans_std\n\n    def forward(self, input, rois, offset):\n        assert input.shape[1] == self.output_dim\n        if self.no_trans:\n            offset = input.new()\n        return dcn_v2_pooling(\n            input,\n            rois,\n            offset,\n            self.spatial_scale,\n            self.pooled_size,\n            self.output_dim,\n            self.no_trans,\n            self.group_size,\n            self.part_size,\n            self.sample_per_part,\n            self.trans_std,\n        )\n\n\nclass DCNPooling(DCNv2Pooling):\n    def __init__(\n        self,\n        spatial_scale,\n        pooled_size,\n        output_dim,\n        no_trans,\n        group_size=1,\n        part_size=None,\n        sample_per_part=4,\n        trans_std=0.0,\n        deform_fc_dim=1024,\n    ):\n        super(DCNPooling, self).__init__(\n            spatial_scale,\n            pooled_size,\n            output_dim,\n            no_trans,\n            group_size,\n            part_size,\n            sample_per_part,\n            trans_std,\n        )\n\n        self.deform_fc_dim = deform_fc_dim\n\n        if not no_trans:\n            self.offset_mask_fc = nn.Sequential(\n                nn.Linear(\n                    self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim\n                ),\n                nn.ReLU(inplace=True),\n                nn.Linear(self.deform_fc_dim, self.deform_fc_dim),\n                nn.ReLU(inplace=True),\n                nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3),\n            )\n            self.offset_mask_fc[4].weight.data.zero_()\n            self.offset_mask_fc[4].bias.data.zero_()\n\n    def forward(self, input, rois):\n        offset = input.new()\n\n        if not self.no_trans:\n\n            # do roi_align first\n            n = rois.shape[0]\n            roi = dcn_v2_pooling(\n                input,\n                rois,\n                offset,\n                self.spatial_scale,\n                self.pooled_size,\n                self.output_dim,\n                True,  # no trans\n                self.group_size,\n                self.part_size,\n                self.sample_per_part,\n                self.trans_std,\n            )\n\n            # build mask and offset\n            offset_mask = self.offset_mask_fc(roi.view(n, -1))\n            offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size)\n            o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)\n            offset = torch.cat((o1, o2), dim=1)\n            mask = torch.sigmoid(mask)\n\n            # do pooling with offset and mask\n            return (\n                dcn_v2_pooling(\n                    input,\n                    rois,\n                    offset,\n                    self.spatial_scale,\n                    self.pooled_size,\n                    self.output_dim,\n                    self.no_trans,\n                    self.group_size,\n                    self.part_size,\n                    self.sample_per_part,\n                    self.trans_std,\n                )\n                * mask\n            )\n        # only roi_align\n        return dcn_v2_pooling(\n            input,\n            rois,\n            offset,\n            self.spatial_scale,\n            self.pooled_size,\n            self.output_dim,\n            self.no_trans,\n            self.group_size,\n            self.part_size,\n            self.sample_per_part,\n            self.trans_std,\n        )\n"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/files.txt",
    "content": "/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu.so\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.py\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/PKG-INFO\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/SOURCES.txt\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/dependency_links.txt\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/native_libs.txt\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/not-zip-safe\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/top_level.txt\n/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/__pycache__/_ext.cpython-37.pyc\n"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/make.sh",
    "content": "#!/usr/bin/env bash\npython setup.py build develop\n"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/setup.py",
    "content": "#!/usr/bin/env python\n\nimport glob\nimport os\n\nimport torch\nfrom setuptools import find_packages, setup\nfrom torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension\n\nrequirements = [\"torch\", \"torchvision\"]\n\n\ndef get_extensions():\n    this_dir = os.path.dirname(os.path.abspath(__file__))\n    extensions_dir = os.path.join(this_dir, \"src\")\n\n    main_file = glob.glob(os.path.join(extensions_dir, \"*.cpp\"))\n    source_cpu = glob.glob(os.path.join(extensions_dir, \"cpu\", \"*.cpp\"))\n    source_cuda = glob.glob(os.path.join(extensions_dir, \"cuda\", \"*.cu\"))\n\n    os.environ[\"CC\"] = \"g++\"\n    sources = main_file + source_cpu\n    extension = CppExtension\n    extra_compile_args = {\"cxx\": []}\n    define_macros = []\n\n    if True:\n        extension = CUDAExtension\n        sources += source_cuda\n        define_macros += [(\"WITH_CUDA\", None)]\n        extra_compile_args[\"nvcc\"] = [\n            \"-DCUDA_HAS_FP16=1\",\n            \"-D__CUDA_NO_HALF_OPERATORS__\",\n            \"-D__CUDA_NO_HALF_CONVERSIONS__\",\n            \"-D__CUDA_NO_HALF2_OPERATORS__\",\n        ]\n    else:\n        # raise NotImplementedError('Cuda is not available')\n        pass\n\n    sources = [os.path.join(extensions_dir, s) for s in sources]\n    include_dirs = [extensions_dir]\n    ext_modules = [\n        extension(\n            \"_ext\",\n            sources,\n            include_dirs=include_dirs,\n            define_macros=define_macros,\n            extra_compile_args=extra_compile_args,\n        )\n    ]\n    return ext_modules\n\n\nsetup(\n    name=\"DCNv2\",\n    version=\"0.1\",\n    author=\"charlesshang\",\n    url=\"https://github.com/charlesshang/DCNv2\",\n    description=\"deformable convolutional networks\",\n    packages=find_packages(exclude=(\"configs\", \"tests\")),\n    # install_requires=requirements,\n    ext_modules=get_extensions(),\n    cmdclass={\"build_ext\": torch.utils.cpp_extension.BuildExtension},\n)\n"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp",
    "content": "#include <vector>\n#include \"cpu/dcn_v2_im2col_cpu.h\"\n\n#include <ATen/ATen.h>\n//#include <ATen/cuda/CUDAContext.h>\n\n#include <TH/TH.h>\n//#include <THC/THCAtomics.cuh>\n//#include <THC/THCDeviceUtils.cuh>\n\n//extern THCState *state;\n\n// author: Charles Shang\n// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu\n// modified from the CUDA version for CPU use by Daniel K. Suhendro\n\nat::Tensor\ndcn_v2_cpu_forward(const at::Tensor &input,\n                    const at::Tensor &weight,\n                    const at::Tensor &bias,\n                    const at::Tensor &offset,\n                    const at::Tensor &mask,\n                    const int kernel_h,\n                    const int kernel_w,\n                    const int stride_h,\n                    const int stride_w,\n                    const int pad_h,\n                    const int pad_w,\n                    const int dilation_h,\n                    const int dilation_w,\n                    const int deformable_group)\n{\n    // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));\n    /*AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n    AT_ASSERTM(weight.type().is_cuda(), \"weight must be a CUDA tensor\");\n    AT_ASSERTM(bias.type().is_cuda(), \"bias must be a CUDA tensor\");\n    AT_ASSERTM(offset.type().is_cuda(), \"offset must be a CUDA tensor\");\n    AT_ASSERTM(mask.type().is_cuda(), \"mask must be a CUDA tensor\");*/\n\n    const int batch = input.size(0);\n    const int channels = input.size(1);\n    const int height = input.size(2);\n    const int width = input.size(3);\n\n    const int channels_out = weight.size(0);\n    const int channels_kernel = weight.size(1);\n    const int kernel_h_ = weight.size(2);\n    const int kernel_w_ = weight.size(3);\n\n    // printf(\"Kernels: %d %d %d %d\\n\", kernel_h_, kernel_w_, kernel_w, kernel_h);\n    // printf(\"Channels: %d %d\\n\", channels, channels_kernel);\n    // printf(\"Channels: %d %d\\n\", channels_out, channels_kernel);\n\n    AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,\n               \"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\", kernel_h_, kernel_w, kernel_h_, kernel_w_);\n\n    AT_ASSERTM(channels == channels_kernel,\n               \"Input shape and kernel channels wont match: (%d vs %d).\", channels, channels_kernel);\n\n    const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n    const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n    auto ones = at::ones({height_out, width_out}, input.options());\n    auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());\n    auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());\n\n    using scalar_t = float;\n    for (int b = 0; b < batch; b++)\n    {\n        auto input_n = input.select(0, b);\n        auto offset_n = offset.select(0, b);\n        auto mask_n = mask.select(0, b);\n        auto output_n = output.select(0, b);\n\n        // Do Bias first:\n        // M,N,K are dims of matrix A and B\n        // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)\n        // (N x 1) (1 x M)\n        long m_ = channels_out;\n        long n_ = height_out * width_out;\n        long k_ = 1;\n        THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,\n                         ones.contiguous().data<scalar_t>(), k_,\n                         bias.contiguous().data<scalar_t>(), k_, 0.0f,\n                         output_n.data<scalar_t>(), n_);\n\n        modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),\n                                         offset_n.data<scalar_t>(),\n                                         mask_n.data<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,\n                                         deformable_group,\n                                         columns.data<scalar_t>());\n\n        //(k * m)  x  (m * n)\n        // Y = WC\n        long m = channels_out;\n        long n = height_out * width_out;\n        long k = channels * kernel_h * kernel_w;\n        THFloatBlas_gemm('n', 'n', n, m, k, 1.0f,\n                         columns.data<scalar_t>(), n,\n                         weight.data<scalar_t>(), k, 1.0f,\n                         output_n.data<scalar_t>(), n);\n    }\n    return output;\n}\n\nstd::vector<at::Tensor> dcn_v2_cpu_backward(const at::Tensor &input,\n                                             const at::Tensor &weight,\n                                             const at::Tensor &bias,\n                                             const at::Tensor &offset,\n                                             const at::Tensor &mask,\n                                             const at::Tensor &grad_output,\n                                             int kernel_h, int kernel_w,\n                                             int stride_h, int stride_w,\n                                             int pad_h, int pad_w,\n                                             int dilation_h, int dilation_w,\n                                             int deformable_group)\n{\n\n    THArgCheck(input.is_contiguous(), 1, \"input tensor has to be contiguous\");\n    THArgCheck(weight.is_contiguous(), 2, \"weight tensor has to be contiguous\");\n\n    /*AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n    AT_ASSERTM(weight.type().is_cuda(), \"weight must be a CUDA tensor\");\n    AT_ASSERTM(bias.type().is_cuda(), \"bias must be a CUDA tensor\");\n    AT_ASSERTM(offset.type().is_cuda(), \"offset must be a CUDA tensor\");\n    AT_ASSERTM(mask.type().is_cuda(), \"mask must be a CUDA tensor\");*/\n\n    const int batch = input.size(0);\n    const int channels = input.size(1);\n    const int height = input.size(2);\n    const int width = input.size(3);\n\n    const int channels_out = weight.size(0);\n    const int channels_kernel = weight.size(1);\n    const int kernel_h_ = weight.size(2);\n    const int kernel_w_ = weight.size(3);\n\n    AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,\n               \"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\", kernel_h_, kernel_w, kernel_h_, kernel_w_);\n\n    AT_ASSERTM(channels == channels_kernel,\n               \"Input shape and kernel channels wont match: (%d vs %d).\", channels, channels_kernel);\n\n    const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n    const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n    auto ones = at::ones({height_out, width_out}, input.options());\n    auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());\n    auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());\n\n    auto grad_input = at::zeros_like(input);\n    auto grad_weight = at::zeros_like(weight);\n    auto grad_bias = at::zeros_like(bias);\n    auto grad_offset = at::zeros_like(offset);\n    auto grad_mask = at::zeros_like(mask);\n\n    using scalar_t = float;\n\n    for (int b = 0; b < batch; b++)\n    {\n        auto input_n = input.select(0, b);\n        auto offset_n = offset.select(0, b);\n        auto mask_n = mask.select(0, b);\n        auto grad_output_n = grad_output.select(0, b);\n        auto grad_input_n = grad_input.select(0, b);\n        auto grad_offset_n = grad_offset.select(0, b);\n        auto grad_mask_n = grad_mask.select(0, b);\n\n        long m = channels * kernel_h * kernel_w;\n        long n = height_out * width_out;\n        long k = channels_out;\n\n        THFloatBlas_gemm('n', 't', n, m, k, 1.0f,\n                         grad_output_n.data<scalar_t>(), n,\n                         weight.data<scalar_t>(), m, 0.0f,\n                         columns.data<scalar_t>(), n);\n\n        // gradient w.r.t. input coordinate data\n        modulated_deformable_col2im_coord_cpu(columns.data<scalar_t>(),\n                                               input_n.data<scalar_t>(),\n                                               offset_n.data<scalar_t>(),\n                                               mask_n.data<scalar_t>(),\n                                               1, channels, height, width,\n                                               height_out, width_out, kernel_h, kernel_w,\n                                               pad_h, pad_w, stride_h, stride_w,\n                                               dilation_h, dilation_w, deformable_group,\n                                               grad_offset_n.data<scalar_t>(),\n                                               grad_mask_n.data<scalar_t>());\n        // gradient w.r.t. input data\n        modulated_deformable_col2im_cpu(columns.data<scalar_t>(),\n                                         offset_n.data<scalar_t>(),\n                                         mask_n.data<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w,\n                                         dilation_h, dilation_w, deformable_group,\n                                         grad_input_n.data<scalar_t>());\n\n        // gradient w.r.t. weight, dWeight should accumulate across the batch and group\n        modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),\n                                         offset_n.data<scalar_t>(),\n                                         mask_n.data<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w,\n                                         dilation_h, dilation_w, deformable_group,\n                                         columns.data<scalar_t>());\n\n        long m_ = channels_out;\n        long n_ = channels * kernel_h * kernel_w;\n        long k_ = height_out * width_out;\n\n        THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,\n                         columns.data<scalar_t>(), k_,\n                         grad_output_n.data<scalar_t>(), k_, 1.0f,\n                         grad_weight.data<scalar_t>(), n_);\n\n        // gradient w.r.t. bias\n        // long m_ = channels_out;\n        // long k__ = height_out * width_out;\n        // THFloatBlas_gemv('t', k_, m_, 1.0f,\n        //                  grad_output_n.data<scalar_t>(), k_,\n        //                  ones.data<scalar_t>(), 1, 1.0f,\n        //                  grad_bias.data<scalar_t>(), 1);\n    }\n\n    return {\n        grad_input, grad_offset, grad_mask, grad_weight, grad_bias\n    };\n}"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp",
    "content": "#include \"dcn_v2_im2col_cpu.h\"\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n//#include <ATen/cuda/CUDAContext.h>\n\n#include <TH/TH.h>\n//#include <THC/THCAtomics.cuh>\n//#include <THC/THCDeviceUtils.cuh>\n\n// modified from the CUDA version for CPU use by Daniel K. Suhendro\n\n/*#define CUDA_KERNEL_LOOP(i, n)                          \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x;   \\\n      i < (n);                                          \\\n      i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N)\n{\n  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;\n}*/\n\n\nfloat dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_width,\n                           const int height, const int width, float h, float w)\n{\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  float lh = h - h_low;\n  float lw = w - w_low;\n  float hh = 1 - lh, hw = 1 - lw;\n\n  float v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n    v1 = bottom_data[h_low * data_width + w_low];\n  float v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n    v2 = bottom_data[h_low * data_width + w_high];\n  float v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n    v3 = bottom_data[h_high * data_width + w_low];\n  float v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n    v4 = bottom_data[h_high * data_width + w_high];\n\n  float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\nfloat dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w,\n                               const int h, const int w, const int height, const int width)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  float weight = 0;\n  if (h == argmax_h_low && w == argmax_w_low)\n    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);\n  if (h == argmax_h_low && w == argmax_w_high)\n    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);\n  if (h == argmax_h_high && w == argmax_w_low)\n    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);\n  if (h == argmax_h_high && w == argmax_w_high)\n    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);\n  return weight;\n}\n\nfloat dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w,\n                                 const int height, const int width, const float *im_data,\n                                 const int data_width, const int bp_dir)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  float weight = 0;\n\n  if (bp_dir == 0)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n  else if (bp_dir == 1)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n\n  return weight;\n}\n\nvoid modulated_deformable_im2col_cpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask,\n                                                       const int height, const int width, const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int num_channels, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       float *data_col)\n{\n  // launch channels * batch_size * height_col * width_col cores\n  for(int index=0; index<n; index++)\n  {\n    // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)\n    // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis\n\n    // index index of output matrix\n    const int w_col = index % width_col;\n    const int h_col = (index / width_col) % height_col;\n    // const int b_col = (index / width_col / height_col) % batch_size;\n    const int b_col = (index / width_col / height_col / num_channels) % batch_size;\n    // const int c_im = (index / width_col / height_col) / batch_size;\n    const int c_im = (index / width_col / height_col) % num_channels;\n    // const int c_col = c_im * kernel_h * kernel_w;\n    const int c_col = c_im * kernel_h * kernel_w;\n\n    // compute deformable group index\n    const int deformable_group_index = c_im / channel_per_deformable_group;\n\n    const int h_in = h_col * stride_h - pad_h;\n    const int w_in = w_col * stride_w - pad_w;\n\n    //  float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;\n    float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col;\n    //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;\n    const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;\n    const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n\n    const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    for (int i = 0; i < kernel_h; ++i)\n    {\n      for (int j = 0; j < kernel_w; ++j)\n      {\n        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;\n        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;\n        const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;\n        const float offset_h = data_offset_ptr[data_offset_h_ptr];\n        const float offset_w = data_offset_ptr[data_offset_w_ptr];\n        const float mask = data_mask_ptr[data_mask_hw_ptr];\n        float val = static_cast<float>(0);\n        const float h_im = h_in + i * dilation_h + offset_h;\n        const float w_im = w_in + j * dilation_w + offset_w;\n        //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {\n        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)\n        {\n          //const float map_h = i * dilation_h + offset_h;\n          //const float map_w = j * dilation_w + offset_w;\n          //const int cur_height = height - h_in;\n          //const int cur_width = width - w_in;\n          //val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, cur_height, cur_width, map_h, map_w);\n          val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im);\n        }\n        *data_col_ptr = val * mask;\n        // data_col_ptr += batch_size * height_col * width_col;\n        data_col_ptr += height_col * width_col;\n      }\n    }\n  }\n}\n\nvoid modulated_deformable_col2im_cpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask,\n                                                       const int channels, const int height, const int width,\n                                                       const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       float *grad_im)\n{\n  for(int index = 0; index < n; index++)\n  {\n    const int j = (index / width_col / height_col / batch_size) % kernel_w;\n    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;\n    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / channel_per_deformable_group;\n\n    int w_out = index % width_col;\n    int h_out = (index / width_col) % height_col;\n    int b = (index / width_col / height_col) % batch_size;\n    int w_in = w_out * stride_w - pad_w;\n    int h_in = h_out * stride_h - pad_h;\n\n    const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;\n    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;\n    const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;\n    const float offset_h = data_offset_ptr[data_offset_h_ptr];\n    const float offset_w = data_offset_ptr[data_offset_w_ptr];\n    const float mask = data_mask_ptr[data_mask_hw_ptr];\n    const float cur_inv_h_data = h_in + i * dilation_h + offset_h;\n    const float cur_inv_w_data = w_in + j * dilation_w + offset_w;\n\n    const float cur_top_grad = data_col[index] * mask;\n    const int cur_h = (int)cur_inv_h_data;\n    const int cur_w = (int)cur_inv_w_data;\n    \n    for (int dy = -2; dy <= 2; dy++)\n    {\n      for (int dx = -2; dx <= 2; dx++)\n      {\n        if (cur_h + dy >= 0 && cur_h + dy < height &&\n            cur_w + dx >= 0 && cur_w + dx < width &&\n            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&\n            abs(cur_inv_w_data - (cur_w + dx)) < 1)\n        {\n          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;\n          float weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);\n          //atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);\n          *(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad;\n\n        }\n      }\n    }\n  }\n}\n\nvoid modulated_deformable_col2im_coord_cpu_kernel(const int n, const float *data_col, const float *data_im,\n                                                             const float *data_offset, const float *data_mask,\n                                                             const int channels, const int height, const int width,\n                                                             const int kernel_h, const int kernel_w,\n                                                             const int pad_h, const int pad_w,\n                                                             const int stride_h, const int stride_w,\n                                                             const int dilation_h, const int dilation_w,\n                                                             const int channel_per_deformable_group,\n                                                             const int batch_size, const int offset_channels, const int deformable_group,\n                                                             const int height_col, const int width_col,\n                                                             float *grad_offset, float *grad_mask)\n{\n  for(int index = 0; index < n; index++)\n  {\n    float val = 0, mval = 0;\n    int w = index % width_col;\n    int h = (index / width_col) % height_col;\n    int c = (index / width_col / height_col) % offset_channels;\n    int b = (index / width_col / height_col) / offset_channels;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / (2 * kernel_h * kernel_w);\n    const int col_step = kernel_h * kernel_w;\n    int cnt = 0;\n    const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;\n    const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;\n    const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;\n\n    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)\n    {\n      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;\n      const int bp_dir = offset_c % 2;\n\n      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;\n      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;\n      int w_out = col_pos % width_col;\n      int h_out = (col_pos / width_col) % height_col;\n      int w_in = w_out * stride_w - pad_w;\n      int h_in = h_out * stride_h - pad_h;\n      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);\n      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);\n      const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);\n      const float offset_h = data_offset_ptr[data_offset_h_ptr];\n      const float offset_w = data_offset_ptr[data_offset_w_ptr];\n      const float mask = data_mask_ptr[data_mask_hw_ptr];\n      float inv_h = h_in + i * dilation_h + offset_h;\n      float inv_w = w_in + j * dilation_w + offset_w;\n      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)\n      {\n        inv_h = inv_w = -2;\n      }\n      else\n      {\n        mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);\n      }\n      const float weight = dmcn_get_coordinate_weight_cpu(\n          inv_h, inv_w,\n          height, width, data_im_ptr + cnt * height * width, width, bp_dir);\n      val += weight * data_col_ptr[col_pos] * mask;\n      cnt += 1;\n    }\n    // KERNEL_ASSIGN(grad_offset[index], offset_req, val);\n    grad_offset[index] = val;\n    if (offset_c % 2 == 0)\n      // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);\n      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;\n  }\n}\n\nvoid modulated_deformable_im2col_cpu(const float* data_im, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w,\n  const int deformable_group, float* data_col) {\n  // num_axes should be smaller than block size\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * batch_size * height_col * width_col;\n  modulated_deformable_im2col_cpu_kernel(\n      num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,\n      pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,\n      batch_size, channels, deformable_group, height_col, width_col, data_col);\n  \n  /*cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_im2col_cuda: %s\\n\", cudaGetErrorString(err));\n  }*/\n\n}\n\nvoid modulated_deformable_col2im_cpu(const float* data_col, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w, \n  const int deformable_group, float* grad_im){\n\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;\n  modulated_deformable_col2im_cpu_kernel(\n        num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,\n        kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,\n        dilation_h, dilation_w, channel_per_deformable_group,\n        batch_size, deformable_group, height_col, width_col, grad_im);\n  /*cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_cuda: %s\\n\", cudaGetErrorString(err));\n  }*/\n\n}\n\nvoid modulated_deformable_col2im_coord_cpu(const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w, \n  const int deformable_group,\n  float* grad_offset, float* grad_mask) {\n  const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;\n  const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;\n  modulated_deformable_col2im_coord_cpu_kernel(\n        num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,\n        kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,\n        dilation_h, dilation_w, channel_per_deformable_group,\n        batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, \n        grad_offset, grad_mask);\n  /*cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_coord_cuda: %s\\n\", cudaGetErrorString(err));\n  }*/\n}"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h",
    "content": "\n/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contributions by the University of California:\n * Copyright (c) 2014-2017 The Regents of the University of California (Regents)\n * All rights reserved.\n *\n * All other contributions:\n * Copyright (c) 2014-2017, the respective contributors\n * All rights reserved.\n *\n * Caffe uses a shared copyright model: each contributor holds copyright over\n * their contributions to Caffe. The project versioning records all such\n * contribution and copyright details. If a contributor wants to further mark\n * their specific copyright on a particular contribution, they should indicate\n * their copyright solely in the commit message of the change when it is\n * committed.\n *\n * LICENSE\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n * CONTRIBUTION AGREEMENT\n *\n * By contributing to the BVLC/caffe repository through pull-request, comment,\n * or otherwise, the contributor releases their content to the\n * license and copyright terms herein.\n *\n ***************** END Caffe Copyright Notice and Disclaimer ********************\n *\n * Copyright (c) 2018 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file modulated_deformable_im2col.h\n * \\brief Function definitions of converting an image to\n * column matrix based on kernel, padding, dilation, and offset.\n * These functions are mainly used in deformable convolution operators.\n * \\ref: https://arxiv.org/abs/1811.11168\n * \\author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu\n */\n\n/***************** Adapted by Charles Shang *********************/\n// modified from the CUDA version for CPU use by Daniel K. Suhendro\n\n#ifndef DCN_V2_IM2COL_CPU\n#define DCN_V2_IM2COL_CPU\n\n#ifdef __cplusplus\nextern \"C\"\n{\n#endif\n\n  void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask,\n                                        const int batch_size, const int channels, const int height_im, const int width_im,\n                                        const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                        const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                        const int dilation_h, const int dilation_w,\n                                        const int deformable_group, float *data_col);\n\n  void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask,\n                                        const int batch_size, const int channels, const int height_im, const int width_im,\n                                        const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                        const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                        const int dilation_h, const int dilation_w,\n                                        const int deformable_group, float *grad_im);\n\n  void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,\n                                         const int batch_size, const int channels, const int height_im, const int width_im,\n                                         const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                         const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                         const int dilation_h, const int dilation_w,\n                                         const int deformable_group,\n                                         float *grad_offset, float *grad_mask);\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp",
    "content": "/*!\n * Copyright (c) 2017 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file deformable_psroi_pooling.cu\n * \\brief\n * \\author Yi Li, Guodong Zhang, Jifeng Dai\n*/\n/***************** Adapted by Charles Shang *********************/\n// modified from the CUDA version for CPU use by Daniel K. Suhendro\n\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n//#include <ATen/cuda/CUDAContext.h>\n\n#include <TH/TH.h>\n//#include <THC/THCAtomics.cuh>\n//#include <THC/THCDeviceUtils.cuh>\n\n/*#define CUDA_KERNEL_LOOP(i, n)                        \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \\\n       i < (n);                                       \\\n       i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N)\n{\n  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;\n}*/\n\ntemplate <typename T>\nT bilinear_interp_cpu(\n    const T *data,\n    const T x,\n    const T y,\n    const int width,\n    const int height)\n{\n  int x1 = floor(x);\n  int x2 = ceil(x);\n  int y1 = floor(y);\n  int y2 = ceil(y);\n  T dist_x = static_cast<T>(x - x1);\n  T dist_y = static_cast<T>(y - y1);\n  T value11 = data[y1 * width + x1];\n  T value12 = data[y2 * width + x1];\n  T value21 = data[y1 * width + x2];\n  T value22 = data[y2 * width + x2];\n  T value = (1 - dist_x) * (1 - dist_y) * value11 +\n            (1 - dist_x) * dist_y * value12 +\n            dist_x * (1 - dist_y) * value21 +\n            dist_x * dist_y * value22;\n  return value;\n}\n\ntemplate <typename T>\n void DeformablePSROIPoolForwardKernelCpu(\n    const int count,\n    const T *bottom_data,\n    const T spatial_scale,\n    const int channels,\n    const int height, const int width,\n    const int pooled_height, const int pooled_width,\n    const T *bottom_rois, const T *bottom_trans,\n    const int no_trans,\n    const T trans_std,\n    const int sample_per_part,\n    const int output_dim,\n    const int group_size,\n    const int part_size,\n    const int num_classes,\n    const int channels_each_class,\n    T *top_data,\n    T *top_count)\n{\n  for(int index = 0; index < count; index++)\n  {\n    // The output is in order (n, ctop, ph, pw)\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int ctop = (index / pooled_width / pooled_height) % output_dim;\n    int n = index / pooled_width / pooled_height / output_dim;\n\n    // [start, end) interval for spatial sampling\n    const T *offset_bottom_rois = bottom_rois + n * 5;\n    int roi_batch_ind = offset_bottom_rois[0];\n    T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;\n    T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;\n    T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;\n    T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;\n\n    // Force too small ROIs to be 1x1\n    T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0\n    T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));\n\n    // Compute w and h at bottom\n    T bin_size_h = roi_height / static_cast<T>(pooled_height);\n    T bin_size_w = roi_width / static_cast<T>(pooled_width);\n\n    T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);\n    T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);\n\n    int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);\n    int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);\n    int class_id = ctop / channels_each_class;\n    T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;\n    T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;\n\n    T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;\n    wstart += trans_x * roi_width;\n    T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;\n    hstart += trans_y * roi_height;\n\n    T sum = 0;\n    int count = 0;\n    int gw = floor(static_cast<T>(pw) * group_size / pooled_width);\n    int gh = floor(static_cast<T>(ph) * group_size / pooled_height);\n    gw = std::min(std::max(gw, 0), group_size - 1);\n    gh = std::min(std::max(gh, 0), group_size - 1);\n\n    const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;\n    for (int ih = 0; ih < sample_per_part; ih++)\n    {\n      for (int iw = 0; iw < sample_per_part; iw++)\n      {\n        T w = wstart + iw * sub_bin_size_w;\n        T h = hstart + ih * sub_bin_size_h;\n        // bilinear interpolation\n        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)\n        {\n          continue;\n        }\n        w = std::min(std::max(w, T(0.)), width - T(1.));\n        h = std::min(std::max(h, T(0.)), height - T(1.));\n        int c = (ctop * group_size + gh) * group_size + gw;\n        T val = bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height);\n        sum += val;\n        count++;\n      }\n    }\n    top_data[index] = count == 0 ? static_cast<T>(0) : sum / count;\n    top_count[index] = count;\n  }\n}\n\ntemplate <typename T>\nvoid DeformablePSROIPoolBackwardAccKernelCpu(\n    const int count,\n    const T *top_diff,\n    const T *top_count,\n    const int num_rois,\n    const T spatial_scale,\n    const int channels,\n    const int height, const int width,\n    const int pooled_height, const int pooled_width,\n    const int output_dim,\n    T *bottom_data_diff, T *bottom_trans_diff,\n    const T *bottom_data,\n    const T *bottom_rois,\n    const T *bottom_trans,\n    const int no_trans,\n    const T trans_std,\n    const int sample_per_part,\n    const int group_size,\n    const int part_size,\n    const int num_classes,\n    const int channels_each_class)\n{\n  for(int index = 0; index < count; index++)\n  {\n    // The output is in order (n, ctop, ph, pw)\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int ctop = (index / pooled_width / pooled_height) % output_dim;\n    int n = index / pooled_width / pooled_height / output_dim;\n\n    // [start, end) interval for spatial sampling\n    const T *offset_bottom_rois = bottom_rois + n * 5;\n    int roi_batch_ind = offset_bottom_rois[0];\n    T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;\n    T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;\n    T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;\n    T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;\n    \n    // Force too small ROIs to be 1x1\n    T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0\n    T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));\n\n    // Compute w and h at bottom\n    T bin_size_h = roi_height / static_cast<T>(pooled_height);\n    T bin_size_w = roi_width / static_cast<T>(pooled_width);\n\n    T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);\n    T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);\n\n    int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);\n    int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);\n    int class_id = ctop / channels_each_class;\n    T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;\n    T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;\n\n    T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;\n    wstart += trans_x * roi_width;\n    T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;\n    hstart += trans_y * roi_height;\n\n    if (top_count[index] <= 0)\n    {\n      continue;\n    }\n    T diff_val = top_diff[index] / top_count[index];\n    const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;\n    T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;\n    int gw = floor(static_cast<T>(pw) * group_size / pooled_width);\n    int gh = floor(static_cast<T>(ph) * group_size / pooled_height);\n    gw = std::min(std::max(gw, 0), group_size - 1);\n    gh = std::min(std::max(gh, 0), group_size - 1);\n\n    for (int ih = 0; ih < sample_per_part; ih++)\n    {\n      for (int iw = 0; iw < sample_per_part; iw++)\n      {\n        T w = wstart + iw * sub_bin_size_w;\n        T h = hstart + ih * sub_bin_size_h;\n        // bilinear interpolation\n        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)\n        {\n          continue;\n        }\n        w = std::min(std::max(w, T(0.)), width - T(1.));\n        h = std::min(std::max(h, T(0.)), height - T(1.));\n        int c = (ctop * group_size + gh) * group_size + gw;\n        // backward on feature\n        int x0 = floor(w);\n        int x1 = ceil(w);\n        int y0 = floor(h);\n        int y1 = ceil(h);\n        T dist_x = w - x0, dist_y = h - y0;\n        T q00 = (1 - dist_x) * (1 - dist_y);\n        T q01 = (1 - dist_x) * dist_y;\n        T q10 = dist_x * (1 - dist_y);\n        T q11 = dist_x * dist_y;\n        int bottom_index_base = c * height * width;\n        /*atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);*/\n       *(offset_bottom_data_diff + bottom_index_base + y0 * width + x0) += q00 * diff_val;\n       *(offset_bottom_data_diff + bottom_index_base + y1 * width + x0) += q01 * diff_val;\n       *(offset_bottom_data_diff + bottom_index_base + y0 * width + x1) += q10 * diff_val;\n       *(offset_bottom_data_diff + bottom_index_base + y1 * width + x1) += q11 * diff_val;\n\n\n        if (no_trans)\n        {\n          continue;\n        }\n        T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];\n        T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];\n        T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];\n        T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];\n        T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;\n        diff_x *= roi_width;\n        T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;\n        diff_y *= roi_height;\n\n        /*atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);\n        atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);*/\n        *(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w) += diff_x;\n        *(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w) += diff_y;\n      }\n    }\n  }\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,\n                                  const at::Tensor &bbox,\n                                  const at::Tensor &trans,\n                                  const int no_trans,\n                                  const float spatial_scale,\n                                  const int output_dim,\n                                  const int group_size,\n                                  const int pooled_size,\n                                  const int part_size,\n                                  const int sample_per_part,\n                                  const float trans_std)\n{\n  /*AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n  AT_ASSERTM(bbox.type().is_cuda(), \"rois must be a CUDA tensor\");\n  AT_ASSERTM(trans.type().is_cuda(), \"trans must be a CUDA tensor\");*/\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n  const int channels_trans = no_trans ? 2 : trans.size(1);\n  const int num_bbox = bbox.size(0);\n\n  AT_ASSERTM(channels == output_dim, \"input channels and output channels must equal\");\n  auto pooled_height = pooled_size;\n  auto pooled_width = pooled_size;\n\n  auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());\n  long out_size = num_bbox * output_dim * pooled_height * pooled_width;\n  auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());\n\n  const int num_classes = no_trans ? 1 : channels_trans / 2;\n  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;\n\n  //cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (out.numel() == 0)\n  {\n    //THCudaCheck(cudaGetLastError());\n    return std::make_tuple(out, top_count);\n  }\n\n  /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));\n  dim3 block(512);*/\n\n  AT_DISPATCH_FLOATING_TYPES(input.type(), \"dcn_v2_psroi_pooling_cpu_forward\", [&] {\n    DeformablePSROIPoolForwardKernelCpu<scalar_t>(\n        out_size,\n        input.contiguous().data<scalar_t>(),\n        spatial_scale,\n        channels,\n        height, width,\n        pooled_height,\n        pooled_width,\n        bbox.contiguous().data<scalar_t>(),\n        trans.contiguous().data<scalar_t>(),\n        no_trans,\n        trans_std,\n        sample_per_part,\n        output_dim,\n        group_size,\n        part_size,\n        num_classes,\n        channels_each_class,\n        out.data<scalar_t>(),\n        top_count.data<scalar_t>());\n  });\n  //THCudaCheck(cudaGetLastError());\n  return std::make_tuple(out, top_count);\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,\n                                   const at::Tensor &input,\n                                   const at::Tensor &bbox,\n                                   const at::Tensor &trans,\n                                   const at::Tensor &top_count,\n                                   const int no_trans,\n                                   const float spatial_scale,\n                                   const int output_dim,\n                                   const int group_size,\n                                   const int pooled_size,\n                                   const int part_size,\n                                   const int sample_per_part,\n                                   const float trans_std)\n{\n  /*AT_ASSERTM(out_grad.type().is_cuda(), \"out_grad must be a CUDA tensor\");\n  AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n  AT_ASSERTM(bbox.type().is_cuda(), \"bbox must be a CUDA tensor\");\n  AT_ASSERTM(trans.type().is_cuda(), \"trans must be a CUDA tensor\");\n  AT_ASSERTM(top_count.type().is_cuda(), \"top_count must be a CUDA tensor\");*/\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n  const int channels_trans = no_trans ? 2 : trans.size(1);\n  const int num_bbox = bbox.size(0);\n\n  AT_ASSERTM(channels == output_dim, \"input channels and output channels must equal\");\n  auto pooled_height = pooled_size;\n  auto pooled_width = pooled_size;\n  long out_size = num_bbox * output_dim * pooled_height * pooled_width;\n  const int num_classes = no_trans ? 1 : channels_trans / 2;\n  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;\n\n  auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());\n  auto trans_grad = at::zeros_like(trans);\n\n  if (input_grad.numel() == 0)\n  {\n    //THCudaCheck(cudaGetLastError());\n    return std::make_tuple(input_grad, trans_grad);\n  }\n\n  /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));\n  dim3 block(512);\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/\n\n  AT_DISPATCH_FLOATING_TYPES(out_grad.type(), \"dcn_v2_psroi_pooling_cpu_backward\", [&] {\n    DeformablePSROIPoolBackwardAccKernelCpu<scalar_t>(\n        out_size,\n        out_grad.contiguous().data<scalar_t>(),\n        top_count.contiguous().data<scalar_t>(),\n        num_bbox,\n        spatial_scale,\n        channels,\n        height,\n        width,\n        pooled_height,\n        pooled_width,\n        output_dim,\n        input_grad.contiguous().data<scalar_t>(),\n        trans_grad.contiguous().data<scalar_t>(),\n        input.contiguous().data<scalar_t>(),\n        bbox.contiguous().data<scalar_t>(),\n        trans.contiguous().data<scalar_t>(),\n        no_trans,\n        trans_std,\n        sample_per_part,\n        group_size,\n        part_size,\n        num_classes,\n        channels_each_class);\n  });\n  //THCudaCheck(cudaGetLastError());\n  return std::make_tuple(input_grad, trans_grad);\n}"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cpu/vision.h",
    "content": "#pragma once\n#include <torch/extension.h>\n\nat::Tensor\ndcn_v2_cpu_forward(const at::Tensor &input,\n                    const at::Tensor &weight,\n                    const at::Tensor &bias,\n                    const at::Tensor &offset,\n                    const at::Tensor &mask,\n                    const int kernel_h,\n                    const int kernel_w,\n                    const int stride_h,\n                    const int stride_w,\n                    const int pad_h,\n                    const int pad_w,\n                    const int dilation_h,\n                    const int dilation_w,\n                    const int deformable_group);\n\nstd::vector<at::Tensor>\ndcn_v2_cpu_backward(const at::Tensor &input,\n                     const at::Tensor &weight,\n                     const at::Tensor &bias,\n                     const at::Tensor &offset,\n                     const at::Tensor &mask,\n                     const at::Tensor &grad_output,\n                     int kernel_h, int kernel_w,\n                     int stride_h, int stride_w,\n                     int pad_h, int pad_w,\n                     int dilation_h, int dilation_w,\n                     int deformable_group);\n\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,\n                                  const at::Tensor &bbox,\n                                  const at::Tensor &trans,\n                                  const int no_trans,\n                                  const float spatial_scale,\n                                  const int output_dim,\n                                  const int group_size,\n                                  const int pooled_size,\n                                  const int part_size,\n                                  const int sample_per_part,\n                                  const float trans_std);\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,\n                                   const at::Tensor &input,\n                                   const at::Tensor &bbox,\n                                   const at::Tensor &trans,\n                                   const at::Tensor &top_count,\n                                   const int no_trans,\n                                   const float spatial_scale,\n                                   const int output_dim,\n                                   const int group_size,\n                                   const int pooled_size,\n                                   const int part_size,\n                                   const int sample_per_part,\n                                   const float trans_std);"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_cuda.cu",
    "content": "#include <vector>\n#include \"cuda/dcn_v2_im2col_cuda.h\"\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/CUDABlas.h>\n#include <ATen/Dispatch.h>\n#include <ATen/div_rtn.h>\n#include <THC/THC.h>\n#include <THC/THCAtomics.cuh>\n#include <THC/THCDeviceUtils.cuh>\n#include <ATen/cuda/CUDABlas.h>\n#include <ATen/cuda/Exceptions.h>\n\nTHCState *state = at::globalContext().lazyInitCUDA();\n\nstatic cublasOperation_t _cublasOpFromChar(char op) {\n    switch (op) {\n      case 'n':\n      case 'N':\n        return CUBLAS_OP_N;\n      case 't':\n      case 'T':\n        return CUBLAS_OP_T;\n      case 'c':\n      case 'C':\n        return CUBLAS_OP_C;\n    }\n    AT_ERROR(\n        \"_cublasOpFromChar input should be 't', 'n' or 'c' but got `\", op, \"`\");\n  }\n\n  static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {\n    // Note: leading dimensions generally are checked that they are > 0\n    // and at least as big the result requires (even if the value won't\n    // be used).\n  \n    // Q: Why does Level3 check trans but this doesn't?\n    // A: In level 2, the sizes (m, n) specify the size of A\n    // (independent of trans value). In level 3. the sizes (m, n, k)\n    // specify the sizes of op(A), op(B) where op depend on trans\n    // values.\n    if (n <= 1)\n      *lda = std::max<int64_t>(m, 1);\n  }\n\n\n\n// author: Charles Shang\n// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu\n\n// [batch gemm]\n// https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu\n\n__global__ void createBatchGemmBuffer(const float **input_b, float **output_b,\n                                      float **columns_b, const float **ones_b,\n                                      const float **weight_b, const float **bias_b,\n                                      float *input, float *output,\n                                      float *columns, float *ones,\n                                      float *weight, float *bias,\n                                      const int input_stride, const int output_stride,\n                                      const int columns_stride, const int ones_stride,\n                                      const int num_batches)\n{\n    const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < num_batches)\n    {\n        input_b[idx] = input + idx * input_stride;\n        output_b[idx] = output + idx * output_stride;\n        columns_b[idx] = columns + idx * columns_stride;\n        ones_b[idx] = ones + idx * ones_stride;\n        // share weights and bias within a Mini-Batch\n        weight_b[idx] = weight;\n        bias_b[idx] = bias;\n    }\n}\n\nat::Tensor\ndcn_v2_cuda_forward(const at::Tensor &input,\n                    const at::Tensor &weight,\n                    const at::Tensor &bias,\n                    const at::Tensor &offset,\n                    const at::Tensor &mask,\n                    const int kernel_h,\n                    const int kernel_w,\n                    const int stride_h,\n                    const int stride_w,\n                    const int pad_h,\n                    const int pad_w,\n                    const int dilation_h,\n                    const int dilation_w,\n                    const int deformable_group)\n{\n    using scalar_t = float;\n    // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));\n    AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n    AT_ASSERTM(weight.type().is_cuda(), \"weight must be a CUDA tensor\");\n    AT_ASSERTM(bias.type().is_cuda(), \"bias must be a CUDA tensor\");\n    AT_ASSERTM(offset.type().is_cuda(), \"offset must be a CUDA tensor\");\n    AT_ASSERTM(mask.type().is_cuda(), \"mask must be a CUDA tensor\");\n\n    const int batch = input.size(0);\n    const int channels = input.size(1);\n    const int height = input.size(2);\n    const int width = input.size(3);\n\n    const int channels_out = weight.size(0);\n    const int channels_kernel = weight.size(1);\n    const int kernel_h_ = weight.size(2);\n    const int kernel_w_ = weight.size(3);\n\n    // printf(\"Kernels: %d %d %d %d\\n\", kernel_h_, kernel_w_, kernel_w, kernel_h);\n    // printf(\"Channels: %d %d\\n\", channels, channels_kernel);\n    // printf(\"Channels: %d %d\\n\", channels_out, channels_kernel);\n\n    AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,\n               \"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\", kernel_h_, kernel_w, kernel_h_, kernel_w_);\n\n    AT_ASSERTM(channels == channels_kernel,\n               \"Input shape and kernel channels wont match: (%d vs %d).\", channels, channels_kernel);\n\n    const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n    const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n    auto ones = at::ones({batch, height_out, width_out}, input.options());\n    auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());\n    auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());\n\n    // prepare for batch-wise computing, which is significantly faster than instance-wise computing\n    // when batch size is large.\n    // launch batch threads\n    int matrices_size = batch * sizeof(float *);\n    auto input_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));\n    auto output_b = static_cast<float **>(THCudaMalloc(state, matrices_size));\n    auto columns_b = static_cast<float **>(THCudaMalloc(state, matrices_size));\n    auto ones_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));\n    auto weight_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));\n    auto bias_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));\n\n    const int block = 128;\n    const int grid = (batch + block - 1) / block;\n\n    createBatchGemmBuffer<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(\n        input_b, output_b,\n        columns_b, ones_b,\n        weight_b, bias_b,\n        input.data_ptr<scalar_t>(),\n        output.data_ptr<scalar_t>(),\n        columns.data_ptr<scalar_t>(),\n        ones.data_ptr<scalar_t>(),\n        weight.data_ptr<scalar_t>(),\n        bias.data_ptr<scalar_t>(),\n        channels * width * height,\n        channels_out * width_out * height_out,\n        channels * kernel_h * kernel_w * height_out * width_out,\n        height_out * width_out,\n        batch);\n\n    long m_ = channels_out;\n    long n_ = height_out * width_out;\n    long k_ = 1;\n    THCudaBlas_SgemmBatched(state,\n                            't',\n                            'n',\n                            n_,\n                            m_,\n                            k_,\n                            1.0f,\n                            ones_b, k_,\n                            bias_b, k_,\n                            0.0f,\n                            output_b, n_,\n                            batch);\n\n    modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),\n                                     input.data_ptr<scalar_t>(),\n                                     offset.data_ptr<scalar_t>(),\n                                     mask.data_ptr<scalar_t>(),\n                                     batch, channels, height, width,\n                                     height_out, width_out, kernel_h, kernel_w,\n                                     pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,\n                                     deformable_group,\n                                     columns.data_ptr<scalar_t>());\n\n    long m = channels_out;\n    long n = height_out * width_out;\n    long k = channels * kernel_h * kernel_w;\n    THCudaBlas_SgemmBatched(state,\n                            'n',\n                            'n',\n                            n,\n                            m,\n                            k,\n                            1.0f,\n                            (const float **)columns_b, n,\n                            weight_b, k,\n                            1.0f,\n                            output_b, n,\n                            batch);\n\n    THCudaFree(state, input_b);\n    THCudaFree(state, output_b);\n    THCudaFree(state, columns_b);\n    THCudaFree(state, ones_b);\n    THCudaFree(state, weight_b);\n    THCudaFree(state, bias_b);\n    return output;\n}\n\n__global__ void createBatchGemmBufferBackward(\n    float **grad_output_b,\n    float **columns_b,\n    float **ones_b,\n    float **weight_b,\n    float **grad_weight_b,\n    float **grad_bias_b,\n    float *grad_output,\n    float *columns,\n    float *ones,\n    float *weight,\n    float *grad_weight,\n    float *grad_bias,\n    const int grad_output_stride,\n    const int columns_stride,\n    const int ones_stride,\n    const int num_batches)\n{\n    const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < num_batches)\n    {\n        grad_output_b[idx] = grad_output + idx * grad_output_stride;\n        columns_b[idx] = columns + idx * columns_stride;\n        ones_b[idx] = ones + idx * ones_stride;\n\n        // share weights and bias within a Mini-Batch\n        weight_b[idx] = weight;\n        grad_weight_b[idx] = grad_weight;\n        grad_bias_b[idx] = grad_bias;\n    }\n}\n\nstd::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,\n                                             const at::Tensor &weight,\n                                             const at::Tensor &bias,\n                                             const at::Tensor &offset,\n                                             const at::Tensor &mask,\n                                             const at::Tensor &grad_output,\n                                             int kernel_h, int kernel_w,\n                                             int stride_h, int stride_w,\n                                             int pad_h, int pad_w,\n                                             int dilation_h, int dilation_w,\n                                             int deformable_group)\n{\n\n    THArgCheck(input.is_contiguous(), 1, \"input tensor has to be contiguous\");\n    THArgCheck(weight.is_contiguous(), 2, \"weight tensor has to be contiguous\");\n\n    AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n    AT_ASSERTM(weight.type().is_cuda(), \"weight must be a CUDA tensor\");\n    AT_ASSERTM(bias.type().is_cuda(), \"bias must be a CUDA tensor\");\n    AT_ASSERTM(offset.type().is_cuda(), \"offset must be a CUDA tensor\");\n    AT_ASSERTM(mask.type().is_cuda(), \"mask must be a CUDA tensor\");\n\n    const int batch = input.size(0);\n    const int channels = input.size(1);\n    const int height = input.size(2);\n    const int width = input.size(3);\n\n    const int channels_out = weight.size(0);\n    const int channels_kernel = weight.size(1);\n    const int kernel_h_ = weight.size(2);\n    const int kernel_w_ = weight.size(3);\n\n    AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,\n               \"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\", kernel_h_, kernel_w, kernel_h_, kernel_w_);\n\n    AT_ASSERTM(channels == channels_kernel,\n               \"Input shape and kernel channels wont match: (%d vs %d).\", channels, channels_kernel);\n\n    const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n    const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n    auto ones = at::ones({height_out, width_out}, input.options());\n    auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());\n    auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());\n\n    auto grad_input = at::zeros_like(input);\n    auto grad_weight = at::zeros_like(weight);\n    auto grad_bias = at::zeros_like(bias);\n    auto grad_offset = at::zeros_like(offset);\n    auto grad_mask = at::zeros_like(mask);\n\n    using scalar_t = float;\n\n    for (int b = 0; b < batch; b++)\n    {\n        auto input_n = input.select(0, b);\n        auto offset_n = offset.select(0, b);\n        auto mask_n = mask.select(0, b);\n        auto grad_output_n = grad_output.select(0, b);\n        auto grad_input_n = grad_input.select(0, b);\n        auto grad_offset_n = grad_offset.select(0, b);\n        auto grad_mask_n = grad_mask.select(0, b);\n\n        long m = channels * kernel_h * kernel_w;\n        long n = height_out * width_out;\n        long k = channels_out;\n\n        THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,\n                         grad_output_n.data_ptr<scalar_t>(), n,\n                         weight.data_ptr<scalar_t>(), m, 0.0f,\n                         columns.data_ptr<scalar_t>(), n);\n\n        // gradient w.r.t. input coordinate data\n        modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(),\n                                               columns.data_ptr<scalar_t>(),\n                                               input_n.data_ptr<scalar_t>(),\n                                               offset_n.data_ptr<scalar_t>(),\n                                               mask_n.data_ptr<scalar_t>(),\n                                               1, channels, height, width,\n                                               height_out, width_out, kernel_h, kernel_w,\n                                               pad_h, pad_w, stride_h, stride_w,\n                                               dilation_h, dilation_w, deformable_group,\n                                               grad_offset_n.data_ptr<scalar_t>(),\n                                               grad_mask_n.data_ptr<scalar_t>());\n        // gradient w.r.t. input data\n        modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(),\n                                         columns.data_ptr<scalar_t>(),\n                                         offset_n.data_ptr<scalar_t>(),\n                                         mask_n.data_ptr<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w,\n                                         dilation_h, dilation_w, deformable_group,\n                                         grad_input_n.data_ptr<scalar_t>());\n\n        // gradient w.r.t. weight, dWeight should accumulate across the batch and group\n        modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),\n                                         input_n.data_ptr<scalar_t>(),\n                                         offset_n.data_ptr<scalar_t>(),\n                                         mask_n.data_ptr<scalar_t>(),\n                                         1, channels, height, width,\n                                         height_out, width_out, kernel_h, kernel_w,\n                                         pad_h, pad_w, stride_h, stride_w,\n                                         dilation_h, dilation_w, deformable_group,\n                                         columns.data_ptr<scalar_t>());\n\n        long m_ = channels_out;\n        long n_ = channels * kernel_h * kernel_w;\n        long k_ = height_out * width_out;\n\n        THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,\n                         columns.data_ptr<scalar_t>(), k_,\n                         grad_output_n.data_ptr<scalar_t>(), k_, 1.0f,\n                         grad_weight.data_ptr<scalar_t>(), n_);\n\n        cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n        cublasOperation_t op = _cublasOpFromChar('t');\n        _cublasAdjustLdLevel2(k_, m_, &k_);\n        scalar_t* grad_output_n_float = grad_output_n.data_ptr<scalar_t>();\n        scalar_t* one_float = ones.data_ptr<scalar_t>();\n        scalar_t alpha = 1.0;\n        scalar_t beta = 1.0;\n        cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr<scalar_t>(), 1);\n\n    }\n    \n\n    return {\n        grad_input, grad_offset, grad_mask, grad_weight, grad_bias\n    };\n}\n"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu",
    "content": "#include \"dcn_v2_im2col_cuda.h\"\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <THC/THC.h>\n#include <THC/THCAtomics.cuh>\n#include <THC/THCDeviceUtils.cuh>\n\n#define CUDA_KERNEL_LOOP(i, n)                          \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x;   \\\n      i < (n);                                          \\\n      i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N)\n{\n  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;\n}\n\n\n__device__ float dmcn_im2col_bilinear_cuda(const float *bottom_data, const int data_width,\n                                      const int height, const int width, float h, float w)\n{\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  float lh = h - h_low;\n  float lw = w - w_low;\n  float hh = 1 - lh, hw = 1 - lw;\n\n  float v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n    v1 = bottom_data[h_low * data_width + w_low];\n  float v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n    v2 = bottom_data[h_low * data_width + w_high];\n  float v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n    v3 = bottom_data[h_high * data_width + w_low];\n  float v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n    v4 = bottom_data[h_high * data_width + w_high];\n\n  float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\n__device__ float dmcn_get_gradient_weight_cuda(float argmax_h, float argmax_w,\n                                          const int h, const int w, const int height, const int width)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  float weight = 0;\n  if (h == argmax_h_low && w == argmax_w_low)\n    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);\n  if (h == argmax_h_low && w == argmax_w_high)\n    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);\n  if (h == argmax_h_high && w == argmax_w_low)\n    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);\n  if (h == argmax_h_high && w == argmax_w_high)\n    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);\n  return weight;\n}\n\n__device__ float dmcn_get_coordinate_weight_cuda(float argmax_h, float argmax_w,\n                                            const int height, const int width, const float *im_data,\n                                            const int data_width, const int bp_dir)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  float weight = 0;\n\n  if (bp_dir == 0)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n  else if (bp_dir == 1)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n\n  return weight;\n}\n\n__global__ void modulated_deformable_im2col_gpu_kernel(const int n,\n                                                       const float *data_im, const float *data_offset, const float *data_mask,\n                                                       const int height, const int width, const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int num_channels, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       float *data_col)\n{\n  // launch channels * batch_size * height_col * width_col cores\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)\n    // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis\n\n    // index index of output matrix\n    const int w_col = index % width_col;\n    const int h_col = (index / width_col) % height_col;\n    // const int b_col = (index / width_col / height_col) % batch_size;\n    const int b_col = (index / width_col / height_col / num_channels) % batch_size;\n    // const int c_im = (index / width_col / height_col) / batch_size;\n    const int c_im = (index / width_col / height_col) % num_channels;\n    // const int c_col = c_im * kernel_h * kernel_w;\n    const int c_col = c_im * kernel_h * kernel_w;\n\n    // compute deformable group index\n    const int deformable_group_index = c_im / channel_per_deformable_group;\n\n    const int h_in = h_col * stride_h - pad_h;\n    const int w_in = w_col * stride_w - pad_w;\n\n    //  float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;\n    float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col;\n    //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;\n    const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;\n    const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n\n    const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    for (int i = 0; i < kernel_h; ++i)\n    {\n      for (int j = 0; j < kernel_w; ++j)\n      {\n        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;\n        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;\n        const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;\n        const float offset_h = data_offset_ptr[data_offset_h_ptr];\n        const float offset_w = data_offset_ptr[data_offset_w_ptr];\n        const float mask = data_mask_ptr[data_mask_hw_ptr];\n        float val = static_cast<float>(0);\n        const float h_im = h_in + i * dilation_h + offset_h;\n        const float w_im = w_in + j * dilation_w + offset_w;\n        //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {\n        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)\n        {\n          //const float map_h = i * dilation_h + offset_h;\n          //const float map_w = j * dilation_w + offset_w;\n          //const int cur_height = height - h_in;\n          //const int cur_width = width - w_in;\n          //val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, cur_height, cur_width, map_h, map_w);\n          val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, height, width, h_im, w_im);\n        }\n        *data_col_ptr = val * mask;\n        // data_col_ptr += batch_size * height_col * width_col;\n        data_col_ptr += height_col * width_col;\n      }\n    }\n  }\n}\n\n__global__ void modulated_deformable_col2im_gpu_kernel(const int n,\n                                                       const float *data_col, const float *data_offset, const float *data_mask,\n                                                       const int channels, const int height, const int width,\n                                                       const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       float *grad_im)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    const int j = (index / width_col / height_col / batch_size) % kernel_w;\n    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;\n    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / channel_per_deformable_group;\n\n    int w_out = index % width_col;\n    int h_out = (index / width_col) % height_col;\n    int b = (index / width_col / height_col) % batch_size;\n    int w_in = w_out * stride_w - pad_w;\n    int h_in = h_out * stride_h - pad_h;\n\n    const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;\n    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;\n    const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;\n    const float offset_h = data_offset_ptr[data_offset_h_ptr];\n    const float offset_w = data_offset_ptr[data_offset_w_ptr];\n    const float mask = data_mask_ptr[data_mask_hw_ptr];\n    const float cur_inv_h_data = h_in + i * dilation_h + offset_h;\n    const float cur_inv_w_data = w_in + j * dilation_w + offset_w;\n\n    const float cur_top_grad = data_col[index] * mask;\n    const int cur_h = (int)cur_inv_h_data;\n    const int cur_w = (int)cur_inv_w_data;\n    for (int dy = -2; dy <= 2; dy++)\n    {\n      for (int dx = -2; dx <= 2; dx++)\n      {\n        if (cur_h + dy >= 0 && cur_h + dy < height &&\n            cur_w + dx >= 0 && cur_w + dx < width &&\n            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&\n            abs(cur_inv_w_data - (cur_w + dx)) < 1)\n        {\n          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;\n          float weight = dmcn_get_gradient_weight_cuda(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);\n          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);\n        }\n      }\n    }\n  }\n}\n\n__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,\n                                                             const float *data_col, const float *data_im,\n                                                             const float *data_offset, const float *data_mask,\n                                                             const int channels, const int height, const int width,\n                                                             const int kernel_h, const int kernel_w,\n                                                             const int pad_h, const int pad_w,\n                                                             const int stride_h, const int stride_w,\n                                                             const int dilation_h, const int dilation_w,\n                                                             const int channel_per_deformable_group,\n                                                             const int batch_size, const int offset_channels, const int deformable_group,\n                                                             const int height_col, const int width_col,\n                                                             float *grad_offset, float *grad_mask)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    float val = 0, mval = 0;\n    int w = index % width_col;\n    int h = (index / width_col) % height_col;\n    int c = (index / width_col / height_col) % offset_channels;\n    int b = (index / width_col / height_col) / offset_channels;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / (2 * kernel_h * kernel_w);\n    const int col_step = kernel_h * kernel_w;\n    int cnt = 0;\n    const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;\n    const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;\n    const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;\n\n    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)\n    {\n      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;\n      const int bp_dir = offset_c % 2;\n\n      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;\n      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;\n      int w_out = col_pos % width_col;\n      int h_out = (col_pos / width_col) % height_col;\n      int w_in = w_out * stride_w - pad_w;\n      int h_in = h_out * stride_h - pad_h;\n      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);\n      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);\n      const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);\n      const float offset_h = data_offset_ptr[data_offset_h_ptr];\n      const float offset_w = data_offset_ptr[data_offset_w_ptr];\n      const float mask = data_mask_ptr[data_mask_hw_ptr];\n      float inv_h = h_in + i * dilation_h + offset_h;\n      float inv_w = w_in + j * dilation_w + offset_w;\n      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)\n      {\n        inv_h = inv_w = -2;\n      }\n      else\n      {\n        mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cuda(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);\n      }\n      const float weight = dmcn_get_coordinate_weight_cuda(\n          inv_h, inv_w,\n          height, width, data_im_ptr + cnt * height * width, width, bp_dir);\n      val += weight * data_col_ptr[col_pos] * mask;\n      cnt += 1;\n    }\n    // KERNEL_ASSIGN(grad_offset[index], offset_req, val);\n    grad_offset[index] = val;\n    if (offset_c % 2 == 0)\n      // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);\n      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;\n  }\n}\n\nvoid modulated_deformable_im2col_cuda(cudaStream_t stream,\n  const float* data_im, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w,\n  const int deformable_group, float* data_col) {\n  // num_axes should be smaller than block size\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * batch_size * height_col * width_col;\n  modulated_deformable_im2col_gpu_kernel\n      <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS,\n          0, stream>>>(\n      num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,\n      pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,\n      batch_size, channels, deformable_group, height_col, width_col, data_col);\n  \n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_im2col_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n\n}\n\nvoid modulated_deformable_col2im_cuda(cudaStream_t stream,\n  const float* data_col, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w, \n  const int deformable_group, float* grad_im){\n\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;\n  modulated_deformable_col2im_gpu_kernel\n      <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS,\n          0, stream>>>(\n        num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,\n        kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,\n        dilation_h, dilation_w, channel_per_deformable_group,\n        batch_size, deformable_group, height_col, width_col, grad_im);\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n\n}\n\nvoid modulated_deformable_col2im_coord_cuda(cudaStream_t stream,\n  const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,\n  const int batch_size, const int channels, const int height_im, const int width_im, \n  const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n  const int pad_h, const int pad_w, const int stride_h, const int stride_w, \n  const int dilation_h, const int dilation_w, \n  const int deformable_group,\n  float* grad_offset, float* grad_mask) {\n  const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;\n  const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;\n  modulated_deformable_col2im_coord_gpu_kernel\n      <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS,\n        0, stream>>>(\n        num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,\n        kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,\n        dilation_h, dilation_w, channel_per_deformable_group,\n        batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, \n        grad_offset, grad_mask);\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_coord_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n}"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.h",
    "content": "\n/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contributions by the University of California:\n * Copyright (c) 2014-2017 The Regents of the University of California (Regents)\n * All rights reserved.\n *\n * All other contributions:\n * Copyright (c) 2014-2017, the respective contributors\n * All rights reserved.\n *\n * Caffe uses a shared copyright model: each contributor holds copyright over\n * their contributions to Caffe. The project versioning records all such\n * contribution and copyright details. If a contributor wants to further mark\n * their specific copyright on a particular contribution, they should indicate\n * their copyright solely in the commit message of the change when it is\n * committed.\n *\n * LICENSE\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n * CONTRIBUTION AGREEMENT\n *\n * By contributing to the BVLC/caffe repository through pull-request, comment,\n * or otherwise, the contributor releases their content to the\n * license and copyright terms herein.\n *\n ***************** END Caffe Copyright Notice and Disclaimer ********************\n *\n * Copyright (c) 2018 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file modulated_deformable_im2col.h\n * \\brief Function definitions of converting an image to\n * column matrix based on kernel, padding, dilation, and offset.\n * These functions are mainly used in deformable convolution operators.\n * \\ref: https://arxiv.org/abs/1811.11168\n * \\author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu\n */\n\n/***************** Adapted by Charles Shang *********************/\n\n#ifndef DCN_V2_IM2COL_CUDA\n#define DCN_V2_IM2COL_CUDA\n\n#ifdef __cplusplus\nextern \"C\"\n{\n#endif\n\n  void modulated_deformable_im2col_cuda(cudaStream_t stream,\n                                        const float *data_im, const float *data_offset, const float *data_mask,\n                                        const int batch_size, const int channels, const int height_im, const int width_im,\n                                        const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                        const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                        const int dilation_h, const int dilation_w,\n                                        const int deformable_group, float *data_col);\n\n  void modulated_deformable_col2im_cuda(cudaStream_t stream,\n                                        const float *data_col, const float *data_offset, const float *data_mask,\n                                        const int batch_size, const int channels, const int height_im, const int width_im,\n                                        const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                        const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                        const int dilation_h, const int dilation_w,\n                                        const int deformable_group, float *grad_im);\n\n  void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,\n                                         const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,\n                                         const int batch_size, const int channels, const int height_im, const int width_im,\n                                         const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n                                         const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                         const int dilation_h, const int dilation_w,\n                                         const int deformable_group,\n                                         float *grad_offset, float *grad_mask);\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu",
    "content": "/*!\n * Copyright (c) 2017 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file deformable_psroi_pooling.cu\n * \\brief\n * \\author Yi Li, Guodong Zhang, Jifeng Dai\n*/\n/***************** Adapted by Charles Shang *********************/\n\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n#include <iostream>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include <THC/THC.h>\n#include <THC/THCAtomics.cuh>\n#include <THC/THCDeviceUtils.cuh>\n\n#define CUDA_KERNEL_LOOP(i, n)                        \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \\\n       i < (n);                                       \\\n       i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\ninline int GET_BLOCKS(const int N)\n{\n  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;\n}\n\ntemplate <typename T>\n__device__ T bilinear_interp_cuda(\n    const T *data,\n    const T x,\n    const T y,\n    const int width,\n    const int height)\n{\n  int x1 = floor(x);\n  int x2 = ceil(x);\n  int y1 = floor(y);\n  int y2 = ceil(y);\n  T dist_x = static_cast<T>(x - x1);\n  T dist_y = static_cast<T>(y - y1);\n  T value11 = data[y1 * width + x1];\n  T value12 = data[y2 * width + x1];\n  T value21 = data[y1 * width + x2];\n  T value22 = data[y2 * width + x2];\n  T value = (1 - dist_x) * (1 - dist_y) * value11 +\n            (1 - dist_x) * dist_y * value12 +\n            dist_x * (1 - dist_y) * value21 +\n            dist_x * dist_y * value22;\n  return value;\n}\n\ntemplate <typename T>\n__global__ void DeformablePSROIPoolForwardKernelCuda(\n    const int count,\n    const T *bottom_data,\n    const T spatial_scale,\n    const int channels,\n    const int height, const int width,\n    const int pooled_height, const int pooled_width,\n    const T *bottom_rois, const T *bottom_trans,\n    const int no_trans,\n    const T trans_std,\n    const int sample_per_part,\n    const int output_dim,\n    const int group_size,\n    const int part_size,\n    const int num_classes,\n    const int channels_each_class,\n    T *top_data,\n    T *top_count)\n{\n  CUDA_KERNEL_LOOP(index, count)\n  {\n    // The output is in order (n, ctop, ph, pw)\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int ctop = (index / pooled_width / pooled_height) % output_dim;\n    int n = index / pooled_width / pooled_height / output_dim;\n\n    // [start, end) interval for spatial sampling\n    const T *offset_bottom_rois = bottom_rois + n * 5;\n    int roi_batch_ind = offset_bottom_rois[0];\n    T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;\n    T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;\n    T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;\n    T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;\n\n    // Force too small ROIs to be 1x1\n    T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0\n    T roi_height = max(roi_end_h - roi_start_h, 0.1);\n\n    // Compute w and h at bottom\n    T bin_size_h = roi_height / static_cast<T>(pooled_height);\n    T bin_size_w = roi_width / static_cast<T>(pooled_width);\n\n    T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);\n    T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);\n\n    int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);\n    int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);\n    int class_id = ctop / channels_each_class;\n    T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;\n    T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;\n\n    T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;\n    wstart += trans_x * roi_width;\n    T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;\n    hstart += trans_y * roi_height;\n\n    T sum = 0;\n    int count = 0;\n    int gw = floor(static_cast<T>(pw) * group_size / pooled_width);\n    int gh = floor(static_cast<T>(ph) * group_size / pooled_height);\n    gw = min(max(gw, 0), group_size - 1);\n    gh = min(max(gh, 0), group_size - 1);\n\n    const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;\n    for (int ih = 0; ih < sample_per_part; ih++)\n    {\n      for (int iw = 0; iw < sample_per_part; iw++)\n      {\n        T w = wstart + iw * sub_bin_size_w;\n        T h = hstart + ih * sub_bin_size_h;\n        // bilinear interpolation\n        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)\n        {\n          continue;\n        }\n        w = min(max(w, 0.), width - 1.);\n        h = min(max(h, 0.), height - 1.);\n        int c = (ctop * group_size + gh) * group_size + gw;\n        T val = bilinear_interp_cuda(offset_bottom_data + c * height * width, w, h, width, height);\n        sum += val;\n        count++;\n      }\n    }\n    top_data[index] = count == 0 ? static_cast<T>(0) : sum / count;\n    top_count[index] = count;\n  }\n}\n\ntemplate <typename T>\n__global__ void DeformablePSROIPoolBackwardAccKernelCuda(\n    const int count,\n    const T *top_diff,\n    const T *top_count,\n    const int num_rois,\n    const T spatial_scale,\n    const int channels,\n    const int height, const int width,\n    const int pooled_height, const int pooled_width,\n    const int output_dim,\n    T *bottom_data_diff, T *bottom_trans_diff,\n    const T *bottom_data,\n    const T *bottom_rois,\n    const T *bottom_trans,\n    const int no_trans,\n    const T trans_std,\n    const int sample_per_part,\n    const int group_size,\n    const int part_size,\n    const int num_classes,\n    const int channels_each_class)\n{\n  CUDA_KERNEL_LOOP(index, count)\n  {\n    // The output is in order (n, ctop, ph, pw)\n    int pw = index % pooled_width;\n    int ph = (index / pooled_width) % pooled_height;\n    int ctop = (index / pooled_width / pooled_height) % output_dim;\n    int n = index / pooled_width / pooled_height / output_dim;\n\n    // [start, end) interval for spatial sampling\n    const T *offset_bottom_rois = bottom_rois + n * 5;\n    int roi_batch_ind = offset_bottom_rois[0];\n    T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;\n    T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;\n    T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;\n    T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;\n\n    // Force too small ROIs to be 1x1\n    T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0\n    T roi_height = max(roi_end_h - roi_start_h, 0.1);\n\n    // Compute w and h at bottom\n    T bin_size_h = roi_height / static_cast<T>(pooled_height);\n    T bin_size_w = roi_width / static_cast<T>(pooled_width);\n\n    T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);\n    T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);\n\n    int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);\n    int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);\n    int class_id = ctop / channels_each_class;\n    T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;\n    T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;\n\n    T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;\n    wstart += trans_x * roi_width;\n    T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;\n    hstart += trans_y * roi_height;\n\n    if (top_count[index] <= 0)\n    {\n      continue;\n    }\n    T diff_val = top_diff[index] / top_count[index];\n    const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;\n    T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;\n    int gw = floor(static_cast<T>(pw) * group_size / pooled_width);\n    int gh = floor(static_cast<T>(ph) * group_size / pooled_height);\n    gw = min(max(gw, 0), group_size - 1);\n    gh = min(max(gh, 0), group_size - 1);\n\n    for (int ih = 0; ih < sample_per_part; ih++)\n    {\n      for (int iw = 0; iw < sample_per_part; iw++)\n      {\n        T w = wstart + iw * sub_bin_size_w;\n        T h = hstart + ih * sub_bin_size_h;\n        // bilinear interpolation\n        if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)\n        {\n          continue;\n        }\n        w = min(max(w, 0.), width - 1.);\n        h = min(max(h, 0.), height - 1.);\n        int c = (ctop * group_size + gh) * group_size + gw;\n        // backward on feature\n        int x0 = floor(w);\n        int x1 = ceil(w);\n        int y0 = floor(h);\n        int y1 = ceil(h);\n        T dist_x = w - x0, dist_y = h - y0;\n        T q00 = (1 - dist_x) * (1 - dist_y);\n        T q01 = (1 - dist_x) * dist_y;\n        T q10 = dist_x * (1 - dist_y);\n        T q11 = dist_x * dist_y;\n        int bottom_index_base = c * height * width;\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);\n        atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);\n\n        if (no_trans)\n        {\n          continue;\n        }\n        T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];\n        T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];\n        T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];\n        T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];\n        T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;\n        diff_x *= roi_width;\n        T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;\n        diff_y *= roi_height;\n\n        atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);\n        atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);\n      }\n    }\n  }\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,\n                                  const at::Tensor &bbox,\n                                  const at::Tensor &trans,\n                                  const int no_trans,\n                                  const float spatial_scale,\n                                  const int output_dim,\n                                  const int group_size,\n                                  const int pooled_size,\n                                  const int part_size,\n                                  const int sample_per_part,\n                                  const float trans_std)\n{\n  AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n  AT_ASSERTM(bbox.type().is_cuda(), \"rois must be a CUDA tensor\");\n  AT_ASSERTM(trans.type().is_cuda(), \"trans must be a CUDA tensor\");\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n  const int channels_trans = no_trans ? 2 : trans.size(1);\n  const int num_bbox = bbox.size(0);\n\n  AT_ASSERTM(channels == output_dim, \"input channels and output channels must equal\");\n  auto pooled_height = pooled_size;\n  auto pooled_width = pooled_size;\n\n  auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());\n  long out_size = num_bbox * output_dim * pooled_height * pooled_width;\n  auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());\n\n  const int num_classes = no_trans ? 1 : channels_trans / 2;\n  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  if (out.numel() == 0)\n  {\n    THCudaCheck(cudaGetLastError());\n    return std::make_tuple(out, top_count);\n  }\n\n  dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));\n  dim3 block(512);\n\n  AT_DISPATCH_FLOATING_TYPES(input.type(), \"dcn_v2_psroi_pooling_cuda_forward\", [&] {\n    DeformablePSROIPoolForwardKernelCuda<scalar_t><<<grid, block, 0, stream>>>(\n        out_size,\n        input.contiguous().data_ptr<scalar_t>(),\n        spatial_scale,\n        channels,\n        height, width,\n        pooled_height,\n        pooled_width,\n        bbox.contiguous().data_ptr<scalar_t>(),\n        trans.contiguous().data_ptr<scalar_t>(),\n        no_trans,\n        trans_std,\n        sample_per_part,\n        output_dim,\n        group_size,\n        part_size,\n        num_classes,\n        channels_each_class,\n        out.data_ptr<scalar_t>(),\n        top_count.data_ptr<scalar_t>());\n  });\n  THCudaCheck(cudaGetLastError());\n  return std::make_tuple(out, top_count);\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,\n                                   const at::Tensor &input,\n                                   const at::Tensor &bbox,\n                                   const at::Tensor &trans,\n                                   const at::Tensor &top_count,\n                                   const int no_trans,\n                                   const float spatial_scale,\n                                   const int output_dim,\n                                   const int group_size,\n                                   const int pooled_size,\n                                   const int part_size,\n                                   const int sample_per_part,\n                                   const float trans_std)\n{\n  AT_ASSERTM(out_grad.type().is_cuda(), \"out_grad must be a CUDA tensor\");\n  AT_ASSERTM(input.type().is_cuda(), \"input must be a CUDA tensor\");\n  AT_ASSERTM(bbox.type().is_cuda(), \"bbox must be a CUDA tensor\");\n  AT_ASSERTM(trans.type().is_cuda(), \"trans must be a CUDA tensor\");\n  AT_ASSERTM(top_count.type().is_cuda(), \"top_count must be a CUDA tensor\");\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n  const int channels_trans = no_trans ? 2 : trans.size(1);\n  const int num_bbox = bbox.size(0);\n\n  AT_ASSERTM(channels == output_dim, \"input channels and output channels must equal\");\n  auto pooled_height = pooled_size;\n  auto pooled_width = pooled_size;\n  long out_size = num_bbox * output_dim * pooled_height * pooled_width;\n  const int num_classes = no_trans ? 1 : channels_trans / 2;\n  const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;\n\n  auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());\n  auto trans_grad = at::zeros_like(trans);\n\n  if (input_grad.numel() == 0)\n  {\n    THCudaCheck(cudaGetLastError());\n    return std::make_tuple(input_grad, trans_grad);\n  }\n\n  dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));\n  dim3 block(512);\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n  AT_DISPATCH_FLOATING_TYPES(out_grad.type(), \"dcn_v2_psroi_pooling_cuda_backward\", [&] {\n    DeformablePSROIPoolBackwardAccKernelCuda<scalar_t><<<grid, block, 0, stream>>>(\n        out_size,\n        out_grad.contiguous().data_ptr<scalar_t>(),\n        top_count.contiguous().data_ptr<scalar_t>(),\n        num_bbox,\n        spatial_scale,\n        channels,\n        height,\n        width,\n        pooled_height,\n        pooled_width,\n        output_dim,\n        input_grad.contiguous().data_ptr<scalar_t>(),\n        trans_grad.contiguous().data_ptr<scalar_t>(),\n        input.contiguous().data_ptr<scalar_t>(),\n        bbox.contiguous().data_ptr<scalar_t>(),\n        trans.contiguous().data_ptr<scalar_t>(),\n        no_trans,\n        trans_std,\n        sample_per_part,\n        group_size,\n        part_size,\n        num_classes,\n        channels_each_class);\n  });\n  THCudaCheck(cudaGetLastError());\n  return std::make_tuple(input_grad, trans_grad);\n}"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/cuda/vision.h",
    "content": "#pragma once\n#include <torch/extension.h>\n#include <ATen/div_rtn.h>\nat::Tensor\ndcn_v2_cuda_forward(const at::Tensor &input,\n                    const at::Tensor &weight,\n                    const at::Tensor &bias,\n                    const at::Tensor &offset,\n                    const at::Tensor &mask,\n                    const int kernel_h,\n                    const int kernel_w,\n                    const int stride_h,\n                    const int stride_w,\n                    const int pad_h,\n                    const int pad_w,\n                    const int dilation_h,\n                    const int dilation_w,\n                    const int deformable_group);\n\nstd::vector<at::Tensor>\ndcn_v2_cuda_backward(const at::Tensor &input,\n                     const at::Tensor &weight,\n                     const at::Tensor &bias,\n                     const at::Tensor &offset,\n                     const at::Tensor &mask,\n                     const at::Tensor &grad_output,\n                     int kernel_h, int kernel_w,\n                     int stride_h, int stride_w,\n                     int pad_h, int pad_w,\n                     int dilation_h, int dilation_w,\n                     int deformable_group);\n\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,\n                                  const at::Tensor &bbox,\n                                  const at::Tensor &trans,\n                                  const int no_trans,\n                                  const float spatial_scale,\n                                  const int output_dim,\n                                  const int group_size,\n                                  const int pooled_size,\n                                  const int part_size,\n                                  const int sample_per_part,\n                                  const float trans_std);\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,\n                                   const at::Tensor &input,\n                                   const at::Tensor &bbox,\n                                   const at::Tensor &trans,\n                                   const at::Tensor &top_count,\n                                   const int no_trans,\n                                   const float spatial_scale,\n                                   const int output_dim,\n                                   const int group_size,\n                                   const int pooled_size,\n                                   const int part_size,\n                                   const int sample_per_part,\n                                   const float trans_std);"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/dcn_v2.h",
    "content": "#pragma once\n\n#include \"cpu/vision.h\"\n\n#ifdef WITH_CUDA\n#include \"cuda/vision.h\"\n#endif\n\nat::Tensor\ndcn_v2_forward(const at::Tensor &input,\n               const at::Tensor &weight,\n               const at::Tensor &bias,\n               const at::Tensor &offset,\n               const at::Tensor &mask,\n               const int kernel_h,\n               const int kernel_w,\n               const int stride_h,\n               const int stride_w,\n               const int pad_h,\n               const int pad_w,\n               const int dilation_h,\n               const int dilation_w,\n               const int deformable_group)\n{\n    if (input.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return dcn_v2_cuda_forward(input, weight, bias, offset, mask,\n                                   kernel_h, kernel_w,\n                                   stride_h, stride_w,\n                                   pad_h, pad_w,\n                                   dilation_h, dilation_w,\n                                   deformable_group);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    else{\n        return dcn_v2_cpu_forward(input, weight, bias, offset, mask,\n                                   kernel_h, kernel_w,\n                                   stride_h, stride_w,\n                                   pad_h, pad_w,\n                                   dilation_h, dilation_w,\n                                   deformable_group);\n    }\n}\n\nstd::vector<at::Tensor>\ndcn_v2_backward(const at::Tensor &input,\n                const at::Tensor &weight,\n                const at::Tensor &bias,\n                const at::Tensor &offset,\n                const at::Tensor &mask,\n                const at::Tensor &grad_output,\n                int kernel_h, int kernel_w,\n                int stride_h, int stride_w,\n                int pad_h, int pad_w,\n                int dilation_h, int dilation_w,\n                int deformable_group)\n{\n    if (input.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return dcn_v2_cuda_backward(input,\n                                    weight,\n                                    bias,\n                                    offset,\n                                    mask,\n                                    grad_output,\n                                    kernel_h, kernel_w,\n                                    stride_h, stride_w,\n                                    pad_h, pad_w,\n                                    dilation_h, dilation_w,\n                                    deformable_group);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    else{\n        return dcn_v2_cpu_backward(input,\n                                    weight,\n                                    bias,\n                                    offset,\n                                    mask,\n                                    grad_output,\n                                    kernel_h, kernel_w,\n                                    stride_h, stride_w,\n                                    pad_h, pad_w,\n                                    dilation_h, dilation_w,\n                                    deformable_group);\n    }\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_forward(const at::Tensor &input,\n                             const at::Tensor &bbox,\n                             const at::Tensor &trans,\n                             const int no_trans,\n                             const float spatial_scale,\n                             const int output_dim,\n                             const int group_size,\n                             const int pooled_size,\n                             const int part_size,\n                             const int sample_per_part,\n                             const float trans_std)\n{\n    if (input.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return dcn_v2_psroi_pooling_cuda_forward(input,\n                                                 bbox,\n                                                 trans,\n                                                 no_trans,\n                                                 spatial_scale,\n                                                 output_dim,\n                                                 group_size,\n                                                 pooled_size,\n                                                 part_size,\n                                                 sample_per_part,\n                                                 trans_std);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    else{\n        return dcn_v2_psroi_pooling_cpu_forward(input,\n                                                 bbox,\n                                                 trans,\n                                                 no_trans,\n                                                 spatial_scale,\n                                                 output_dim,\n                                                 group_size,\n                                                 pooled_size,\n                                                 part_size,\n                                                 sample_per_part,\n                                                 trans_std);\n    }\n}\n\nstd::tuple<at::Tensor, at::Tensor>\ndcn_v2_psroi_pooling_backward(const at::Tensor &out_grad,\n                              const at::Tensor &input,\n                              const at::Tensor &bbox,\n                              const at::Tensor &trans,\n                              const at::Tensor &top_count,\n                              const int no_trans,\n                              const float spatial_scale,\n                              const int output_dim,\n                              const int group_size,\n                              const int pooled_size,\n                              const int part_size,\n                              const int sample_per_part,\n                              const float trans_std)\n{\n    if (input.type().is_cuda())\n    {\n#ifdef WITH_CUDA\n        return dcn_v2_psroi_pooling_cuda_backward(out_grad,\n                                                  input,\n                                                  bbox,\n                                                  trans,\n                                                  top_count,\n                                                  no_trans,\n                                                  spatial_scale,\n                                                  output_dim,\n                                                  group_size,\n                                                  pooled_size,\n                                                  part_size,\n                                                  sample_per_part,\n                                                  trans_std);\n#else\n        AT_ERROR(\"Not compiled with GPU support\");\n#endif\n    }\n    else{\n        return dcn_v2_psroi_pooling_cpu_backward(out_grad,\n                                                  input,\n                                                  bbox,\n                                                  trans,\n                                                  top_count,\n                                                  no_trans,\n                                                  spatial_scale,\n                                                  output_dim,\n                                                  group_size,\n                                                  pooled_size,\n                                                  part_size,\n                                                  sample_per_part,\n                                                  trans_std);\n    }\n}"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/src/vision.cpp",
    "content": "\n#include \"dcn_v2.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"dcn_v2_forward\", &dcn_v2_forward, \"dcn_v2_forward\");\n  m.def(\"dcn_v2_backward\", &dcn_v2_backward, \"dcn_v2_backward\");\n  m.def(\"dcn_v2_psroi_pooling_forward\", &dcn_v2_psroi_pooling_forward, \"dcn_v2_psroi_pooling_forward\");\n  m.def(\"dcn_v2_psroi_pooling_backward\", &dcn_v2_psroi_pooling_backward, \"dcn_v2_psroi_pooling_backward\");\n}\n"
  },
  {
    "path": "code/synthetic/bsrt/model/DCNv2/test.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nimport time\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import gradcheck\n\nfrom dcn_v2 import dcn_v2_conv, DCNv2, DCN\nfrom dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling\n\ndeformable_groups = 1\nN, inC, inH, inW = 2, 2, 4, 4\noutC = 2\nkH, kW = 3, 3\n\n\ndef conv_identify(weight, bias):\n    weight.data.zero_()\n    bias.data.zero_()\n    o, i, h, w = weight.shape\n    y = h//2\n    x = w//2\n    for p in range(i):\n        for q in range(o):\n            if p == q:\n                weight.data[q, p, y, x] = 1.0\n\n\ndef check_zero_offset():\n    conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW,\n                            kernel_size=(kH, kW),\n                            stride=(1, 1),\n                            padding=(1, 1),\n                            bias=True).cuda()\n\n    conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW,\n                          kernel_size=(kH, kW),\n                          stride=(1, 1),\n                          padding=(1, 1),\n                          bias=True).cuda()\n\n    dcn_v2 = DCNv2(inC, outC, (kH, kW),\n                   stride=1, padding=1, dilation=1,\n                   deformable_groups=deformable_groups).cuda()\n\n    conv_offset.weight.data.zero_()\n    conv_offset.bias.data.zero_()\n    conv_mask.weight.data.zero_()\n    conv_mask.bias.data.zero_()\n    conv_identify(dcn_v2.weight, dcn_v2.bias)\n\n    input = torch.randn(N, inC, inH, inW).cuda()\n    offset = conv_offset(input)\n    mask = conv_mask(input)\n    mask = torch.sigmoid(mask)\n    output = dcn_v2(input, offset, mask)\n    output *= 2\n    d = (input - output).abs().max()\n    if d < 1e-10:\n        print('Zero offset passed')\n    else:\n        print('Zero offset failed')\n        print(input)\n        print(output)\n\ndef check_gradient_dconv():\n\n    input = torch.rand(N, inC, inH, inW).cuda() * 0.01\n    input.requires_grad = True\n\n    offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2\n    # offset.data.zero_()\n    # offset.data -= 0.5\n    offset.requires_grad = True\n\n    mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda()\n    # mask.data.zero_()\n    mask.requires_grad = True\n    mask = torch.sigmoid(mask)\n\n    weight = torch.randn(outC, inC, kH, kW).cuda()\n    weight.requires_grad = True\n\n    bias = torch.rand(outC).cuda()\n    bias.requires_grad = True\n\n    stride = 1\n    padding = 1\n    dilation = 1\n\n    print('check_gradient_dconv: ',\n          gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias,\n                    stride, padding, dilation, deformable_groups),\n                    eps=1e-3, atol=1e-4, rtol=1e-2))\n\n\ndef check_pooling_zero_offset():\n\n    input = torch.randn(2, 16, 64, 64).cuda().zero_()\n    input[0, :, 16:26, 16:26] = 1.\n    input[1, :, 10:20, 20:30] = 2.\n    rois = torch.tensor([\n        [0, 65, 65, 103, 103],\n        [1, 81, 41, 119, 79],\n    ]).cuda().float()\n    pooling = DCNv2Pooling(spatial_scale=1.0 / 4,\n                           pooled_size=7,\n                           output_dim=16,\n                           no_trans=True,\n                           group_size=1,\n                           trans_std=0.0).cuda()\n\n    out = pooling(input, rois, input.new())\n    s = ', '.join(['%f' % out[i, :, :, :].mean().item()\n                   for i in range(rois.shape[0])])\n    print(s)\n\n    dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,\n                            pooled_size=7,\n                            output_dim=16,\n                            no_trans=False,\n                            group_size=1,\n                            trans_std=0.0).cuda()\n    offset = torch.randn(20, 2, 7, 7).cuda().zero_()\n    dout = dpooling(input, rois, offset)\n    s = ', '.join(['%f' % dout[i, :, :, :].mean().item()\n                   for i in range(rois.shape[0])])\n    print(s)\n\n\ndef check_gradient_dpooling():\n    input = torch.randn(2, 3, 5, 5).cuda() * 0.01\n    N = 4\n    batch_inds = torch.randint(2, (N, 1)).cuda().float()\n    x = torch.rand((N, 1)).cuda().float() * 15\n    y = torch.rand((N, 1)).cuda().float() * 15\n    w = torch.rand((N, 1)).cuda().float() * 10\n    h = torch.rand((N, 1)).cuda().float() * 10\n    rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)\n    offset = torch.randn(N, 2, 3, 3).cuda()\n    input.requires_grad = True\n    offset.requires_grad = True\n\n    spatial_scale = 1.0 / 4\n    pooled_size = 3\n    output_dim = 3\n    no_trans = 0\n    group_size = 1\n    trans_std = 0.0\n    sample_per_part = 4\n    part_size = pooled_size\n\n    print('check_gradient_dpooling:',\n          gradcheck(dcn_v2_pooling, (input, rois, offset,\n                                     spatial_scale,\n                                     pooled_size,\n                                     output_dim,\n                                     no_trans,\n                                     group_size,\n                                     part_size,\n                                     sample_per_part,\n                                     trans_std),\n                    eps=1e-4))\n\n\ndef example_dconv():\n    input = torch.randn(2, 64, 128, 128).cuda()\n    # wrap all things (offset and mask) in DCN\n    dcn = DCN(64, 64, kernel_size=(3, 3), stride=1,\n              padding=1, deformable_groups=2).cuda()\n    # print(dcn.weight.shape, input.shape)\n    output = dcn(input)\n    targert = output.new(*output.size())\n    targert.data.uniform_(-0.01, 0.01)\n    error = (targert - output).mean()\n    error.backward()\n    print(output.shape)\n\n\ndef example_dpooling():\n    input = torch.randn(2, 32, 64, 64).cuda()\n    batch_inds = torch.randint(2, (20, 1)).cuda().float()\n    x = torch.randint(256, (20, 1)).cuda().float()\n    y = torch.randint(256, (20, 1)).cuda().float()\n    w = torch.randint(64, (20, 1)).cuda().float()\n    h = torch.randint(64, (20, 1)).cuda().float()\n    rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)\n    offset = torch.randn(20, 2, 7, 7).cuda()\n    input.requires_grad = True\n    offset.requires_grad = True\n\n    # normal roi_align\n    pooling = DCNv2Pooling(spatial_scale=1.0 / 4,\n                           pooled_size=7,\n                           output_dim=32,\n                           no_trans=True,\n                           group_size=1,\n                           trans_std=0.1).cuda()\n\n    # deformable pooling\n    dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,\n                            pooled_size=7,\n                            output_dim=32,\n                            no_trans=False,\n                            group_size=1,\n                            trans_std=0.1).cuda()\n\n    out = pooling(input, rois, offset)\n    dout = dpooling(input, rois, offset)\n    print(out.shape)\n    print(dout.shape)\n\n    target_out = out.new(*out.size())\n    target_out.data.uniform_(-0.01, 0.01)\n    target_dout = dout.new(*dout.size())\n    target_dout.data.uniform_(-0.01, 0.01)\n    e = (target_out - out).mean()\n    e.backward()\n    e = (target_dout - dout).mean()\n    e.backward()\n\n\ndef example_mdpooling():\n    input = torch.randn(2, 32, 64, 64).cuda()\n    input.requires_grad = True\n    batch_inds = torch.randint(2, (20, 1)).cuda().float()\n    x = torch.randint(256, (20, 1)).cuda().float()\n    y = torch.randint(256, (20, 1)).cuda().float()\n    w = torch.randint(64, (20, 1)).cuda().float()\n    h = torch.randint(64, (20, 1)).cuda().float()\n    rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)\n\n    # mdformable pooling (V2)\n    dpooling = DCNPooling(spatial_scale=1.0 / 4,\n                          pooled_size=7,\n                          output_dim=32,\n                          no_trans=False,\n                          group_size=1,\n                          trans_std=0.1,\n                          deform_fc_dim=1024).cuda()\n\n    dout = dpooling(input, rois)\n    target = dout.new(*dout.size())\n    target.data.uniform_(-0.1, 0.1)\n    error = (target - dout).mean()\n    error.backward()\n    print(dout.shape)\n\n\nif __name__ == '__main__':\n\n    example_dconv()\n    example_dpooling()\n    example_mdpooling()\n\n    check_pooling_zero_offset()\n    # zero offset check\n    if inC == outC:\n        check_zero_offset()\n\n    check_gradient_dpooling()\n    check_gradient_dconv()\n    # \"\"\"\n    # ****** Note: backward is not reentrant error may not be a serious problem,\n    # ****** since the max error is less than 1e-7,\n    # ****** Still looking for what trigger this problem\n    # \"\"\"\n"
  },
  {
    "path": "code/synthetic/bsrt/model/__init__.py",
    "content": "import os\nfrom importlib import import_module\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel as P\nimport torch.utils.model_zoo\nimport time\n\nclass Model(nn.Module):\n    def __init__(self, args, ckp):\n        super(Model, self).__init__()\n        self.args = args\n        if args.local_rank == 0:\n            print(\"Making model: \", args.model)\n            print(\"Patch size: \", args.patch_size)\n\n\n        self.scale = args.scale\n        self.idx_scale = 0\n        self.input_large = (args.model == 'VDSR')\n        self.self_ensemble = args.self_ensemble\n        self.chop = args.chop\n        self.precision = args.precision\n        self.cpu = args.cpu\n        self.device = torch.device('cpu' if args.cpu else 'cuda:%d' % args.local_rank)\n        self.n_GPUs = args.n_GPUs\n        self.save_models = args.save_models\n\n        module = import_module('model.' + args.model.lower())\n        self.model = module.make_model(args).to(self.device)\n\n        if args.precision == 'half':\n            self.model.half()\n\n        self.load(\n            ckp.get_path('model'),\n            pre_train=args.pre_train,\n            resume=args.resume,\n            cpu=args.cpu\n        )\n\n        # time.sleep(3)\n\n        if args.n_GPUs > 1:\n            self.model = nn.parallel.DistributedDataParallel(self.model,\n                device_ids=[args.local_rank],\n                find_unused_parameters=True\n                )\n\n        print(self.model, file=ckp.log_file)\n\n    def forward(self, x, idx_scale):\n        self.idx_scale = idx_scale\n        if hasattr(self.model, 'set_scale'):\n            self.model.set_scale(idx_scale)\n\n        if self.training:\n            # if self.n_GPUs > 1:\n            return self.model(x)\n        else:\n            if self.chop:\n                forward_function = self.forward_chop\n            else:\n                forward_function = self.model.forward\n\n            if self.self_ensemble:\n                return self.forward_x8(x, forward_function=forward_function)\n            else:\n                # return self.model(x)\n                return forward_function(x)\n\n    def save(self, apath, epoch, is_best=False):\n        save_dirs = [os.path.join(apath, 'model_latest.pt')]\n\n        if is_best:\n            save_dirs.append(os.path.join(apath, 'model_best.pt'))\n        if self.save_models:\n            save_dirs.append(\n                os.path.join(apath, 'model_{}.pt'.format(epoch))\n            )\n        if self.n_GPUs > 1:\n            model = self.model.module\n        else:\n            model = self.model\n\n        for s in save_dirs:\n            torch.save(self.model.state_dict(), s)\n\n    def load(self, apath, pre_train='', resume=-1, cpu=False):\n        load_from = None\n        kwargs = {}\n        if cpu:\n            kwargs = {'map_location': lambda storage, loc: storage}\n\n        if resume == -1:\n            load_from = torch.load(\n                os.path.join(apath, 'model_latest.pt'),\n                **kwargs\n            )\n        elif resume == 0:\n            if pre_train == 'download':\n                print('Download the model')\n                dir_model = os.path.join('..', 'models')\n                os.makedirs(dir_model, exist_ok=True)\n                load_from = torch.utils.model_zoo.load_url(\n                    self.model.url,\n                    model_dir=dir_model,\n                    **kwargs\n                )\n            elif pre_train:\n                if self.args.local_rank == 0:\n                    print('Load the model from {}'.format(pre_train))\n                map_location = {'cuda:%d' % 0: 'cuda:%d' % self.args.local_rank}\n                load_from = torch.load(pre_train, map_location=map_location)\n        else:\n            load_from = torch.load(\n                os.path.join(apath, 'model_{}.pt'.format(resume)),\n                **kwargs\n            )\n\n        if load_from:\n            self.model.load_state_dict(load_from, strict=True)\n            del load_from\n\n        \n        if self.args.finetune:\n            if self.args.local_rank == 0:\n                print('finetune')\n            for param in self.model.parameters():\n                param.requires_grad = False\n\n            for param in self.model.HRconv.parameters():\n                param.requires_grad = True\n            for param in self.model.conv_last.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_prelayer:\n            if self.args.local_rank == 0:\n                print('finetune_prelayer')\n            if self.args.swinfeature:\n                if self.args.model == 'MBSRT':\n                    for param in self.model.pre_layer1.parameters():\n                        param.requires_grad = True\n                    for param in self.model.pre_layer2.parameters():\n                        param.requires_grad = True\n                else:\n                    for param in self.model.pre_layers.parameters():\n                        param.requires_grad = True\n            else:\n                for param in self.model.feature_extraction.parameters():\n                    param.requires_grad = True\n\n            for param in self.model.conv_after_pre_layer.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_align:\n            if self.args.local_rank == 0:\n                print('finetune_align')\n            for param in self.model.align.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_spynet:\n            if self.args.local_rank == 0:\n                print('finetune_spynet')\n            for param in self.model.spynet.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_swin:\n            if self.args.local_rank == 0:\n                print('finetune_swin')\n            for param in self.model.layers.parameters():\n                param.requires_grad = True\n            for param in self.model.conv_after_body.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_upconv:\n            if self.args.local_rank == 0:\n                print('finetune_upconv')\n            for param in self.model.upconv1.parameters():\n                param.requires_grad = True\n            for param in self.model.upconv2.parameters():\n                param.requires_grad = True\n            for param in self.model.skipup1.parameters():\n                param.requires_grad = True\n            for param in self.model.skipup2.parameters():\n                param.requires_grad = True\n\n        if self.args.finetune_conv:\n            if self.args.local_rank == 0:\n                print('finetune_conv')\n            # for param in self.model.conv_first.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.conv_flow.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.fea_L2_conv1.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.fea_L3_conv1.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.toplayer.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.smooth1.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.smooth2.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.latlayer1.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.latlayer2.parameters():\n            #     param.requires_grad = True\n            # for param in self.model.fusion.parameters():\n            #     param.requires_grad = True\n            for param in self.model.conv_after_body.parameters():\n                param.requires_grad = True\n            \n            \n\n    def forward_chop(self, *args, shave=10, min_size=160000):\n        scale = 1 if self.input_large else self.scale[self.idx_scale]\n        n_GPUs = min(self.n_GPUs, 4)\n        # height, width\n        h, w = args[0].size()[-2:]\n\n        top = slice(0, h//2 + shave)\n        bottom = slice(h - h//2 - shave, h)\n        left = slice(0, w//2 + shave)\n        right = slice(w - w//2 - shave, w)\n        x_chops = [torch.cat([\n            a[..., top, left],\n            a[..., top, right],\n            a[..., bottom, left],\n            a[..., bottom, right]\n        ]) for a in args]\n\n        y_chops = []\n        if h * w < 4 * min_size:\n            for i in range(0, 4, n_GPUs):\n                x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]\n                y = P.data_parallel(self.model, *x, range(n_GPUs))\n                if not isinstance(y, list): y = [y]\n                if not y_chops:\n                    y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]\n                else:\n                    for y_chop, _y in zip(y_chops, y):\n                        y_chop.extend(_y.chunk(n_GPUs, dim=0))\n        else:\n            for p in zip(*x_chops):\n                y = self.forward_chop(*p, shave=shave, min_size=min_size)\n                if not isinstance(y, list): y = [y]\n                if not y_chops:\n                    y_chops = [[_y] for _y in y]\n                else:\n                    for y_chop, _y in zip(y_chops, y): y_chop.append(_y)\n\n        h *= scale\n        w *= scale\n        top = slice(0, h//2)\n        bottom = slice(h - h//2, h)\n        bottom_r = slice(h//2 - h, None)\n        left = slice(0, w//2)\n        right = slice(w - w//2, w)\n        right_r = slice(w//2 - w, None)\n\n        # batch size, number of color channels\n        b, c = y_chops[0][0].size()[:-2]\n        y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]\n        for y_chop, _y in zip(y_chops, y):\n            _y[..., top, left] = y_chop[0][..., top, left]\n            _y[..., top, right] = y_chop[1][..., top, right_r]\n            _y[..., bottom, left] = y_chop[2][..., bottom_r, left]\n            _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]\n\n        if len(y) == 1: y = y[0]\n\n        return y\n\n    def forward_x8(self, *args, forward_function=None):\n        def _transform(v, op):\n            if self.precision != 'single': v = v.float()\n\n            v2np = v.data.cpu().numpy()\n            if op == 'v':\n                tfnp = v2np[:, :, :, ::-1].copy()\n            elif op == 'h':\n                tfnp = v2np[:, :, ::-1, :].copy()\n            elif op == 't':\n                tfnp = v2np.transpose((0, 1, 3, 2)).copy()\n\n            ret = torch.Tensor(tfnp).to(self.device)\n            if self.precision == 'half': ret = ret.half()\n\n            return ret\n\n        list_x = []\n        for a in args:\n            x = [a]\n            for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])\n\n            list_x.append(x)\n\n        list_y = []\n        for x in zip(*list_x):\n            y = forward_function(*x)\n            if not isinstance(y, list): y = [y]\n            if not list_y:\n                list_y = [[_y] for _y in y]\n            else:\n                for _list_y, _y in zip(list_y, y): _list_y.append(_y)\n\n        for _list_y in list_y:\n            for i in range(len(_list_y)):\n                if i > 3:\n                    _list_y[i] = _transform(_list_y[i], 't')\n                if i % 4 > 1:\n                    _list_y[i] = _transform(_list_y[i], 'h')\n                if (i % 4) % 2 == 1:\n                    _list_y[i] = _transform(_list_y[i], 'v')\n\n        y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]\n        if len(y) == 1: y = y[0]\n\n        return y\n"
  },
  {
    "path": "code/synthetic/bsrt/model/arch_util.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.nn.functional as F\nfrom model import common\nfrom model.utils.psconv import PSGConv2d as PSConv2d, PyConv2d\n\n\ndef initialize_weights(net_l, scale=1):\n    if not isinstance(net_l, list):\n        net_l = [net_l]\n    for net in net_l:\n        for m in net.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n                m.weight.data *= scale  # for residual block\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n                m.weight.data *= scale\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias.data, 0.0)\n\n\ndef make_layer(block, n_layers):\n    layers = []\n    for _ in range(n_layers):\n        layers.append(block())\n    return nn.Sequential(*layers)\n\n\n###########################\n\ndef conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0):\n    return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True)\n    \n\nclass ESA(nn.Module):\n    def __init__(self, n_feats, conv=conv_layer):\n        super(ESA, self).__init__()\n        f = n_feats // 4\n        self.conv1 = conv(n_feats, f, kernel_size=1)\n        self.conv_f = conv(f, f, kernel_size=1)\n        self.conv_max = conv(f, f, kernel_size=3, padding=1)\n        self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)\n        self.conv3 = conv(f, f, kernel_size=3, padding=1)\n        self.conv3_ = conv(f, f, kernel_size=3, padding=1)\n        self.conv4 = conv(f, n_feats, kernel_size=1)\n        self.sigmoid = nn.Sigmoid()\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        c1_ = (self.conv1(x))\n        c1 = self.conv2(c1_)\n        v_max = F.max_pool2d(c1, kernel_size=7, stride=3)\n        v_range = self.relu(self.conv_max(v_max))\n        c3 = self.relu(self.conv3(v_range))\n        c3 = self.conv3_(c3)\n        c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False) \n        cf = self.conv_f(c1_)\n        c4 = self.conv4(c3+cf)\n        m = self.sigmoid(c4)\n        \n        return x * m\n\n\nclass DWConv(nn.Module):\n    def __init__(self, dim=768):\n        super(DWConv, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n\n    def forward(self, x):\n        x = self.dwconv(x)\n        return x\n\n##########################\n\nclass SELayer(nn.Module):\n    '''\n    SE-block\n    '''\n    def __init__(self, channel, reduction=16):\n        super(SELayer, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel // reduction, channel, bias=False),\n            # nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        return x * y.expand_as(x)\n\nclass ResidualBlock_noBN(nn.Module):\n    '''Residual block w/o BN\n    ---Conv-ReLU-Conv-+-\n     |________________|\n    '''\n\n    def __init__(self, nf=64):\n        super(ResidualBlock_noBN, self).__init__()\n        self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n\n        # initialization\n        initialize_weights([self.conv1, self.conv2], 0.1)\n\n    def forward(self, x):\n        identity = x\n        out = F.relu(self.conv1(x), inplace=True)\n        out = self.conv2(out)\n        return identity + out\n\n\nclass ResidualBlock_SE(nn.Module):\n    '''Residual block w/o BN\n    ---Conv-ReLU-Conv-+-\n     |________________|\n    '''\n\n    def __init__(self, nf=64, reduction=16):\n        super(ResidualBlock_SE, self).__init__()\n        self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.conv3 = nn.Conv2d(3 * nf, nf, 1, padding=0, dilation=1, bias=True)\n        self.se = SELayer(nf, reduction)\n        # initialization\n        initialize_weights([self.conv1, self.conv2, self.conv3], 0.1)\n\n    def forward(self, x):\n        identity = x\n        basic_out = F.relu(self.conv1(x), inplace=True)\n        basic_out = self.conv2(basic_out)\n        se_out = self.se(basic_out)\n        out = torch.cat((identity, basic_out, se_out), 1)\n        out = self.conv3(out)\n        return out\n\n\nclass _PositionAttentionModule(nn.Module):\n    \"\"\" Position attention module\"\"\"\n\n    def __init__(self, in_channels, **kwargs):\n        super(_PositionAttentionModule, self).__init__()\n        self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)\n        self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)\n        self.conv_d = nn.Conv2d(in_channels, in_channels, 1)\n        self.alpha = nn.Parameter(torch.zeros(1))\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x):\n        batch_size, _, height, width = x.size()\n        feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)\n        feat_c = self.conv_c(x).view(batch_size, -1, height * width)\n        attention_s = self.softmax(torch.bmm(feat_b, feat_c))\n        feat_d = self.conv_d(x).view(batch_size, -1, height * width)\n        feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)\n        out = self.alpha * feat_e + x\n\n        return out\n\n## Spatial Attention (CA) Layer\nclass SALayer(nn.Module):\n    def __init__(self, wn=None):\n        super(SALayer,self).__init__()\n        self.body = nn.Sequential(\n            wn(nn.Conv2d(2, 1, 7, 1, 3, bias=False)),\n            nn.Sigmoid()\n        )\n    def forward(self, x):\n        avg_f = torch.mean(x, dim=1, keepdim=True)\n        max_f = torch.max(x, dim=1, keepdim=True)[0]\n        y = torch.cat([avg_f, max_f], dim=1)\n        return self.body(y).expand_as(x) * x\n\n\n## Channel Attention (CA) Layer\nclass CALayerV2(nn.Module):\n    def __init__(self, n_feat, reduction=16, wn=None):\n        super(CALayerV2, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.max_pool = nn.AdaptiveMaxPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        self.conv_du = nn.Sequential(\n                wn(nn.Conv2d(n_feat, n_feat//reduction, 1, padding=0, bias=False)),\n                nn.ReLU(inplace=True),\n                wn(nn.Conv2d(n_feat//reduction, n_feat, 1, padding=0, bias=False)),\n                # nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y1 = self.avg_pool(x)\n        y2 = self.max_pool(x)\n        y1 = self.conv_du(y1)\n        y2 = self.conv_du(y2)\n        return x * torch.sigmoid(y1+y2)\n\nclass DALayer(nn.Module):\n    def __init__(self, channel, reduction, wn):\n        super(DALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.ca = CALayer(channel, reduction, wn)\n        self.sa = SALayer(wn)\n        self.conv = wn(nn.Conv2d(channel*2, channel, 1))\n\n    def forward(self, x):\n        ca = self.ca(x)\n        sa = self.sa(x)\n        res = self.conv(torch.cat([ca, sa], dim=1))\n        return res + x\n\n\n## Channel Attention (CA) Layer\nclass CALayer(nn.Module):\n    def __init__(self, channel, reduction, wn):\n        super(CALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        self.conv_du = nn.Sequential(\n                wn(nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True)),\n                nn.ReLU(inplace=True),\n                wn(nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True)),\n                nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y = self.avg_pool(x)\n        y = self.conv_du(y)\n        return x * y\n\n\n## Residual Channel Attention Block (RCAB)\nclass RCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction, wn,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False):\n\n        super(RCAB, self).__init__()\n\n        expand = 6\n        linear = 0.75\n        modules_body = []\n        # for i in range(2):\n        modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))\n        modules_body.append(act)\n        modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))\n        modules_body.append(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias))\n        if da:\n            modules_body.append(DALayer(n_feat, reduction, wn))\n        else:\n            modules_body.append(CALayer(n_feat, reduction, wn))\n\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        #res = self.body(x).mul(self.res_scale)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass ResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False):\n        super(ResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = common.default_conv\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_body = []\n        modules_body = [\n            RCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \\\n            for _ in range(n_resblocks)]\n        modules_body.append(wn(conv(n_feat, n_feat, kernel_size)))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n\n################################################################\n################################################################\n################################################################\n\ndef make_layer_idx(block, n_layers):\n    layers = []\n    for i in range(n_layers):\n        layers.append(block(idx=i))\n    return nn.Sequential(*layers)\n\n## Residual Channel Attention Block (RCAB)\nclass LRSCRCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction, wn,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0):\n        super(LRSCRCAB, self).__init__()\n\n        expand = 6\n        linear = 0.75\n\n        modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []\n        # for i in range(2):\n        modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))\n        modules_body.append(act)\n        modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))\n        modules_body.append(wn(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias)))\n        if da:\n            modules_body.append(DALayer(n_feat, reduction, wn))\n        else:\n            modules_body.append(CALayer(n_feat, reduction, wn))\n\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Residual Channel Attention Block (RCAB)\nclass LRSCPYRCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction, wn,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0):\n        super(LRSCPYRCAB, self).__init__()\n\n        expand = 6\n        linear = 0.75\n\n        modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []\n        # for i in range(2):\n        modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))\n        modules_body.append(act)\n        modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))\n        modules_body.append(\n            PyConv2d(in_channels=int(n_feat*linear),\n                out_channels=[n_feat//4, n_feat//4, n_feat//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8]))\n        if da:\n            modules_body.append(DALayer(n_feat, reduction, wn))\n        else:\n            modules_body.append(CALayer(n_feat, reduction, wn))\n\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False, idx=0):\n        super(LRSCResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = common.default_conv\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(conv(n_feat*(idx+1), n_feat, 1, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCRCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \\\n            for i in range(n_resblocks)]\n        modules_body.append(wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size)))\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCPSResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False, idx=0):\n        super(LRSCPSResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = PSConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCRCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \\\n            for i in range(n_resblocks)]\n        modules_tail = [wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size))]\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res = self.tail(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCPyResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False, idx=0):\n        super(LRSCPyResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = PyConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCPYRCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \\\n            for i in range(n_resblocks)]\n        modules_tail = [wn(nn.Conv2d(n_feat*(n_resblocks+1), n_feat, 1))]\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res = self.tail(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\nclass LRSCWideActResBlock(nn.Module):\n    def __init__(self, nf=64, idx=0):\n        super(LRSCWideActResBlock, self).__init__()\n        self.res_scale = 1\n\n        expand = 6\n        linear = 0.8\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n        head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else []\n\n        body = []\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2)))\n\n        self.head = nn.Sequential(*head)\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\nclass LRSCPyWideActResBlock(nn.Module):\n    def __init__(self, nf=64, idx=0):\n        super(LRSCPyWideActResBlock, self).__init__()\n        self.res_scale = 1\n\n        expand = 6\n        linear = 0.75\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n        head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else []\n\n        body = []\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            PyConv2d(in_channels=int(nf*linear),\n                out_channels=[nf//4, nf//4, nf//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8]))\n\n        self.head = nn.Sequential(*head)\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCPyWideActResGroup(nn.Module):\n    def __init__(self, nf, n_resblocks, idx=0):\n        super(LRSCPyWideActResGroup, self).__init__()\n        kernel_size = 3\n\n        conv = PyConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCPyWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)]\n        modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))]\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res = self.tail(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n\n## Long-Range Skip-connect Residual Group (RG)\nclass LRSCWideActResGroup(nn.Module):\n    def __init__(self, nf, n_resblocks, idx=0):\n        super(LRSCWideActResGroup, self).__init__()\n        kernel_size = 3\n\n        conv = PyConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else []\n        modules_body = [\n            LRSCWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)]\n        modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))]\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        res = self.head(x)\n        res = self.body(res)\n        res = self.tail(res)\n        res  = torch.cat([res, x], dim=1)\n        return res\n\n################################################################\n################################################################\n################################################################\n\n\n## Residual Channel Attention Block (RCAB)\nclass PYRCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction, wn,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False):\n        super(PYRCAB, self).__init__()\n\n        expand = 6\n        linear = 0.75\n        modules_body = []\n        # for i in range(2):\n        modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))\n        modules_body.append(act)\n        modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))\n        # modules_body.append(conv(, n_feat, kernel_size, bias=bias))\n        modules_body.append(PyConv2d(in_channels=int(n_feat*linear),\n                out_channels=[n_feat//4, n_feat//4, n_feat//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8], bias=bias))\n        if da:\n            modules_body.append(DALayer(n_feat, reduction, wn))\n        else:\n            modules_body.append(CALayer(n_feat, reduction, wn))\n\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass PyResidualGroup(nn.Module):\n    def __init__(self, n_feat, n_resblocks, da=False):\n        super(PyResidualGroup, self).__init__()\n        kernel_size = 3\n        res_scale = 1\n        reduction = 16\n\n        conv = PyConv2d\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n\n        modules_body = []\n        modules_body = [\n            PYRCAB(\n                conv, n_feat, kernel_size, reduction, wn=wn, bias=True,\n                bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \\\n            for _ in range(n_resblocks)]\n        modules_body.append(\n            PyConv2d(in_channels=n_feat,\n                out_channels=[n_feat//4, n_feat//4, n_feat//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8]))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\nclass WideActResBlock(nn.Module):\n    def __init__(self, nf=64):\n        super(WideActResBlock, self).__init__()\n        self.res_scale = 1\n        body = []\n        expand = 6\n        linear = 0.8\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2)))\n\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.body(x) * self.res_scale\n        res += x\n        return res\n\n\nclass PSWideActResBlock(nn.Module):\n    def __init__(self, nf=64):\n        super(PSWideActResBlock, self).__init__()\n        self.res_scale = 1\n        body = []\n        expand = 6\n        linear = 0.75\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            wn(PSConv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2)))\n\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.body(x) * self.res_scale\n        res += x\n        return res\n\n\nclass PyWideActResBlock(nn.Module):\n    def __init__(self, nf=64):\n        super(PyWideActResBlock, self).__init__()\n        self.res_scale = 1\n        body = []\n        expand = 6\n        linear = 0.75\n        kernel_size = 3\n        wn = lambda x: torch.nn.utils.weight_norm(x)\n        act=nn.ReLU(True)\n        expand_nf = nf*expand\n        linear_nf = int(nf * linear)\n\n        body.append(\n            wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))\n        body.append(act)\n        body.append(\n            wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))\n        body.append(\n            PyConv2d(in_channels=linear_nf,\n                out_channels=[nf//4, nf//4, nf//2],\n                pyconv_kernels=[3, 5, 7],\n                pyconv_groups=[1, 4, 8]))\n\n        self.body = nn.Sequential(*body)\n\n    def forward(self, x):\n        res = self.body(x) * self.res_scale\n        res += x\n        return res\n\n\ndef flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True, use_pad_mask=False):\n    \"\"\"Warp an image or feature map with optical flow.\n\n    Args:\n        x (Tensor): Tensor with size (n, c, h, w).\n        flow (Tensor): Tensor with size (n, h, w, 2), normal value.\n        interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'.\n        padding_mode (str): 'zeros' or 'border' or 'reflection'.\n            Default: 'zeros'.\n        align_corners (bool): Before pytorch 1.3, the default value is\n            align_corners=True. After pytorch 1.3, the default value is\n            align_corners=False. Here, we use the True as default.\n        use_pad_mask (bool): only used for PWCNet, x is first padded with ones along the channel dimension.\n            The mask is generated according to the grid_sample results of the padded dimension.\n\n\n    Returns:\n        Tensor: Warped image or feature map.\n    \"\"\"\n    # assert x.size()[-2:] == flow.size()[1:3] # temporaily turned off for image-wise shift\n    n, _, h, w = x.size()\n    x = x.float()\n    # create mesh grid\n    # grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) # an illegal memory access on TITAN RTX + PyTorch1.9.1\n    grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), torch.arange(0, w, dtype=x.dtype, device=x.device))\n    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2\n    grid.requires_grad = False\n    grid = grid.type_as(x)\n    vgrid = grid + flow\n\n    # if use_pad_mask: # for PWCNet\n    #     x = F.pad(x, (0,0,0,0,0,1), mode='constant', value=1)\n\n    # scale grid to [-1,1]\n    if interp_mode == 'nearest4': # todo: bug, no gradient for flow model in this case!!! but the result is good\n        vgrid_x_floor = 2.0 * torch.floor(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0\n        vgrid_x_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0\n        vgrid_y_floor = 2.0 * torch.floor(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0\n        vgrid_y_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0\n\n        output00 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)\n        output01 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)\n        output10 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)\n        output11 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)\n\n        return torch.cat([output00, output01, output10, output11], 1)\n\n    else:\n        vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0\n        vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0\n        vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)\n        output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)\n\n        # if use_pad_mask: # for PWCNet\n        #     output = _flow_warp_masking(output)\n\n        # TODO, what if align_corners=False\n        return output\n\n\n# def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):\n#     \"\"\"Warp an image or feature map with optical flow\n#     Args:\n#         x (Tensor): size (N, C, H, W)\n#         flow (Tensor): size (N, H, W, 2), normal value\n#         interp_mode (str): 'nearest' or 'bilinear'\n#         padding_mode (str): 'zeros' or 'border' or 'reflection'\n\n#     Returns:\n#         Tensor: warped image or feature map\n#     \"\"\"\n#     assert x.size()[-2:] == flow.size()[1:3]\n#     B, C, H, W = x.size()\n#     # mesh grid\n#     grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))\n#     grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2\n#     grid.requires_grad = False\n#     grid = grid.type_as(x)\n#     vgrid = grid + flow\n#     # scale grid to [-1,1]\n#     vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0\n#     vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0\n#     vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)\n#     output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)\n#     return output\n"
  },
  {
    "path": "code/synthetic/bsrt/model/bsrt.py",
    "content": "import functools\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport model.arch_util as arch_util\nfrom torch.cuda.amp import autocast\nimport model.swin_util as swu\nimport time\nimport os\nimport math\nfrom utils.debayer import Debayer3x3\nimport torchvision.utils as tvutils\nfrom datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch\n\ntry:\n    from model.non_local.non_local_cross_dot_product import NONLocalBlock2D as NonLocalCross\n    from model.non_local.non_local_dot_product import NONLocalBlock2D as NonLocal\nexcept ImportError:\n    raise ImportError('Failed to import Non_Local module.')\n\ntry:\n    from model.DCNv2.dcn_v2 import DCN_sep as DCN, FlowGuidedDCN, InsideFlowGuidedDCN\nexcept ImportError:\n    raise ImportError('Failed to import DCNv2 module.')\n\n\ndef make_model(args, parent=False):\n    nframes = args.burst_size\n    img_size = args.patch_size // args.scale[0]\n    patch_size = 1\n    in_chans = args.burst_channel\n    out_chans = args.n_colors\n    \n    if args.model_level == \"S\":\n        depths = [6]*1 + [6] * 4\n        num_heads = [6]*1 + [6] * 4\n        embed_dim = 60\n    elif args.model_level == \"L\":\n        depths = [6]*1 + [8] * 6\n        num_heads = [6]*1 + [6] * 6\n        embed_dim = 180\n    window_size = 8\n    mlp_ratio = 2\n    upscale = args.scale[0]\n    non_local = args.non_local\n    use_checkpoint=args.use_checkpoint\n\n    if args.local_rank <= 0:\n        print(\"depths: \", depths)\n\n    return BSRT(args=args,nframes=nframes,\n                   img_size=img_size,\n                   patch_size=patch_size,\n                   in_chans=in_chans,\n                   out_chans=out_chans,\n                   embed_dim=embed_dim,\n                   depths=depths,\n                   num_heads=num_heads,\n                   window_size=window_size,\n                   mlp_ratio=mlp_ratio,\n                   upscale=upscale,\n                   non_local=non_local,\n                   use_checkpoint=use_checkpoint)\n\n\nclass BasicModule(nn.Module):\n    \"\"\"Basic Module for SpyNet.\n    \"\"\"\n\n    def __init__(self):\n        super(BasicModule, self).__init__()\n\n        self.basic_module = nn.Sequential(\n            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))\n\n    def forward(self, tensor_input):\n        return self.basic_module(tensor_input)\n\n\nclass SpyNet(nn.Module):\n    \"\"\"SpyNet architecture.\n\n    Args:\n        load_path (str): path for pretrained SpyNet. Default: None.\n        return_levels (list[int]): return flows of different levels. Default: [5].\n    \"\"\"\n\n    def __init__(self, load_path=None, return_levels=[5]):\n        super(SpyNet, self).__init__()\n        self.return_levels = return_levels\n        self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])\n        if load_path:\n            if not os.path.exists(load_path):\n                import requests\n                url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/spynet_sintel_final-3d2a1287.pth'\n                r = requests.get(url, allow_redirects=True)\n                print(f'downloading SpyNet pretrained model from {url}')\n                os.makedirs(os.path.dirname(load_path), exist_ok=True)\n                open(load_path, 'wb').write(r.content)\n\n            self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])\n\n        self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))\n        self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))\n\n    def preprocess(self, tensor_input):\n        tensor_output = (tensor_input - self.mean) / self.std\n        return tensor_output\n\n    def process(self, ref, supp, w, h, w_floor, h_floor):\n        flow_list = []\n\n        ref = [self.preprocess(ref)]\n        supp = [self.preprocess(supp)]\n\n        # ref = [ref]\n        # supp = [supp]\n\n        for level in range(5):\n            ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))\n            supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))\n\n        flow = ref[0].new_zeros(\n            [ref[0].size(0), 2,\n             int(math.floor(ref[0].size(2) / 2.0)),\n             int(math.floor(ref[0].size(3) / 2.0))])\n\n        for level in range(len(ref)):\n            upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0\n\n            if upsampled_flow.size(2) != ref[level].size(2):\n                upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')\n            if upsampled_flow.size(3) != ref[level].size(3):\n                upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')\n\n            flow = self.basic_module[level](torch.cat([\n                ref[level],\n                arch_util.flow_warp(\n                    supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),\n                upsampled_flow\n            ], 1)) + upsampled_flow\n\n            if level in self.return_levels:\n                scale = 2**(5-level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8)\n                flow_out = F.interpolate(input=flow, size=(h//scale, w//scale), mode='bilinear', align_corners=False)\n                flow_out[:, 0, :, :] *= float(w//scale) / float(w_floor//scale)\n                flow_out[:, 1, :, :] *= float(h//scale) / float(h_floor//scale)\n                if torch.abs(flow_out).mean() > 200:\n                    print(f\"level {level}, flow > 200: {torch.abs(flow_out).mean():.4f}\")\n                    # return None\n                    flow_out.clamp(-250, 250)\n                flow_list.insert(0, flow_out)\n\n        return flow_list\n\n    def forward(self, ref, supp):\n        assert ref.size() == supp.size()\n\n        h, w = ref.size(2), ref.size(3)\n        w_floor = math.floor(math.ceil(w / 32.0) * 32.0)\n        h_floor = math.floor(math.ceil(h / 32.0) * 32.0)\n\n        ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)\n        supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)\n\n        flow_list = self.process(ref, supp, w, h, w_floor, h_floor)\n\n        return flow_list[0] if len(flow_list) == 1 else flow_list\n\n\n\nclass FlowGuidedPCDAlign(nn.Module):\n    ''' Alignment module using Pyramid, Cascading and Deformable convolution\n    with 3 pyramid levels. [From EDVR]\n    '''\n\n    def __init__(self, nf=64, groups=8):\n        super(FlowGuidedPCDAlign, self).__init__()\n        # L3: level 3, 1/4 spatial size\n        self.L3_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L3_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n\n        # L2: level 2, 1/2 spatial size\n        self.L2_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L2_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n        self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n\n        # L1: level 1, original spatial size\n        self.L1_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L1_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n        self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n\n        # Cascading DCN\n        self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)\n\n    def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l):\n        '''align other neighboring frames to the reference frame in the feature level\n        nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features\n        '''\n        # L3\n        L3_offset = torch.cat([nbr_fea_warped_l[2], ref_fea_l[2], flows_l[2]], dim=1)\n        L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))\n        L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))\n        L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset, flows_l[2]))\n        # L2\n        L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)\n        L2_offset = torch.cat([nbr_fea_warped_l[1], ref_fea_l[1], flows_l[1]], dim=1)\n        L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))\n        L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset*2], dim=1)))\n        L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))\n        L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset, flows_l[1])\n        L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)\n        L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))\n        # L1\n        L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)\n        L1_offset = torch.cat([nbr_fea_warped_l[0], ref_fea_l[0], flows_l[0]], dim=1)\n        L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))\n        L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))\n        L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))\n        L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset, flows_l[0])\n        L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)\n        L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))\n\n        # Cascading\n        offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)\n        offset = self.lrelu(self.cas_offset_conv1(offset))\n        offset = self.lrelu(self.cas_offset_conv2(offset))\n        L1_fea = self.cas_dcnpack(L1_fea, offset)\n\n        return L1_fea\n\n\nclass CrossNonLocal_Fusion(nn.Module):\n    ''' Cross Non Local fusion module\n    '''\n    def __init__(self, nf=64, out_feat=96, nframes=5, center=2):\n        super(CrossNonLocal_Fusion, self).__init__()\n        self.center = center\n\n        self.non_local_T = nn.ModuleList()\n        self.non_local_F = nn.ModuleList()\n\n        for i in range(nframes):\n            self.non_local_T.append(NonLocalCross(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False))\n            self.non_local_F.append(NonLocal(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False))\n\n        # fusion conv: using 1x1 to save parameters and computation\n        self.fea_fusion = nn.Conv2d(nframes * nf*2, out_feat, 3, 1, 1, bias=True)\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)\n\n    def forward(self, aligned_fea):\n        B, N, C, H, W = aligned_fea.size()  # N video frames\n        ref = aligned_fea[:, self.center, :, :, :].clone()\n\n        cor_l = []\n        non_l = []\n        for i in range(N):\n            nbr = aligned_fea[:, i, :, :, :]\n            non_l.append(self.non_local_F[i](nbr))\n            cor_l.append(self.non_local_T[i](nbr, ref))\n\n        aligned_fea_T = torch.cat(cor_l, dim=1)\n        aligned_fea_F = torch.cat(non_l, dim=1)\n        aligned_fea = torch.cat([aligned_fea_T, aligned_fea_F], dim=1)\n\n        #### fusion\n        fea = self.fea_fusion(aligned_fea)\n\n        return fea\n\n\n\nclass BSRT(nn.Module):\n    def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chans=3, out_chans=3,\n                 embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, upscale=4, non_local=False,\n                 **kwargs):\n        super(BSRT, self).__init__()\n        num_in_ch = in_chans\n        num_out_ch = out_chans\n        num_feat = 64\n        groups = 8\n        # embed_dim = num_feat\n        back_RBs = 5\n        n_resblocks = 6\n\n        self.args = args\n        self.center = 0\n        self.upscale = upscale\n        self.window_size = window_size\n        self.non_local = non_local\n        self.nframes = nframes\n\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = embed_dim\n        self.mlp_ratio = mlp_ratio\n\n        spynet_path='/home/luoziwei/.pretrained_models/spynet_sintel_final-3d2a1287.pth'\n        self.spynet = SpyNet(spynet_path, [3, 4, 5])\n        self.conv_flow = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)\n        self.flow_ps = nn.PixelShuffle(2)\n\n        # split image into non-overlapping patches\n        self.patch_embed = swu.PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # merge non-overlapping patches into image\n        self.patch_unembed = swu.PatchUnEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n\n        #####################################################################################################\n        ################################### 1, shallow feature extraction ###################################\n        self.conv_first = nn.Conv2d(num_in_ch*(1+2*0), embed_dim, 3, 1, 1, bias=True)\n        \n        # # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        if args.swinfeature:\n            if self.args.local_rank <= 0:\n                print(\"using swinfeature\")\n            self.pre_layers = nn.ModuleList()\n            for i_layer in range(depths[0]):\n                layer = swu.SwinTransformerBlock(dim=embed_dim, \n                            input_resolution=(patches_resolution[0]//2,\n                                              patches_resolution[1]//2),\n                             num_heads=num_heads[0], window_size=window_size,\n                             shift_size=0 if (i_layer % 2 == 0) else window_size // 2,\n                             mlp_ratio=mlp_ratio,\n                             qkv_bias=qkv_bias, qk_scale=qk_scale,\n                             drop=drop_rate, attn_drop=attn_drop_rate,\n                             drop_path=dpr[i_layer],\n                             norm_layer=norm_layer)\n                self.pre_layers.append(layer)\n\n            self.pre_norm = norm_layer(embed_dim)\n        else:\n            WARB = functools.partial(arch_util.WideActResBlock, nf=embed_dim)\n            self.feature_extraction = arch_util.make_layer(WARB, 5)\n\n        self.conv_after_pre_layer = nn.Conv2d(embed_dim, num_feat*4, 3, 1, 1, bias=True)\n        self.mid_ps = nn.PixelShuffle(2)\n\n        self.fea_L2_conv1 = nn.Conv2d(num_feat, num_feat*2, 3, 2, 1, bias=True)\n        self.fea_L3_conv1 = nn.Conv2d(num_feat*2, num_feat*4, 3, 2, 1, bias=True)\n\n        #####################################################################################################\n        ################################### 2, Feature Enhanced PCD Align ###################################\n\n        # Top layers\n        self.toplayer = nn.Conv2d(num_feat*4, num_feat, kernel_size=1, stride=1, padding=0)\n        # Smooth layers\n        self.smooth1 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1)\n        self.smooth2 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1)\n        # Lateral layers\n        self.latlayer1 = nn.Conv2d(num_feat*2, num_feat, kernel_size=1, stride=1, padding=0)\n        self.latlayer2 = nn.Conv2d(num_feat*1, num_feat, kernel_size=1, stride=1, padding=0)\n\n        # self.align = PCD_Align(nf=num_feat, groups=groups)\n        self.align = FlowGuidedPCDAlign(nf=num_feat, groups=groups)\n        #####################################################################################################\n        ################################### 3, Multi-frame Feature Fusion  ##################################\n\n        if self.non_local:\n            if self.args.local_rank <= 0:\n                print(\"using non_local\")\n            self.fusion = CrossNonLocal_Fusion(nf=num_feat, out_feat=embed_dim, nframes=nframes, center=self.center)\n        else:\n            self.fusion = nn.Conv2d(nframes * num_feat, embed_dim, 1, 1, bias=True)\n\n        #####################################################################################################\n        ################################### 4, deep feature extraction ######################################\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            swu.trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # build Residual Swin Transformer blocks (RSTB)\n        self.layers = nn.ModuleList()\n        for i_layer in range(1, self.num_layers):\n            layer = swu.RSTB(dim=embed_dim,\n                         input_resolution=(patches_resolution[0],\n                                           patches_resolution[1]),\n                         depth=depths[i_layer],\n                         num_heads=num_heads[i_layer],\n                         window_size=window_size,\n                         mlp_ratio=self.mlp_ratio,\n                         qkv_bias=qkv_bias, qk_scale=qk_scale,\n                         drop=drop_rate, attn_drop=attn_drop_rate,\n                         drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results\n                         norm_layer=norm_layer,\n                         downsample=None,\n                         use_checkpoint=use_checkpoint,\n                         img_size=img_size,\n                         patch_size=patch_size\n                         )\n            self.layers.append(layer)\n        \n        self.norm = norm_layer(self.num_features)\n\n        # build the last conv layer in deep feature extraction\n        self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)\n\n        #####################################################################################################\n        ################################ 5, high quality image reconstruction ################################\n\n        self.upconv1 = nn.Conv2d(embed_dim, num_feat * 4, 3, 1, 1, bias=True)\n        self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)\n        self.pixel_shuffle = nn.PixelShuffle(2)\n        self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)\n        self.conv_last = nn.Conv2d(64, args.n_colors, 3, 1, 1, bias=True)\n\n        #### skip #############\n        self.skip_pixel_shuffle = nn.PixelShuffle(2)\n        self.skipup1 = nn.Conv2d(num_in_ch//4, num_feat * 4, 3, 1, 1, bias=True)\n        self.skipup2 = nn.Conv2d(num_feat, args.n_colors * 4, 3, 1, 1, bias=True)\n\n        #### activation function\n        self.lrelu = nn.LeakyReLU(0.1, inplace=True)\n        self.lrelu2 = nn.LeakyReLU(0.1, inplace=True)\n\n        self.apply(self._init_weights)\n\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            swu.trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def _upsample_add(self, x, y):\n        return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + y\n\n    def check_image_size(self, x):\n        _, _, h, w = x.size()\n        mod_pad_h = (self.window_size - h % self.window_size) % self.window_size\n        mod_pad_w = (self.window_size - w % self.window_size) % self.window_size\n        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')\n        return x\n\n    def pre_forward_features(self, x):\n        if self.args.swinfeature:\n            x_size = (x.shape[-2], x.shape[-1])\n            x = self.patch_embed(x, use_norm=True)\n            if self.ape:\n                x = x + self.absolute_pos_embed\n            x = self.pos_drop(x)\n\n            for idx, layer in enumerate(self.pre_layers):\n                x = layer(x, x_size)\n\n            x = self.pre_norm(x)\n            x = self.patch_unembed(x, x_size)\n\n        else:\n            x = self.feature_extraction(x)\n\n        return x\n\n    def forward_features(self, x):\n        x_size = (x.shape[-2], x.shape[-1])\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for idx, layer in enumerate(self.layers):\n            x = layer(x, x_size)\n            if torch.any(torch.isinf(x)) or torch.any(torch.isnan(x)):\n                print('layer: ', idx)\n\n        x = self.norm(x)  # B L C\n        x = self.patch_unembed(x, x_size)\n\n        return x\n\n    @autocast()\n    def forward(self, x, print_time=False):\n        B, N, C, H, W = x.size()  # N video frames\n        x_center = x[:, self.center, :, :, :].contiguous()\n\n        #### skip module ########\n        skip1 = self.lrelu2(self.skip_pixel_shuffle(self.skipup1(self.skip_pixel_shuffle(x_center))))\n        skip2 = self.skip_pixel_shuffle(self.skipup2(skip1))\n\n        x_ = self.conv_flow(self.flow_ps(x.view(B*N, C, H, W))).view(B, N, -1, H*2, W*2)\n        \n        # calculate flows\n        ref_flows = self.get_ref_flows(x_)\n\n        #### extract LR features\n        x = self.lrelu(self.conv_first(x.view(B*N, -1, H, W)))\n\n        L1_fea = self.mid_ps(self.conv_after_pre_layer(self.pre_forward_features(x)))\n        _, _, H, W = L1_fea.size()\n\n        L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))\n        L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))\n\n        # FPN enhance features\n        L3_fea = self.lrelu(self.toplayer(L3_fea))\n        L2_fea = self.smooth1(self._upsample_add(L3_fea, self.latlayer1(L2_fea)))\n        L1_fea = self.smooth2(self._upsample_add(L2_fea, self.latlayer2(L1_fea)))\n\n        L1_fea = L1_fea.view(B, N, -1, H, W).contiguous()\n        L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2 ).contiguous()\n        L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4).contiguous()\n\n        #### PCD align\n        # ref feature list\n        ref_fea_l = [\n            L1_fea[:, self.center, :, :, :].clone(), \n            L2_fea[:, self.center, :, :, :].clone(),\n            L3_fea[:, self.center, :, :, :].clone()\n        ]\n        aligned_fea = []\n        for i in range(N):\n            nbr_fea_l = [\n                L1_fea[:, i, :, :, :].clone(), \n                L2_fea[:, i, :, :, :].clone(),\n                L3_fea[:, i, :, :, :].clone()\n            ]\n            flows_l = [\n                ref_flows[0][:, i, :, :, :].clone(), \n                ref_flows[1][:, i, :, :, :].clone(), \n                ref_flows[2][:, i, :, :, :].clone()\n            ]\n            # print(nbr_fea_l[0].shape, flows_l[0].shape)\n            nbr_warped_l = [\n                arch_util.flow_warp(nbr_fea_l[0], flows_l[0].permute(0, 2, 3, 1), 'bilinear'),\n                arch_util.flow_warp(nbr_fea_l[1], flows_l[1].permute(0, 2, 3, 1), 'bilinear'),\n                arch_util.flow_warp(nbr_fea_l[2], flows_l[2].permute(0, 2, 3, 1), 'bilinear')\n            ]\n            aligned_fea.append(self.align(nbr_fea_l, nbr_warped_l, ref_fea_l, flows_l))\n\n        aligned_fea = torch.stack(aligned_fea, dim=1)  # [B, N, C, H, W] --> [B, T, C, H, W]\n\n        if not self.non_local:\n            aligned_fea = aligned_fea.view(B, -1, H, W)\n\n        x = self.lrelu(self.fusion(aligned_fea))\n\n        x = self.lrelu(self.conv_after_body(self.forward_features(x))) + x\n\n        x = self.lrelu(self.pixel_shuffle(self.upconv1(x)))\n        x = skip1 + x\n        x = self.lrelu(self.pixel_shuffle(self.upconv2(x)))\n        x = self.lrelu(self.HRconv(x))\n        x = self.conv_last(x)\n\n        x = skip2 + x\n        return x\n\n\n    def get_ref_flows(self, x):\n        '''Get flow between frames ref and other'''\n\n        b, n, c, h, w = x.size()\n        x_nbr = x.reshape(-1, c, h, w)\n        x_ref = x[:, self.center:self.center+1, :, :, :].repeat(1, n, 1, 1, 1).reshape(-1, c, h, w)\n\n        # backward\n        flows = self.spynet(x_ref, x_nbr)\n        flows_list = [flow.view(b, n, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in\n                          zip(flows, range(3))]\n\n        return flows_list\n\n\n\n\n\n\n\n"
  },
  {
    "path": "code/synthetic/bsrt/model/checkpoint.py",
    "content": "import torch\nimport warnings\n\n\ndef detach_variable(inputs):\n    if isinstance(inputs, tuple):\n        out = []\n        for inp in inputs:\n            x = inp.detach()\n            x.requires_grad = inp.requires_grad\n            out.append(x)\n        return tuple(out)\n    else:\n        raise RuntimeError(\n            \"Only tuple of tensors is supported. Got Unsupported input type: \", type(inputs).__name__)\n\n\ndef check_backward_validity(inputs):\n    if not any(inp.requires_grad for inp in inputs):\n        warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n\n\nclass CheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n        with torch.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        for i in range(len(ctx.input_tensors)):\n            temp = ctx.input_tensors[i]\n            ctx.input_tensors[i] = temp.detach()\n            ctx.input_tensors[i].requires_grad = temp.requires_grad\n        with torch.enable_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)\n        return (None, None) + input_grads\n"
  },
  {
    "path": "code/synthetic/bsrt/model/common.py",
    "content": "import math\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef default_conv(in_channels, out_channels, kernel_size, bias=True):\n    return nn.Conv2d(\n        in_channels, out_channels, kernel_size,\n        padding=(kernel_size // 2), bias=bias)\n\n\nclass MeanShift(nn.Conv2d):\n    def __init__(\n            self, rgb_range,\n            rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):\n        super(MeanShift, self).__init__(3, 3, kernel_size=1)\n        std = torch.Tensor(rgb_std)\n        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)\n        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std\n        for p in self.parameters():\n            p.requires_grad = False\n\n\nclass BasicBlock(nn.Sequential):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,\n            bn=True, act=nn.ReLU(True)):\n\n        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]\n        if bn:\n            m.append(nn.BatchNorm2d(out_channels))\n        if act is not None:\n            m.append(act)\n\n        super(BasicBlock, self).__init__(*m)\n\n\nclass ResBlock(nn.Module):\n    def __init__(\n            self, conv, n_feats, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(ResBlock, self).__init__()\n        m = []\n        for i in range(2):\n            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))\n            if bn:\n                m.append(nn.BatchNorm2d(n_feats))\n            if i == 0:\n                m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x).mul(self.res_scale)\n        res += x\n\n        return res\n\n\nclass Upsampler(nn.Sequential):\n    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):\n\n        m = []\n        if (scale & (scale - 1)) == 0:  # Is scale = 2^n?\n            for _ in range(int(math.log(scale, 2))):\n                m.append(conv(n_feats, 4 * n_feats, 3, bias))\n                m.append(nn.PixelShuffle(2))\n                if bn:\n                    m.append(nn.BatchNorm2d(n_feats))\n                if act == 'relu':\n                    m.append(nn.ReLU(True))\n                elif act == 'prelu':\n                    m.append(nn.PReLU(n_feats))\n\n        elif scale == 3:\n            m.append(conv(n_feats, 9 * n_feats, 3, bias))\n            m.append(nn.PixelShuffle(3))\n            if bn:\n                m.append(nn.BatchNorm2d(n_feats))\n            if act == 'relu':\n                m.append(nn.ReLU(True))\n            elif act == 'prelu':\n                m.append(nn.PReLU(n_feats))\n        else:\n            raise NotImplementedError\n\n        super(Upsampler, self).__init__(*m)\n\n\nclass UpOnly(nn.Sequential):\n    def __init__(self, scale):\n\n        m = []\n        if (scale & (scale - 1)) == 0:  # Is scale = 2^n?\n            for _ in range(int(math.log(scale, 2))):\n                m.append(nn.PixelShuffle(2))\n\n\n        elif scale == 3:\n\n            m.append(nn.PixelShuffle(3))\n\n        else:\n            raise NotImplementedError\n\n        super(UpOnly, self).__init__(*m)\n\n\ndef lanczos_kernel(dx, a=3, N=None, dtype=None, device=None):\n    '''\n    Generates 1D Lanczos kernels for translation and interpolation.\n    Args:\n        dx : float, tensor (batch_size, 1), the translation in pixels to shift an image.\n        a : int, number of lobes in the kernel support.\n            If N is None, then the width is the kernel support (length of all lobes),\n            S = 2(a + ceil(dx)) + 1.\n        N : int, width of the kernel.\n            If smaller than S then N is set to S.\n    Returns:\n        k: tensor (?, ?), lanczos kernel\n    '''\n\n    if not torch.is_tensor(dx):\n        dx = torch.tensor(dx, dtype=dtype, device=device)\n\n    if device is None:\n        device = dx.device\n\n    if dtype is None:\n        dtype = dx.dtype\n\n    D = dx.abs().ceil().int()\n    S = 2 * (a + D) + 1  # width of kernel support\n\n    S_max = S.max() if hasattr(S, 'shape') else S\n\n    if (N is None) or (N < S_max):\n        N = S\n\n    Z = (N - S) // 2  # width of zeros beyond kernel support\n\n    start = (-(a + D + Z)).min()\n    end = (a + D + Z + 1).max()\n    x = torch.arange(start, end, dtype=dtype, device=device).view(1, -1) - dx\n    px = (np.pi * x) + 1e-3\n\n    sin_px = torch.sin(px)\n    sin_pxa = torch.sin(px / a)\n\n    k = a * sin_px * sin_pxa / px ** 2  # sinc(x) masked by sinc(x/a)\n\n    return k\n\n\ndef lanczos_shift(img, shift, p=5, a=3):\n    '''\n    Shifts an image by convolving it with a Lanczos kernel.\n    Lanczos interpolation is an approximation to ideal sinc interpolation,\n    by windowing a sinc kernel with another sinc function extending up to a\n    few nunber of its lobes (typically a=3).\n\n    Args:\n        img : tensor (batch_size, channels, height, width), the images to be shifted\n        shift : tensor (batch_size, 2) of translation parameters (dy, dx)\n        p : int, padding width prior to convolution (default=3)\n        a : int, number of lobes in the Lanczos interpolation kernel (default=3)\n    Returns:\n        I_s: tensor (batch_size, channels, height, width), shifted images\n    '''\n    img = img.transpose(0, 1)\n    dtype = img.dtype\n\n    if len(img.shape) == 2:\n        img = img[None, None].repeat(1, shift.shape[0], 1, 1)  # batch of one image\n    elif len(img.shape) == 3:  # one image per shift\n        assert img.shape[0] == shift.shape[0]\n        img = img[None,]\n\n    # Apply padding\n\n    padder = torch.nn.ReflectionPad2d(p)  # reflect pre-padding\n    I_padded = padder(img)\n\n    # Create 1D shifting kernels\n\n    y_shift = shift[:, [0]]\n    x_shift = shift[:, [1]]\n\n    k_y = (lanczos_kernel(y_shift, a=a, N=None, dtype=dtype)\n           .flip(1)  # flip axis of convolution\n           )[:, None, :, None]  # expand dims to get shape (batch, channels, y_kernel, 1)\n    k_x = (lanczos_kernel(x_shift, a=a, N=None, dtype=dtype)\n           .flip(1)\n           )[:, None, None, :]  # shape (batch, channels, 1, x_kernel)\n\n    # Apply kernels\n    # print(I_padded.shape, k_y.shape)\n    I_s = torch.conv1d(I_padded,\n                       groups=k_y.shape[0],\n                       weight=k_y,\n                       padding=[k_y.shape[2] // 2, 0])  # same padding\n    I_s = torch.conv1d(I_s,\n                       groups=k_x.shape[0],\n                       weight=k_x,\n                       padding=[0, k_x.shape[3] // 2])\n\n    I_s = I_s[..., p:-p, p:-p]  # remove padding\n\n    # print(I_s.shape)\n    return I_s.transpose(0, 1)  # , k.squeeze()\n"
  },
  {
    "path": "code/synthetic/bsrt/model/ebsr.py",
    "content": "import functools\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport model.arch_util as arch_util\nfrom torch.cuda.amp import autocast\nimport model.swin_util as swu\nimport time\nimport os\nimport math\nfrom utils.debayer import Debayer3x3\nimport torchvision.utils as tvutils\nfrom datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch\n\ntry:\n    from model.non_local.non_local_cross_dot_product import NONLocalBlock2D as NonLocalCross\n    from model.non_local.non_local_dot_product import NONLocalBlock2D as NonLocal\nexcept ImportError:\n    raise ImportError('Failed to import Non_Local module.')\n\ntry:\n    from model.DCNv2.dcn_v2 import DCN_sep as DCN, FlowGuidedDCN, InsideFlowGuidedDCN\nexcept ImportError:\n    raise ImportError('Failed to import DCNv2 module.')\n\n\ndef make_model(args, parent=False):\n    nframes = args.burst_size\n    img_size = args.patch_size // args.scale[0]\n    patch_size = 1\n    in_chans = args.burst_channel\n    out_chans = args.n_colors\n    embed_dim = args.n_feats\n    depths = [6]*1 + [8] * 6\n    num_heads = [6]*1 + [6] * 6\n    window_size = 8\n    mlp_ratio = 2\n    upscale = args.scale[0]\n    non_local = args.non_local\n\n    if args.local_rank <= 0:\n        print(\"depths: \", depths)\n\n    return EBSR(args=args,nframes=nframes,\n                   img_size=img_size,\n                   patch_size=patch_size,\n                   in_chans=in_chans,\n                   out_chans=out_chans,\n                   embed_dim=embed_dim,\n                   depths=depths,\n                   num_heads=num_heads,\n                   window_size=window_size,\n                   mlp_ratio=mlp_ratio,\n                   upscale=upscale,\n                   non_local=non_local)\n\n\nclass BasicModule(nn.Module):\n    \"\"\"Basic Module for SpyNet.\n    \"\"\"\n\n    def __init__(self):\n        super(BasicModule, self).__init__()\n\n        self.basic_module = nn.Sequential(\n            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),\n            nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))\n\n    def forward(self, tensor_input):\n        return self.basic_module(tensor_input)\n\n\nclass SpyNet(nn.Module):\n    \"\"\"SpyNet architecture.\n\n    Args:\n        load_path (str): path for pretrained SpyNet. Default: None.\n        return_levels (list[int]): return flows of different levels. Default: [5].\n    \"\"\"\n\n    def __init__(self, load_path=None, return_levels=[5]):\n        super(SpyNet, self).__init__()\n        self.return_levels = return_levels\n        self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])\n        if load_path:\n            if not os.path.exists(load_path):\n                import requests\n                url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/spynet_sintel_final-3d2a1287.pth'\n                r = requests.get(url, allow_redirects=True)\n                print(f'downloading SpyNet pretrained model from {url}')\n                os.makedirs(os.path.dirname(load_path), exist_ok=True)\n                open(load_path, 'wb').write(r.content)\n\n            self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])\n\n        self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))\n        self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))\n\n    def preprocess(self, tensor_input):\n        tensor_output = (tensor_input - self.mean) / self.std\n        return tensor_output\n\n    def process(self, ref, supp, w, h, w_floor, h_floor):\n        flow_list = []\n\n        ref = [self.preprocess(ref)]\n        supp = [self.preprocess(supp)]\n\n        # ref = [ref]\n        # supp = [supp]\n\n        for level in range(5):\n            ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))\n            supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))\n\n        flow = ref[0].new_zeros(\n            [ref[0].size(0), 2,\n             int(math.floor(ref[0].size(2) / 2.0)),\n             int(math.floor(ref[0].size(3) / 2.0))])\n\n        for level in range(len(ref)):\n            upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0\n\n            if upsampled_flow.size(2) != ref[level].size(2):\n                upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')\n            if upsampled_flow.size(3) != ref[level].size(3):\n                upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')\n\n            flow = self.basic_module[level](torch.cat([\n                ref[level],\n                arch_util.flow_warp(\n                    supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),\n                upsampled_flow\n            ], 1)) + upsampled_flow\n\n            if level in self.return_levels:\n                scale = 2**(5-level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8)\n                flow_out = F.interpolate(input=flow, size=(h//scale, w//scale), mode='bilinear', align_corners=False)\n                flow_out[:, 0, :, :] *= float(w//scale) / float(w_floor//scale)\n                flow_out[:, 1, :, :] *= float(h//scale) / float(h_floor//scale)\n                if torch.abs(flow_out).mean() > 200:\n                    print(f\"level {level}, flow > 200: {torch.abs(flow_out).mean():.4f}\")\n                    # return None\n                    flow_out.clamp(-50, 50)\n                flow_list.insert(0, flow_out)\n\n        return flow_list\n\n    def forward(self, ref, supp):\n        assert ref.size() == supp.size()\n\n        h, w = ref.size(2), ref.size(3)\n        w_floor = math.floor(math.ceil(w / 32.0) * 32.0)\n        h_floor = math.floor(math.ceil(h / 32.0) * 32.0)\n\n        ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)\n        supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)\n\n        flow_list = self.process(ref, supp, w, h, w_floor, h_floor)\n\n        return flow_list[0] if len(flow_list) == 1 else flow_list\n\n\nclass PCD_Align(nn.Module):\n    ''' Alignment module using Pyramid, Cascading and Deformable convolution\n    with 3 pyramid levels. [From EDVR]\n    '''\n\n    def __init__(self, nf=64, groups=8, wn=None):\n        super(PCD_Align, self).__init__()\n        if wn is None:\n            wn = lambda x: torch.nn.utils.weight_norm(x)\n        # L3: level 3, 1/4 spatial size\n        self.L3_offset_conv1 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True))  # concat for diff\n        self.L3_offset_conv2 = wn(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))\n        # self.L3_shift = ShiftAlign(nf)\n        self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n                              # extra_offset_mask=True)\n        # L2: level 2, 1/2 spatial size\n        self.L2_offset_conv1 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True))  # concat for diff\n        self.L2_offset_conv2 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True))  # concat for offset\n        self.L2_offset_conv3 = wn(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))\n        # self.L2_shift = ShiftAlign(nf)\n        self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n                              # extra_offset_mask=True)\n        self.L2_fea_conv = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True))  # concat for fea\n        # L1: level 1, original spatial size\n        self.L1_offset_conv1 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True))  # concat for diff\n        self.L1_offset_conv2 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True))  # concat for offset\n        self.L1_offset_conv3 = wn(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))\n        # self.L1_shift = ShiftAlign(nf)\n        self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n                              # extra_offset_mask=True)\n        self.L1_fea_conv = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True))  # concat for fea\n        # Cascading DCN\n        self.cas_offset_conv1 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True))  # concat for diff\n        self.cas_offset_conv2 = wn(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))\n\n        self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)\n\n    def forward(self, nbr_fea_l, ref_fea_l):\n        '''align other neighboring frames to the reference frame in the feature level\n        nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features\n        '''\n        # L3\n        L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1)\n        L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))\n        L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))\n        # L3_nbr_fea = self.L3_shift(L3_offset, nbr_fea_l[2])\n        L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset))\n        # L2\n        L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)\n        L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1)\n        L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))\n        L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset*2], dim=1)))\n        L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))\n        # L2_nbr_fea = self.L2_shift(L2_offset, nbr_fea_l[1])\n        L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset)\n        L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)\n        L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))\n        # L1\n        L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)\n        L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1)\n        L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))\n        L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))\n        L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))\n        # L1_nbr_fea = self.L1_shift(L1_offset, nbr_fea_l[0])\n        L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset)\n        L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)\n        L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))\n        # Cascading\n        offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)\n        offset = self.lrelu(self.cas_offset_conv1(offset))\n        offset = self.lrelu(self.cas_offset_conv2(offset))\n        L1_fea = self.cas_dcnpack(L1_fea, offset)\n\n        return L1_fea\n\n\nclass FlowGuidedPCDAlign(nn.Module):\n    ''' Alignment module using Pyramid, Cascading and Deformable convolution\n    with 3 pyramid levels. [From EDVR]\n    '''\n\n    def __init__(self, nf=64, groups=8):\n        super(FlowGuidedPCDAlign, self).__init__()\n        # L3: level 3, 1/4 spatial size\n        self.L3_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L3_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n\n        # L2: level 2, 1/2 spatial size\n        self.L2_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L2_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n        self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n\n        # L1: level 1, original spatial size\n        self.L1_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L1_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n        self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n        # Cascading DCN\n        # self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        # self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        # self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)\n\n    def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l):\n        '''align other neighboring frames to the reference frame in the feature level\n        nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features\n        '''\n        # L3\n        L3_offset = torch.cat([nbr_fea_warped_l[2], ref_fea_l[2], flows_l[2]], dim=1)\n        L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))\n        L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))\n        L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset, flows_l[2]))\n\n        # L2\n        L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)\n        L2_offset = torch.cat([nbr_fea_warped_l[1], ref_fea_l[1], flows_l[1]], dim=1)\n        L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))\n        L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset*2], dim=1)))\n        L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))\n        L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset, flows_l[1])\n        L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)\n        L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))\n\n        # L1\n        L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)\n        L1_offset = torch.cat([nbr_fea_warped_l[0], ref_fea_l[0], flows_l[0]], dim=1)\n        L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))\n        L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))\n        L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))\n        L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset, flows_l[0])\n        L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)\n        L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))\n        # Cascading\n        # offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)\n        # offset = self.lrelu(self.cas_offset_conv1(offset))\n        # offset = self.lrelu(self.cas_offset_conv2(offset))\n        # L1_fea = self.cas_dcnpack(L1_fea, offset)\n\n        return L1_fea\n\n\nclass CrossNonLocal_Fusion(nn.Module):\n    ''' Cross Non Local fusion module\n    '''\n    def __init__(self, nf=64, out_feat=96, nframes=5, center=2):\n        super(CrossNonLocal_Fusion, self).__init__()\n        self.center = center\n\n        self.non_local_T = nn.ModuleList()\n        self.non_local_F = nn.ModuleList()\n\n        for i in range(nframes):\n            self.non_local_T.append(NonLocalCross(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False))\n            self.non_local_F.append(NonLocal(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False))\n\n        # fusion conv: using 1x1 to save parameters and computation\n        self.fea_fusion = nn.Conv2d(nframes * nf*2, out_feat, 3, 1, 1, bias=True)\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)\n\n    def forward(self, aligned_fea):\n        B, N, C, H, W = aligned_fea.size()  # N video frames\n        ref = aligned_fea[:, self.center, :, :, :].clone()\n\n        cor_l = []\n        non_l = []\n        for i in range(N):\n            nbr = aligned_fea[:, i, :, :, :]\n            non_l.append(self.non_local_F[i](nbr))\n            cor_l.append(self.non_local_T[i](nbr, ref))\n\n        aligned_fea_T = torch.cat(cor_l, dim=1)\n        aligned_fea_F = torch.cat(non_l, dim=1)\n        aligned_fea = torch.cat([aligned_fea_T, aligned_fea_F], dim=1)\n\n        #### fusion\n        fea = self.fea_fusion(aligned_fea)\n\n        return fea\n\n\n\nclass EBSR(nn.Module):\n    r\"\"\" SwinBSR\n    \"\"\"\n\n    def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chans=3, out_chans=3,\n                 embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, upscale=4, non_local=False,\n                 **kwargs):\n        super(EBSR, self).__init__()\n        num_in_ch = in_chans\n        num_out_ch = out_chans\n        num_feat = 128\n        groups = 8\n        back_RBs = 5\n        n_resblocks = 8\n        embed_dim = num_feat\n\n        self.args = args\n        self.center = 0\n        self.upscale = upscale\n        self.window_size = window_size\n        self.non_local = non_local\n        self.nframes = nframes\n\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = embed_dim\n        self.mlp_ratio = mlp_ratio\n\n        spynet_path='/home/luoziwei/.pretrained_models/spynet_sintel_final-3d2a1287.pth'\n        self.spynet = SpyNet(spynet_path, [3, 4, 5])\n        self.conv_flow = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)\n\n        self.flow_ps = nn.PixelShuffle(2)\n        # self.debayer = Debayer3x3()\n        \n\n        # split image into non-overlapping patches\n        self.patch_embed = swu.PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # merge non-overlapping patches into image\n        self.patch_unembed = swu.PatchUnEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n\n        #####################################################################################################\n        ################################### 1, shallow feature extraction ###################################\n        self.conv_first = nn.Conv2d(num_in_ch*(1+2*0), embed_dim, 3, 1, 1, bias=True)\n        \n        # # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        if args.swinfeature:\n            if self.args.local_rank <= 0:\n                print(\"using swinfeature\")\n            self.pre_layers = nn.ModuleList()\n            for i_layer in range(depths[0]):\n                layer = swu.SwinTransformerBlock(dim=embed_dim, \n                            input_resolution=(patches_resolution[0]//2,\n                                              patches_resolution[1]//2),\n                             num_heads=num_heads[i_layer], window_size=window_size,\n                             shift_size=0 if (i_layer % 2 == 0) else window_size // 2,\n                             mlp_ratio=mlp_ratio,\n                             qkv_bias=qkv_bias, qk_scale=qk_scale,\n                             drop=drop_rate, attn_drop=attn_drop_rate,\n                             drop_path=dpr[i_layer],\n                             norm_layer=norm_layer)\n                self.pre_layers.append(layer)\n\n            # self.pre_linear = nn.Linear(embed_dim, embed_dim)\n            self.pre_norm = norm_layer(embed_dim)\n        else:\n            WARB = functools.partial(arch_util.WideActResBlock, nf=embed_dim)\n            self.feature_extraction = arch_util.make_layer(WARB, 5)\n\n        self.conv_after_pre_layer = nn.Conv2d(embed_dim, num_feat*4, 3, 1, 1, bias=True)\n        self.mid_ps = nn.PixelShuffle(2)\n\n        self.fea_L2_conv1 = nn.Conv2d(num_feat, num_feat*2, 3, 2, 1, bias=True)\n        self.fea_L3_conv1 = nn.Conv2d(num_feat*2, num_feat*4, 3, 2, 1, bias=True)\n\n        #####################################################################################################\n        ################################### 2, Feature Enhanced PCD Align ###################################\n\n        # Top layers\n        self.toplayer = nn.Conv2d(num_feat*4, num_feat, kernel_size=1, stride=1, padding=0)\n        # Smooth layers\n        self.smooth1 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1)\n        self.smooth2 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1)\n        # Lateral layers\n        self.latlayer1 = nn.Conv2d(num_feat*2, num_feat, kernel_size=1, stride=1, padding=0)\n        self.latlayer2 = nn.Conv2d(num_feat*1, num_feat, kernel_size=1, stride=1, padding=0)\n\n        # self.align = PCD_Align(nf=num_feat, groups=groups)\n        self.align = FlowGuidedPCDAlign(nf=num_feat, groups=groups)\n        #####################################################################################################\n        ################################### 3, Multi-frame Feature Fusion  ##################################\n\n        if self.non_local:\n            if self.args.local_rank <= 0:\n                print(\"using non_local\")\n            self.fusion = CrossNonLocal_Fusion(nf=num_feat, out_feat=embed_dim, nframes=nframes, center=self.center)\n        else:\n            self.fusion = nn.Conv2d(nframes * num_feat, embed_dim, 1, 1, bias=True)\n\n        #####################################################################################################\n        ################################### 4, deep feature extraction ######################################\n\n        # absolute position embedding\n        # if self.ape:\n        #     self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n        #     swu.trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        # self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # # build Residual Swin Transformer blocks (RSTB)\n        # self.layers = nn.ModuleList()\n        # for i_layer in range(1, self.num_layers):\n        #     layer = swu.RSTB(dim=embed_dim,\n        #                  input_resolution=(patches_resolution[0],\n        #                                    patches_resolution[1]),\n        #                  depth=depths[i_layer],\n        #                  num_heads=num_heads[i_layer],\n        #                  window_size=window_size,\n        #                  mlp_ratio=self.mlp_ratio,\n        #                  qkv_bias=qkv_bias, qk_scale=qk_scale,\n        #                  drop=drop_rate, attn_drop=attn_drop_rate,\n        #                  drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results\n        #                  norm_layer=norm_layer,\n        #                  downsample=None,\n        #                  use_checkpoint=use_checkpoint,\n        #                  img_size=img_size,\n        #                  patch_size=patch_size\n        #                  )\n        #     self.layers.append(layer)\n        \n        # self.norm = norm_layer(self.num_features)\n\n        LRCN = functools.partial(arch_util.LRSCWideActResGroup, n_resblocks=n_resblocks, nf=embed_dim)\n        self.post_feature_extraction = nn.Sequential(arch_util.make_layer_idx(LRCN, back_RBs),\n                nn.Conv2d(embed_dim*(back_RBs+1), num_feat, 1))\n\n        # self.post_feature_extraction = nn.Sequential(\n        #         arch_util.make_layer(WARB, 20),\n        #         nn.Conv2d(embed_dim, embed_dim, 3, 1, 1))\n\n        # build the last conv layer in deep feature extraction\n        # self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)\n\n        #####################################################################################################\n        ################################ 5, high quality image reconstruction ################################\n\n        self.upconv1 = nn.Conv2d(embed_dim, num_feat * 4, 3, 1, 1, bias=True)\n        self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)\n        self.pixel_shuffle = nn.PixelShuffle(2)\n        self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)\n        self.conv_last = nn.Conv2d(64, args.n_colors, 3, 1, 1, bias=True)\n\n        #### skip #############\n        self.skip_pixel_shuffle = nn.PixelShuffle(2)\n        self.skipup1 = nn.Conv2d(num_in_ch//4, num_feat * 4, 3, 1, 1, bias=True)\n        self.skipup2 = nn.Conv2d(num_feat, args.n_colors * 4, 3, 1, 1, bias=True)\n\n        #### activation function\n        self.lrelu = nn.LeakyReLU(0.1, inplace=True)\n        self.lrelu2 = nn.LeakyReLU(0.1, inplace=True)\n\n        # self.apply(self._init_weights)\n\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            swu.trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def _upsample_add(self, x, y):\n        return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + y\n\n    def check_image_size(self, x):\n        _, _, h, w = x.size()\n        mod_pad_h = (self.window_size - h % self.window_size) % self.window_size\n        mod_pad_w = (self.window_size - w % self.window_size) % self.window_size\n        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')\n        return x\n\n    def pre_forward_features(self, x):\n        if self.args.swinfeature:\n            x_size = (x.shape[-2], x.shape[-1])\n            x = self.patch_embed(x, use_norm=True)\n            if self.ape:\n                x = x + self.absolute_pos_embed\n            x = self.pos_drop(x)\n\n            for idx, layer in enumerate(self.pre_layers):\n                x = layer(x, x_size)\n\n            x = self.pre_norm(x)\n            x = self.patch_unembed(x, x_size)\n\n        else:\n            x = self.feature_extraction(x)\n\n        return x\n\n    def forward_features(self, x):\n        # x_size = (x.shape[-2], x.shape[-1])\n        # x = self.patch_embed(x)\n        # if self.ape:\n        #     x = x + self.absolute_pos_embed\n        # x = self.pos_drop(x)\n\n        # for idx, layer in enumerate(self.layers):\n        #     x = layer(x, x_size)\n\n        # x = self.norm(x)  # B L C\n        # x = self.patch_unembed(x, x_size)\n\n        x = self.post_feature_extraction(x)\n\n        return x\n\n    @autocast()\n    def forward(self, x, print_time=False):\n        B, N, C, H, W = x.size()  # N video frames\n        x_center = x[:, self.center, :, :, :].contiguous()\n\n        #### skip module ########\n        skip1 = self.lrelu2(self.skip_pixel_shuffle(self.skipup1(self.skip_pixel_shuffle(x_center))))\n        skip2 = self.skip_pixel_shuffle(self.skipup2(skip1))\n\n        x_ = self.conv_flow(self.flow_ps(x.view(B*N, C, H, W))).view(B, N, -1, H*2, W*2)\n        \n        # calculate flows\n        ref_flows = self.get_ref_flows(x_)\n        # flows_backward, flows_forward = self.get_flow_2frames(x_)\n        # # # warp input\n        # x_backward, x_forward = self.get_aligned_image_2frames(x,  flows_backward[1], flows_forward[1])\n        # x = torch.cat([x, x_backward, x_forward], 2)\n\n        #### extract LR features\n        x = self.lrelu(self.conv_first(x.view(B*N, -1, H, W)))\n        L1_fea = self.mid_ps(self.conv_after_pre_layer(self.pre_forward_features(x)))\n        _, _, H, W = L1_fea.size()\n\n        L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))\n        L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))\n\n        # FPN enhance features\n        L3_fea = self.lrelu(self.toplayer(L3_fea))\n        L2_fea = self.smooth1(self._upsample_add(L3_fea, self.latlayer1(L2_fea)))\n        L1_fea = self.smooth2(self._upsample_add(L2_fea, self.latlayer2(L1_fea)))\n\n        L1_fea = L1_fea.view(B, N, -1, H, W).contiguous()\n        L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2 ).contiguous()\n        L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4).contiguous()\n\n        #### PCD align\n        # ref feature list\n        ref_fea_l = [\n            L1_fea[:, self.center, :, :, :].clone(), \n            L2_fea[:, self.center, :, :, :].clone(),\n            L3_fea[:, self.center, :, :, :].clone()\n        ]\n        aligned_fea = []\n        for i in range(N):\n            nbr_fea_l = [\n                L1_fea[:, i, :, :, :].clone(), \n                L2_fea[:, i, :, :, :].clone(),\n                L3_fea[:, i, :, :, :].clone()\n            ]\n            flows_l = [\n                ref_flows[0][:, i, :, :, :].clone(), \n                ref_flows[1][:, i, :, :, :].clone(), \n                ref_flows[2][:, i, :, :, :].clone()\n            ]\n            # print(nbr_fea_l[0].shape, flows_l[0].shape)\n            nbr_warped_l = [\n                arch_util.flow_warp(nbr_fea_l[0], flows_l[0].permute(0, 2, 3, 1), 'bilinear'),\n                arch_util.flow_warp(nbr_fea_l[1], flows_l[1].permute(0, 2, 3, 1), 'bilinear'),\n                arch_util.flow_warp(nbr_fea_l[2], flows_l[2].permute(0, 2, 3, 1), 'bilinear')\n            ]\n            aligned_fea.append(self.align(nbr_fea_l, nbr_warped_l, ref_fea_l, flows_l))\n            # aligned_fea.append(self.align(nbr_fea_l, ref_fea_l))\n        aligned_fea = torch.stack(aligned_fea, dim=1)  # [B, N, C, H, W] --> [B, T, C, H, W]\n\n        if not self.non_local:\n            aligned_fea = aligned_fea.view(B, -1, H, W)\n\n        x = self.lrelu(self.fusion(aligned_fea))\n        x = self.forward_features(x)\n\n        x = self.lrelu(self.pixel_shuffle(self.upconv1(x)))\n        x = skip1 + x\n        x = self.lrelu(self.pixel_shuffle(self.upconv2(x)))\n        x = self.lrelu(self.HRconv(x))\n        x = self.conv_last(x)\n\n        x = skip2 + x\n        return x\n\n\n    def get_ref_flows(self, x):\n        '''Get flow between frames ref and other'''\n\n        b, n, c, h, w = x.size()\n        x_nbr = x.reshape(-1, c, h, w)\n        x_ref = x[:, self.center:self.center+1, :, :, :].repeat(1, n, 1, 1, 1).reshape(-1, c, h, w)\n\n        # backward\n        flows = self.spynet(x_ref, x_nbr)\n        flows_list = [flow.view(b, n, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in\n                          zip(flows, range(3))]\n\n        return flows_list\n\n\n    def get_flow_2frames(self, x):\n        '''Get flow between frames t and t+1 from x.'''\n\n        b, n, c, h, w = x.size()\n        x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)\n        x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)\n\n        # backward\n        flows_backward = self.spynet(x_1, x_2)\n        flows_backward = [flow.view(b, n-1, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in\n                          zip(flows_backward, range(3))]\n\n        # forward\n        flows_forward = self.spynet(x_2, x_1)\n        flows_forward = [flow.view(b, n-1, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in\n                         zip(flows_forward, range(3))]\n\n        return flows_backward, flows_forward\n\n\n    def get_aligned_image_2frames(self, x, flows_backward, flows_forward):\n        '''Parallel feature warping for 2 frames.'''\n\n        # backward\n        n = x.size(1)\n        x_backward = [torch.zeros_like(x[:, -1, ...]).repeat(1, 4, 1, 1)]\n        for i in range(n - 1, 0, -1):\n            x_i = x[:, i, ...]\n            flow = flows_backward[:, i - 1, ...]\n            x_backward.insert(0, arch_util.flow_warp(x_i, flow.permute(0, 2, 3, 1), 'nearest4')) # frame i+1 aligned towards i\n\n        # forward\n        x_forward = [torch.zeros_like(x[:, 0, ...]).repeat(1, 4, 1, 1)]\n        for i in range(0, n - 1):\n            x_i = x[:, i, ...]\n            flow = flows_forward[:, i, ...]\n            x_forward.append(arch_util.flow_warp(x_i, flow.permute(0, 2, 3, 1), 'nearest4')) # frame i-1 aligned towards i\n\n        return [torch.stack(x_backward, 1), torch.stack(x_forward, 1)]\n\n\n    def get_aligned_feature_2frames(self, x):\n        '''Parallel feature warping for 2 frames.'''\n\n        # backward\n        n = x.size(1)\n        x_backward = [torch.zeros_like(x[:, -1, ...])]\n        for i in range(n - 1, 0, -1):\n            # x_i = x[:, i, ...]\n            # flow = flows_backward[0][:, i - 1, ...]\n            # x_i_warped = arch_util.flow_warp(x_i, flow.permute(0, 2, 3, 1), 'bilinear')  # frame i+1 aligned towards i\n            # x_backward.insert(0, self.FDCN(x_i, x_i_warped, x[:, i - 1, ...], flow))\n            offset = self.offset_conv(torch.cat([x[:, i, ...], x[:, i - 1, ...]], dim=1))\n            x_backward.insert(0, self.FDCN(x[:, i, ...].clone(), offset))\n\n        # forward\n        x_forward = [torch.zeros_like(x[:, 0, ...])]\n        for i in range(0, n - 1):\n            # x_i = x[:, i, ...]\n            # flow = flows_forward[0][:, i, ...]\n            # x_i_warped = arch_util.flow_warp(x_i, flow.permute(0, 2, 3, 1), 'bilinear')  # frame i-1 aligned towards i\n            # x_forward.append(self.FDCN(x_i, x_i_warped, x[:, i + 1, ...], flow))\n            offset = self.offset_conv(torch.cat([x[:, i, ...], x[:, i + 1, ...]], dim=1))\n            x_forward.insert(0, self.FDCN(x[:, i, ...].clone(), offset))\n\n        return [torch.stack(x_backward, 1), torch.stack(x_forward, 1)]\n\n\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "code/synthetic/bsrt/model/non_local/network.py",
    "content": "from torch import nn\n# from lib.non_local_concatenation import NONLocalBlock2D\n# from lib.non_local_gaussian import NONLocalBlock2D\nfrom lib.non_local_embedded_gaussian import NONLocalBlock2D\n# from lib.non_local_dot_product import NONLocalBlock2D\n\n\nclass Network(nn.Module):\n    def __init__(self):\n        super(Network, self).__init__()\n\n        self.conv_1 = nn.Sequential(\n            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(32),\n            nn.ReLU(),\n            nn.MaxPool2d(2),\n        )\n\n        self.nl_1 = NONLocalBlock2D(in_channels=32)\n        self.conv_2 = nn.Sequential(\n            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(64),\n            nn.ReLU(),\n            nn.MaxPool2d(2),\n        )\n\n        self.nl_2 = NONLocalBlock2D(in_channels=64)\n        self.conv_3 = nn.Sequential(\n            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(128),\n            nn.ReLU(),\n            nn.MaxPool2d(2),\n        )\n\n        self.fc = nn.Sequential(\n            nn.Linear(in_features=128*3*3, out_features=256),\n            nn.ReLU(),\n            nn.Dropout(0.5),\n\n            nn.Linear(in_features=256, out_features=10)\n        )\n\n    def forward(self, x):\n        batch_size = x.size(0)\n\n        feature_1 = self.conv_1(x)\n        nl_feature_1 = self.nl_1(feature_1)\n\n        feature_2 = self.conv_2(nl_feature_1)\n        nl_feature_2 = self.nl_2(feature_2)\n\n        output = self.conv_3(nl_feature_2).view(batch_size, -1)\n        output = self.fc(output)\n\n        return output\n\n    def forward_with_nl_map(self, x):\n        batch_size = x.size(0)\n\n        feature_1 = self.conv_1(x)\n        nl_feature_1, nl_map_1 = self.nl_1(feature_1, return_nl_map=True)\n\n        feature_2 = self.conv_2(nl_feature_1)\n        nl_feature_2, nl_map_2 = self.nl_2(feature_2, return_nl_map=True)\n\n        output = self.conv_3(nl_feature_2).view(batch_size, -1)\n        output = self.fc(output)\n\n        return output, [nl_map_1, nl_map_2]\n\n\nif __name__ == '__main__':\n    import torch\n\n    img = torch.randn(3, 1, 28, 28)\n    net = Network()\n    out = net(img)\n    print(out.size())\n\n"
  },
  {
    "path": "code/synthetic/bsrt/model/non_local/non_local_concatenation.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(2))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                             kernel_size=1, stride=1, padding=0)\n\n        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                           kernel_size=1, stride=1, padding=0)\n\n        self.concat_project = nn.Sequential(\n            nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),\n            nn.ReLU()\n        )\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = nn.Sequential(self.phi, max_pool_layer)\n\n    def forward(self, x, return_nl_map=False):\n        '''\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        '''\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        # (b, c, N, 1)\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)\n        # (b, c, 1, N)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)\n\n        h = theta_x.size(2)\n        w = phi_x.size(3)\n        theta_x = theta_x.repeat(1, 1, 1, w)\n        phi_x = phi_x.repeat(1, 1, h, 1)\n\n        concat_feature = torch.cat([theta_x, phi_x], dim=1)\n        f = self.concat_project(concat_feature)\n        b, _, h, w = f.size()\n        f = f.view(b, h, w)\n\n        N = f.size(-1)\n        f_div_C = f / N\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n"
  },
  {
    "path": "code/synthetic/bsrt/model/non_local/non_local_cross_dot_product.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(4))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                             kernel_size=1, stride=1, padding=0)\n\n        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                           kernel_size=1, stride=1, padding=0)\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = nn.Sequential(self.phi, max_pool_layer)\n\n    def forward(self, x, ref, return_nl_map=False):\n        \"\"\"\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        \"\"\"\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_ref = self.theta(ref).view(batch_size, self.inter_channels, -1)\n        theta_ref = theta_ref.permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n        f = torch.matmul(theta_ref, phi_x)\n        N = f.size(-1)\n        f_div_C = f / N\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n\n\n"
  },
  {
    "path": "code/synthetic/bsrt/model/non_local/non_local_dot_product.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(2))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                             kernel_size=1, stride=1, padding=0)\n\n        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                           kernel_size=1, stride=1, padding=0)\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = nn.Sequential(self.phi, max_pool_layer)\n\n    def forward(self, x, return_nl_map=False):\n        \"\"\"\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        \"\"\"\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n        f = torch.matmul(theta_x, phi_x)\n        N = f.size(-1)\n        f_div_C = f / N\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n\n\n"
  },
  {
    "path": "code/synthetic/bsrt/model/non_local/non_local_embedded_gaussian.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        \"\"\"\n        :param in_channels:\n        :param inter_channels:\n        :param dimension:\n        :param sub_sample:\n        :param bn_layer:\n        \"\"\"\n\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(2))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                             kernel_size=1, stride=1, padding=0)\n        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                           kernel_size=1, stride=1, padding=0)\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = nn.Sequential(self.phi, max_pool_layer)\n\n    def forward(self, x, return_nl_map=False):\n        \"\"\"\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        \"\"\"\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n        f = torch.matmul(theta_x, phi_x)\n        f_div_C = F.softmax(f, dim=-1)\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer,)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer,)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n\n"
  },
  {
    "path": "code/synthetic/bsrt/model/non_local/non_local_gaussian.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool_layer = nn.MaxPool1d(kernel_size=(2))\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant_(self.W[1].weight, 0)\n            nn.init.constant_(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant_(self.W.weight, 0)\n            nn.init.constant_(self.W.bias, 0)\n\n        if sub_sample:\n            self.g = nn.Sequential(self.g, max_pool_layer)\n            self.phi = max_pool_layer\n\n    def forward(self, x, return_nl_map=False):\n        \"\"\"\n        :param x: (b, c, t, h, w)\n        :param return_nl_map: if True return z, nl_map, else only return z.\n        :return:\n        \"\"\"\n\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = x.view(batch_size, self.in_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n\n        if self.sub_sample:\n            phi_x = self.phi(x).view(batch_size, self.in_channels, -1)\n        else:\n            phi_x = x.view(batch_size, self.in_channels, -1)\n\n        f = torch.matmul(theta_x, phi_x)\n        f_div_C = F.softmax(f, dim=-1)\n\n        # if self.store_last_batch_nl_map:\n        #     self.nl_map = f_div_C\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        if return_nl_map:\n            return z, f_div_C\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nif __name__ == '__main__':\n    import torch\n\n    for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:\n        img = torch.zeros(2, 3, 20)\n        net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.zeros(2, 3, 20, 20)\n        net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n        img = torch.randn(2, 3, 8, 20, 20)\n        net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)\n        out = net(img)\n        print(out.size())\n\n\n\n\n\n\n"
  },
  {
    "path": "code/synthetic/bsrt/model/swin_util.py",
    "content": "# -----------------------------------------------------------------------------------\n# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257\n# Originally Written by Ze Liu, Modified by Jingyun Liang.\n# -----------------------------------------------------------------------------------\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n# import torch.utils.checkpoint as checkpoint\nfrom model.checkpoint import CheckpointFunction as checkpoint\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom functools import reduce, lru_cache\nimport time\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n@lru_cache()\ndef calculate_mask(x_size, window_size, shift_size):\n    # calculate attention mask for SW-MSA\n    H, W = x_size\n    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n    h_slices = (slice(0, -window_size),\n                slice(-window_size, -shift_size),\n                slice(-shift_size, None))\n    w_slices = (slice(0, -window_size),\n                slice(-window_size, -shift_size),\n                slice(-shift_size, None))\n    cnt = 0\n    for h in h_slices:\n        for w in w_slices:\n            img_mask[:, h, w, :] = cnt\n            cnt += 1\n\n    mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1\n    mask_windows = mask_windows.view(-1, window_size * window_size)\n    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n\n    return attn_mask\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        self.use_checkpoint = use_checkpoint\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n\n    def forward(self, x, x_size):\n        H, W = x_size\n        B, L, C = x.shape\n        # assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size\n\n        attn_mask = calculate_mask(x_size, self.window_size, self.shift_size).to(x.device)\n        attn_windows = self.attn(x_windows, mask=attn_mask)\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = False\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer, use_checkpoint=use_checkpoint)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x, x_size):\n        for i, blk in enumerate(self.blocks):\n            if self.use_checkpoint:\n                # x = checkpoint.checkpoint(blk, x, x_size)\n                x = checkpoint.apply(blk, 2, x, x_size)\n            else:\n                x = blk(x, x_size)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass RSTB(nn.Module):\n    \"\"\"Residual Swin Transformer Block (RSTB).\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n        img_size: Input image size.\n        patch_size: Patch size.\n        resi_connection: The convolutional block before residual connection.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n                 img_size=224, patch_size=4, resi_connection='1conv'):\n        super(RSTB, self).__init__()\n\n        # print(f'dim: {dim}, input_resolution: {input_resolution}, depth: {depth}, num_heads: {num_heads}, window_size: {window_size}, img_size: {img_size}. patch_size: {patch_size}')\n\n        self.dim = dim\n        self.input_resolution = input_resolution\n\n        self.residual_group = BasicLayer(dim=dim,\n                                         input_resolution=input_resolution,\n                                         depth=depth,\n                                         num_heads=num_heads,\n                                         window_size=window_size,\n                                         mlp_ratio=mlp_ratio,\n                                         qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                         drop=drop, attn_drop=attn_drop,\n                                         drop_path=drop_path,\n                                         norm_layer=norm_layer,\n                                         downsample=downsample,\n                                         use_checkpoint=use_checkpoint)\n\n        if resi_connection == '1conv':\n            self.conv = nn.Conv2d(dim, dim, 3, 1, 1)\n\n        elif resi_connection == '3conv':\n            # to save parameters and memory\n            self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),\n                                      nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),\n                                      nn.LeakyReLU(negative_slope=0.2, inplace=True),\n                                      nn.Conv2d(dim // 4, dim, 3, 1, 1))\n\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,\n            norm_layer=None)\n\n        self.patch_unembed = PatchUnEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,\n            norm_layer=None)\n\n    def forward(self, x, x_size):\n        x = self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x\n        return x\n\n\n    def flops(self):\n        flops = 0\n        flops += self.residual_group.flops()\n        H, W = self.input_resolution\n        flops += H * W * self.dim * self.dim * 9\n        flops += self.patch_embed.flops()\n        flops += self.patch_unembed.flops()\n\n        return flops\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x, use_norm=True):\n        x = x.flatten(2).transpose(1, 2)  # B Ph*Pw C\n        if use_norm and self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        H, W = self.img_size\n        if self.norm is not None:\n            flops += H * W * self.embed_dim\n        return flops\n\n\nclass PatchUnEmbed(nn.Module):\n    r\"\"\" Image to Patch Unembedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n    def forward(self, x, x_size):\n        B, HW, C = x.shape\n        x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1])  # B Ph*Pw C\n        return x\n\n    def flops(self):\n        flops = 0\n        return flops\n\n"
  },
  {
    "path": "code/synthetic/bsrt/model/utils/interp_methods.py",
    "content": "from math import pi\n\ntry:\n    import torch\nexcept ImportError:\n    torch = None\n\ntry:\n    import numpy\nexcept ImportError:\n    numpy = None\n\nif numpy is None and torch is None:\n    raise ImportError(\"Must have either Numpy or PyTorch but both not found\")\n\n\ndef set_framework_dependencies(x):\n    if type(x) is numpy.ndarray:\n        to_dtype = lambda a: a\n        fw = numpy\n    else:\n        to_dtype = lambda a: a.to(x.dtype)\n        fw = torch\n    eps = fw.finfo(fw.float32).eps\n    return fw, to_dtype, eps\n\n\ndef support_sz(sz):\n    def wrapper(f):\n        f.support_sz = sz\n        return f\n    return wrapper\n\n@support_sz(4)\ndef cubic(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    absx = fw.abs(x)\n    absx2 = absx ** 2\n    absx3 = absx ** 3\n    return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +\n            (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *\n            to_dtype((1. < absx) & (absx <= 2.)))\n\n@support_sz(4)\ndef lanczos2(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /\n            ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))\n\n@support_sz(6)\ndef lanczos3(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /\n            ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))\n\n@support_sz(2)\ndef linear(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *\n            to_dtype((0 <= x) & (x <= 1)))\n\n@support_sz(1)\ndef box(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))\n"
  },
  {
    "path": "code/synthetic/bsrt/model/utils/psconv.py",
    "content": "import torch\nimport torch.nn as nn\n\nclass PyConv2d(nn.Module):\n    \"\"\"PyConv2d with padding (general case). Applies a 2D PyConv over an input signal composed of several input planes.\n    Args:\n        in_channels (int): Number of channels in the input image\n        out_channels (list): Number of channels for each pyramid level produced by the convolution\n        pyconv_kernels (list): Spatial size of the kernel for each pyramid level\n        pyconv_groups (list): Number of blocked connections from input channels to output channels for each pyramid level\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``\n    Example::\n        >>> # PyConv with two pyramid levels, kernels: 3x3, 5x5\n        >>> m = PyConv2d(in_channels=64, out_channels=[32, 32], pyconv_kernels=[3, 5], pyconv_groups=[1, 4])\n        >>> input = torch.randn(4, 64, 56, 56)\n        >>> output = m(input)\n        >>> # PyConv with three pyramid levels, kernels: 3x3, 5x5, 7x7\n        >>> m = PyConv2d(in_channels=64, out_channels=[16, 16, 32], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])\n        >>> input = torch.randn(4, 64, 56, 56)\n        >>> output = m(input)\n    \"\"\"\n    def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_groups, stride=1, dilation=1, bias=False):\n        super(PyConv2d, self).__init__()\n\n        assert len(out_channels) == len(pyconv_kernels) == len(pyconv_groups)\n\n        self.pyconv_levels = [None] * len(pyconv_kernels)\n        for i in range(len(pyconv_kernels)):\n            self.pyconv_levels[i] = nn.Conv2d(in_channels, out_channels[i], kernel_size=pyconv_kernels[i],\n                                              stride=stride, padding=pyconv_kernels[i] // 2, groups=pyconv_groups[i],\n                                              dilation=dilation, bias=bias)\n        self.pyconv_levels = nn.ModuleList(self.pyconv_levels)\n\n    def forward(self, x):\n        out = []\n        for level in self.pyconv_levels:\n            out.append(level(x))\n\n        return torch.cat(out, 1)\n\n################################################################\n\nclass PSConv2d(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, parts=4, bias=False):\n        super(PSConv2d, self).__init__()\n        self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, dilation, dilation, groups=parts, bias=bias)\n        self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * dilation, 2 * dilation, groups=parts, bias=bias)\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)\n\n        def backward_hook(grad):\n            out = grad.clone()\n            out[self.mask] = 0\n            return out\n\n        self.mask = torch.zeros(self.conv.weight.shape).byte().cuda()\n        _in_channels = in_channels // parts\n        _out_channels = out_channels // parts\n        for i in range(parts):\n            self.mask[i * _out_channels: (i + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1\n            self.mask[(i + parts//2)%parts * _out_channels: ((i + parts//2)%parts + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1\n        self.conv.weight.data[self.mask] = 0\n        self.conv.weight.register_hook(backward_hook)\n\n        self.weight = self.conv.weight\n        self.bias = self.conv.bias\n\n    def forward(self, x):\n        x1, x2 = x.chunk(2, dim=1)\n        x_shift = self.gwconv_shift(torch.cat((x2, x1), dim=1))\n        return self.gwconv(x) + self.conv(x) + x_shift\n\n\n# PSConv-based Group Convolution\nclass PSGConv2d(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, parts=4, bias=False):\n        super(PSGConv2d, self).__init__()\n        self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups * parts, bias=bias)\n        self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * padding, 2 * dilation, groups=groups * parts, bias=bias)\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias)\n\n        def backward_hook(grad):\n            out = grad.clone()\n            out[self.mask] = 0\n            return out\n\n        self.mask = torch.zeros(self.conv.weight.shape).bool().cuda()\n        _in_channels = in_channels // (groups * parts)\n        _out_channels = out_channels // (groups * parts)\n        for i in range(parts):\n            for j in range(groups):\n                self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1\n                self.mask[((i + parts // 2) % parts + j * groups) * _out_channels: ((i + parts // 2) % parts + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1\n        self.conv.weight.data[self.mask] = 0\n        self.conv.weight.register_hook(backward_hook)\n        self.groups = groups\n\n        self.weight = self.conv.weight\n        self.bias = self.conv.bias\n\n    def forward(self, x):\n        x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1))\n        x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1)\n        x_shift = self.gwconv_shift(x_merge)\n        gx = self.gwconv(x)\n        cx = self.conv(x)\n        # print(x.shape, gx.shape, cx.shape, x_merge.shape, x_shift.shape)\n        return gx + cx + x_shift\n\n"
  },
  {
    "path": "code/synthetic/bsrt/model/utils/resize_right.py",
    "content": "import warnings\nfrom math import ceil\nimport model.utils.interp_methods as interp_methods\n\n\nclass NoneClass:\n    pass\n\ntry:\n    import torch\n    from torch import nn\n    nnModuleWrapped = nn.Module\nexcept ImportError:\n    warnings.warn('No PyTorch found, will work only with Numpy')\n    torch = None\n    nnModuleWrapped = NoneClass\n\ntry:\n    import numpy\nexcept ImportError:\n    warnings.warn('No Numpy found, will work only with PyTorch')\n    numpy = None\n\n\nif numpy is None and torch is None:\n    raise ImportError(\"Must have either Numpy or PyTorch but both not found\")\n\n\ndef resize(input, scale_factors=None, out_shape=None,\n           interp_method=interp_methods.cubic, support_sz=None,\n           antialiasing=True):\n    # get properties of the input tensor\n    in_shape, n_dims = input.shape, input.ndim\n\n    # fw stands for framework that can be either numpy or torch,\n    # determined by the input type\n    fw = numpy if type(input) is numpy.ndarray else torch\n    eps = fw.finfo(fw.float32).eps\n\n    # set missing scale factors or output shapem one according to another,\n    # scream if both missing\n    scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,\n                                                    scale_factors, fw)\n\n    # sort indices of dimensions according to scale of each dimension.\n    # since we are going dim by dim this is efficient\n    sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])\n                                       for dim in sorted(range(n_dims),\n                                       key=lambda ind: scale_factors[ind])\n                                       if scale_factors[dim] != 1.]\n\n    # unless support size is specified by the user, it is an attribute\n    # of the interpolation method\n    if support_sz is None:\n        support_sz = interp_method.support_sz\n\n    # when using pytorch, we need to know what is the input tensor device\n    if fw is torch:\n        device = input.device\n\n    # output begins identical to input and changes with each iteration\n    output = input\n\n    # iterate over dims\n    for dim, scale_factor in sorted_filtered_dims_and_scales:\n\n        # get 1d set of weights and fields of view for each output location\n        # along this dim\n        field_of_view, weights = prepare_weights_and_field_of_view_1d(\n            dim, scale_factor, in_shape[dim], out_shape[dim], interp_method,\n            support_sz, antialiasing, fw, eps, device)\n\n        # multiply the weights by the values in the field of view and\n        # aggreagate\n        output = apply_weights(output, field_of_view, weights, dim, n_dims,\n                               fw)\n    return output\n\n\nclass ResizeLayer(nnModuleWrapped):\n    def __init__(self, in_shape, scale_factors=None, out_shape=None,\n                 interp_method=interp_methods.cubic, support_sz=None,\n                 antialiasing=True):\n        super(ResizeLayer, self).__init__()\n\n        # fw stands for framework, that can be either numpy or torch. since\n        # this is a torch layer, only one option in this case.\n        fw = torch\n        eps = fw.finfo(fw.float32).eps\n\n        # set missing scale factors or output shapem one according to another,\n        # scream if both missing\n        scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,\n                                                        scale_factors, fw)\n\n        # unless support size is specified by the user, it is an attribute\n        # of the interpolation method\n        if support_sz is None:\n            support_sz = interp_method.support_sz\n\n        self.n_dims = len(in_shape)\n\n        # sort indices of dimensions according to scale of each dimension.\n        # since we are going dim by dim this is efficient\n        self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])\n                                                for dim in\n                                                sorted(range(self.n_dims),\n                                                key=lambda ind:\n                                                scale_factors[ind])\n                                                if scale_factors[dim] != 1.]\n\n        # iterate over dims\n        field_of_view_list = []\n        weights_list = []\n        for dim, scale_factor in self.sorted_filtered_dims_and_scales:\n\n            # get 1d set of weights and fields of view for each output\n            # location along this dim\n            field_of_view, weights = prepare_weights_and_field_of_view_1d(\n                dim, scale_factor, in_shape[dim], out_shape[dim],\n                interp_method, support_sz, antialiasing, fw, eps, input.device)\n\n            # keep weights and fields of views for all dims\n            weights_list.append(nn.Parameter(weights, requires_grad=False))\n            field_of_view_list.append(nn.Parameter(field_of_view,\n                                      requires_grad=False))\n\n        self.field_of_view = nn.ParameterList(field_of_view_list)\n        self.weights = nn.ParameterList(weights_list)\n        self.in_shape = in_shape\n\n    def forward(self, input):\n        # output begins identical to input and changes with each iteration\n        output = input\n\n        for (dim, scale_factor), field_of_view, weights in zip(\n                self.sorted_filtered_dims_and_scales,\n                self.field_of_view,\n                self.weights):\n            # multiply the weights by the values in the field of view and\n            # aggreagate\n            output = apply_weights(output, field_of_view, weights, dim,\n                                   self.n_dims, torch)\n        return output\n\n\ndef prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,\n                                         interp_method, support_sz,\n                                         antialiasing, fw, eps, device=None):\n    # If antialiasing is taking place, we modify the window size and the\n    # interpolation method (see inside function)\n    interp_method, cur_support_sz = apply_antialiasing_if_needed(\n                                                             interp_method,\n                                                             support_sz,\n                                                             scale_factor,\n                                                             antialiasing)\n\n    # STEP 1- PROJECTED GRID: The non-integer locations of the projection of\n    # output pixel locations to the input tensor\n    projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device)\n\n    # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels\n    # that influence it\n    field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz,\n                                      fw, eps)\n\n    # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the\n    # field of view for each output pixel\n    weights = get_weights(interp_method, projected_grid, field_of_view)\n\n    return field_of_view, weights\n\n\ndef apply_weights(input, field_of_view, weights, dim, n_dims, fw):\n    # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying\n    # its set of weights with the pixel values in its field of view.\n    # We now multiply the fields of view with their matching weights.\n    # We do this by tensor multiplication and broadcasting.\n    # this step is separated to a different function, so that it can be\n    # repeated with the same calculated weights and fields.\n\n    # for this operations we assume the resized dim is the first one.\n    # so we transpose and will transpose back after multiplying\n    tmp_input = fw_swapaxes(input, dim, 0, fw)\n\n    # field_of_view is a tensor of order 2: for each output (1d location\n    # along cur dim)- a list of 1d neighbors locations.\n    # note that this whole operations is applied to each dim separately,\n    # this is why it is all in 1d.\n    # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:\n    # for each output pixel (this time indicated in all dims), these are the\n    # values of the neighbors in the 1d field of view. note that we only\n    # consider neighbors along the current dim, but such set exists for every\n    # multi-dim location, hence the final tensor order is image_dims+1.\n    neighbors = tmp_input[field_of_view]\n\n    # weights is an order 2 tensor: for each output location along 1d- a list\n    # of weighs matching the field of view. we augment it with ones, for\n    # broadcasting, so that when multiplies some tensor the weights affect\n    # only its first dim.\n    tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1)))\n\n    # now we simply multiply the weights with the neighbors, and then sum\n    # along the field of view, to get a single value per out pixel\n    tmp_output = (neighbors * tmp_weights).sum(1)\n\n    # we transpose back the resized dim to its original position\n    return fw_swapaxes(tmp_output, 0, dim, fw)\n\n\ndef set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):\n    # eventually we must have both scale-factors and out-sizes for all in/out\n    # dims. however, we support many possible partial arguments\n    if scale_factors is None and out_shape is None:\n        raise ValueError(\"either scale_factors or out_shape should be \"\n                         \"provided\")\n    if out_shape is not None:\n        # if out_shape has less dims than in_shape, we defaultly resize the\n        # first dims for numpy and last dims for torch\n        out_shape = (list(out_shape) + list(in_shape[:-len(out_shape)])\n                     if fw is numpy\n                     else list(in_shape[:-len(out_shape)]) + list(out_shape))\n        if scale_factors is None:\n            # if no scale given, we calculate it as the out to in ratio\n            # (not recomended)\n            scale_factors = [out_sz / in_sz for out_sz, in_sz\n                             in zip(out_shape, in_shape)]\n    if scale_factors is not None:\n        # by default, if a single number is given as scale, we assume resizing\n        # two dims (most common are images with 2 spatial dims)\n        scale_factors = (scale_factors\n                         if isinstance(scale_factors, (list, tuple))\n                         else [scale_factors, scale_factors])\n        # if less scale_factors than in_shape dims, we defaultly resize the\n        # first dims for numpy and last dims for torch\n        scale_factors = (list(scale_factors) + [1] *\n                         (len(in_shape) - len(scale_factors)) if fw is numpy\n                         else [1] * (len(in_shape) - len(scale_factors)) +\n                         list(scale_factors))\n        if out_shape is None:\n            # when no out_shape given, it is calculated by multiplying the\n            # scale by the in_shape (not recomended)\n            out_shape = [ceil(scale_factor * in_sz)\n                         for scale_factor, in_sz in\n                         zip(scale_factors, in_shape)]\n        # next line intentionally after out_shape determined for stability\n        scale_factors = [float(sf) for sf in scale_factors]\n    return scale_factors, out_shape\n\n\ndef get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):\n    # we start by having the ouput coordinates which are just integer locations\n    out_coordinates = fw.arange(out_sz)\n\n    # if using torch we need to match the grid tensor device to the input device\n    out_coordinates = fw_set_device(out_coordinates, device, fw)\n\n    # This is projecting the ouput pixel locations in 1d to the input tensor,\n    # as non-integer locations.\n    # the following fomrula is derived in the paper\n    # \"From Discrete to Continuous Convolutions\" by Shocher et al.\n    return (out_coordinates / scale_factor +\n            (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor))\n\n\ndef get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):\n    # for each output pixel, map which input pixels influence it, in 1d.\n    # we start by calculating the leftmost neighbor, using half of the window\n    # size (eps is for when boundary is exact int)\n    left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)\n\n    # then we simply take all the pixel centers in the field by counting\n    # window size pixels from the left boundary\n    ordinal_numbers = fw.arange(ceil(cur_support_sz - eps))\n    # in case using torch we need to match the device\n    ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw)\n    field_of_view = left_boundaries[:, None] + ordinal_numbers\n\n    # next we do a trick instead of padding, we map the field of view so that\n    # it would be like mirror padding, without actually padding\n    # (which would require enlarging the input tensor)\n    mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw)\n    field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])]\n    field_of_view = fw_set_device(field_of_view,projected_grid.device, fw)\n    return field_of_view\n\n\ndef get_weights(interp_method, projected_grid, field_of_view):\n    # the set of weights per each output pixels is the result of the chosen\n    # interpolation method applied to the distances between projected grid\n    # locations and the pixel-centers in the field of view (distances are\n    # directed, can be positive or negative)\n    weights = interp_method(projected_grid[:, None] - field_of_view)\n\n    # we now carefully normalize the weights to sum to 1 per each output pixel\n    sum_weights = weights.sum(1, keepdims=True)\n    sum_weights[sum_weights == 0] = 1\n    return weights / sum_weights\n\n\ndef apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,\n                                 antialiasing):\n    # antialiasing is \"stretching\" the field of view according to the scale\n    # factor (only for downscaling). this is low-pass filtering. this\n    # requires modifying both the interpolation (stretching the 1d\n    # function and multiplying by the scale-factor) and the window size.\n    if scale_factor >= 1.0 or not antialiasing:\n        return interp_method, support_sz\n    cur_interp_method = (lambda arg: scale_factor *\n                         interp_method(scale_factor * arg))\n    cur_support_sz = support_sz / scale_factor\n    return cur_interp_method, cur_support_sz\n\n\ndef fw_ceil(x, fw):\n    if fw is numpy:\n        return fw.int_(fw.ceil(x))\n    else:\n        return x.ceil().long()\n\n\ndef fw_cat(x, fw):\n    if fw is numpy:\n        return fw.concatenate(x)\n    else:\n        return fw.cat(x)\n\n\ndef fw_swapaxes(x, ax_1, ax_2, fw):\n    if fw is numpy:\n        return fw.swapaxes(x, ax_1, ax_2)\n    else:\n        return x.transpose(ax_1, ax_2)\n\ndef fw_set_device(x, device, fw):\n    if fw is numpy:\n        return x\n    else:\n        return x.to(device)\n"
  },
  {
    "path": "code/synthetic/bsrt/option.py",
    "content": "import argparse\n\nparser = argparse.ArgumentParser(description='EDSR and MDSR')\n\nparser.add_argument('--n_resblocks', type=int, default=16,\n                    help='number of residual blocks')\nparser.add_argument('--n_feats', type=int, default=64,\n                    help='number of feature maps')\nparser.add_argument('--n_colors', type=int, default=3,\n                    help='number of color channels to use')\nparser.add_argument('--lr', type=float, default=1e-4,\n                    help='learning rate')\nparser.add_argument('--burst_size', type=int, default=14,\n                    help='burst size, max 14')\nparser.add_argument('--burst_channel', type=int, default=4,\n                    help='burst size, max 14')\nparser.add_argument('--swinfeature', action='store_true',\n                    help='use swin transformer to extract features')\nparser.add_argument('--model_level', type=str, default='S',\n                    help='S: small, L: large')\n\n################## fine-tune ##################\nparser.add_argument('--finetune', action='store_true',\n                    help='finetune model')\nparser.add_argument('--finetune_align', action='store_true',\n                    help='finetune alignment module')\nparser.add_argument('--finetune_swin', action='store_true',\n                    help='finetune swin trans module')\nparser.add_argument('--finetune_conv', action='store_true',\n                    help='finetune rest convs')\nparser.add_argument('--finetune_prelayer', action='store_true',\n                    help='finetune finetune pre feature extract layer')\nparser.add_argument('--finetune_upconv', action='store_true',\n                    help='finetune finetune up conv layer')\nparser.add_argument('--finetune_spynet', action='store_true',\n                    help='finetune finetune up conv layer')\n\n# Hardware specifications\nparser.add_argument('--n_threads', type=int, default=6,\n                    help='number of threads for data loading')\nparser.add_argument('--cpu', action='store_true',\n                    help='use cpu only')\nparser.add_argument('--n_GPUs', type=int, default=2,\n                    help='number of GPUs')\nparser.add_argument('--seed', type=int, default=1,\n                    help='random seed')\nparser.add_argument('--local_rank', type=int, default=-1,\n                    help='proc index')\nparser.add_argument('--fp16', action='store_true',\n                    help='use fp16 only')\nparser.add_argument('--use_checkpoint', action='store_true',\n                    help='use use_checkpoint in swin transformer')\n\n# Data specifications\nparser.add_argument('--root', type=str, default='/data/dataset/ntire21/burstsr/synthetic',\n                    help='dataset directory')\nparser.add_argument('--mode', type=str, default='train',\n                    help='demo image directory')\nparser.add_argument('--scale', type=str, default='4',\n                    help='super resolution scale')\nparser.add_argument('--patch_size', type=int, default=256,\n                    help='output patch size')\nparser.add_argument('--rgb_range', type=int, default=1,\n                    help='maximum value of RGB')\n\nparser.add_argument('--chop', action='store_true',\n                    help='enable memory-efficient forward')\nparser.add_argument('--no_augment', action='store_true',\n                    help='do not use data augmentation')\n\n# Model specifications\nparser.add_argument('--model', default='LRSC_EDVR',\n                    help='model name')\n\nparser.add_argument('--act', type=str, default='relu',\n                    help='activation function')\nparser.add_argument('--pre_train', type=str, default='',\n                    help='pre-trained model directory')\nparser.add_argument('--extend', type=str, default='.',\n                    help='pre-trained model directory')\n\nparser.add_argument('--res_scale', type=float, default=1,\n                    help='residual scaling')\nparser.add_argument('--shift_mean', default=True,\n                    help='subtract pixel mean from the input')\nparser.add_argument('--dilation', action='store_true',\n                    help='use dilated convolution')\nparser.add_argument('--precision', type=str, default='single',\n                    choices=('single', 'half'),\n                    help='FP precision for test (single | half)')\n\n\n# Option for Residual channel attention network (RCAN)\nparser.add_argument('--n_resgroups', type=int, default=20,\n                    help='number of residual groups')\nparser.add_argument('--reduction', type=int, default=16,\n                    help='number of feature maps reduction')\nparser.add_argument('--DA', action='store_true',\n                    help='use Dual Attention')\nparser.add_argument('--CA', action='store_true',\n                    help='use Channel Attention')\nparser.add_argument('--non_local', action='store_true',\n                    help='use Dual Attention')\n\n# Training specifications\nparser.add_argument('--reset', action='store_true',\n                    help='reset the training')\nparser.add_argument('--test_every', type=int, default=1000,\n                    help='do test per every N batches')\nparser.add_argument('--epochs', type=int, default=300,\n                    help='number of epochs to train')\nparser.add_argument('--batch_size', type=int, default=8,\n                    help='input batch size for training')\nparser.add_argument('--split_batch', type=int, default=1,\n                    help='split the batch into smaller chunks')\nparser.add_argument('--self_ensemble', action='store_true',\n                    help='use self-ensemble method for test')\nparser.add_argument('--test_only', action='store_true',\n                    help='set this option to test the model')\nparser.add_argument('--gan_k', type=int, default=1,\n                    help='k value for adversarial loss')\n\n# Optimization specifications\n\nparser.add_argument('--decay', type=str, default='100-200',\n                    help='learning rate decay type')\nparser.add_argument('--gamma', type=float, default=0.5,\n                    help='learning rate decay factor for step decay')\nparser.add_argument('--optimizer', default='ADAM',\n                    choices=('SGD', 'ADAM', 'RMSprop'),\n                    help='optimizer to use (SGD | ADAM | RMSprop)')\nparser.add_argument('--momentum', type=float, default=0.9,\n                    help='SGD momentum')\nparser.add_argument('--betas', type=tuple, default=(0.9, 0.999),\n                    help='ADAM beta')\nparser.add_argument('--epsilon', type=float, default=1e-8,\n                    help='ADAM epsilon for numerical stability')\nparser.add_argument('--weight_decay', type=float, default=0,\n                    help='weight decay')\nparser.add_argument('--gclip', type=float, default=0,\n                    help='gradient clipping threshold (0 = no clipping)')\n\n# Loss specifications\nparser.add_argument('--loss', type=str, default='1*L1',\n                    help='loss function configuration')\nparser.add_argument('--skip_threshold', type=float, default='1e8',\n                    help='skipping batch that has large error')\n\n# Log specifications\nparser.add_argument('--save', type=str, default='test',\n                    help='file name to save')\nparser.add_argument('--load', type=str, default='',\n                    help='file name to load')\nparser.add_argument('--resume', type=int, default=0,\n                    help='resume from specific checkpoint')\nparser.add_argument('--save_models', action='store_true',\n                    help='save all intermediate models')\nparser.add_argument('--print_every', type=int, default=20,\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--save_results', action='store_true',\n                    help='save output results')\nparser.add_argument('--save_gt', action='store_true',\n                    help='save low-resolution and high-resolution images together')\n\nargs = parser.parse_args()\n\nargs.scale = list(map(lambda x: int(x), args.scale.split('+')))\n\nif args.epochs == 0:\n    args.epochs = 1e8\n\nfor arg in vars(args):\n    if vars(args)[arg] == 'True':\n        vars(args)[arg] = True\n    elif vars(args)[arg] == 'False':\n        vars(args)[arg] = False\n\n"
  },
  {
    "path": "code/synthetic/bsrt/requirements.txt",
    "content": "matplotlib\nimageio\nopencv-python\ntensorboardX\n"
  },
  {
    "path": "code/synthetic/bsrt/scripts/__init__.py",
    "content": ""
  },
  {
    "path": "code/synthetic/bsrt/scripts/cal_mean_std.py",
    "content": "import torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image\nfrom datasets.synthetic_burst_train_set import SyntheticBurst\nfrom datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB\n\ndef main():\n    train_zurich_raw2rgb = ZurichRAW2RGB(root='/data/dataset/ntire21/burstsr/synthetic', split='train')\n    train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=14, crop_sz=384)\n    means = []\n    stds = []\n\n    for data in tqdm(train_data):\n        print(data.shape)\n        break\n\n\nif __name__ == '__main__':\n    # if not args.cpu: torch.cuda.set_device(0)\n    main()\n"
  },
  {
    "path": "code/synthetic/bsrt/scripts/demo.sh",
    "content": "set -ex\nrlaunch --cpu=4 --gpu=1 --memory=10240 -- python ./scripts/evaluate_burstsr_val.py\n"
  },
  {
    "path": "code/synthetic/bsrt/scripts/download_burstsr_dataset.py",
    "content": "import os\nimport urllib.request\nimport zipfile\nimport shutil\nimport argparse\n\n\ndef download_burstsr_dataset(download_path):\n    out_dir = download_path + '/burstsr_dataset'\n\n    # Download train folders\n    for i in range(9):\n        if not os.path.isfile('{}/train_{:02d}.zip'.format(out_dir, i)):\n            print('Downloading train_{:02d}'.format(i))\n\n            urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/train_{:02d}.zip'.format(i),\n                                       '{}/tmp.zip'.format(out_dir))\n\n            os.rename('{}/tmp.zip'.format(out_dir), '{}/train_{:02d}.zip'.format(out_dir, i))\n\n    # Download val folder\n    if not os.path.isfile('{}/val.zip'.format(out_dir)):\n        print('Downloading val')\n\n        urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/val.zip',\n                                   '{}/tmp.zip'.format(out_dir))\n\n        os.rename('{}/tmp.zip'.format(out_dir), '{}/val.zip'.format(out_dir))\n\n    # Unpack train set\n    for i in range(9):\n        print('Unpacking train_{:02d}'.format(i))\n        with zipfile.ZipFile('{}/train_{:02d}.zip'.format(out_dir, i), 'r') as zip_ref:\n            zip_ref.extractall('{}'.format(out_dir))\n\n    # Move files to a common directory\n    os.makedirs('{}/train'.format(out_dir), exist_ok=True)\n\n    for i in range(9):\n        file_list = os.listdir('{}/train_{:02d}'.format(out_dir, i))\n\n        for b in file_list:\n            source_dir = '{}/train_{:02d}/{}'.format(out_dir, i, b)\n            dst_dir = '{}/train/{}'.format(out_dir, b)\n\n            if os.path.isdir(source_dir):\n                shutil.move(source_dir, dst_dir)\n\n    # Delete individual subsets\n    for i in range(9):\n        shutil.rmtree('{}/train_{:02d}'.format(out_dir, i))\n\n    # Unpack val set\n    print('Unpacking val')\n    with zipfile.ZipFile('{}/val.zip'.format(out_dir), 'r') as zip_ref:\n        zip_ref.extractall('{}'.format(out_dir))\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Downloads and unpacks BurstSR dataset')\n    parser.add_argument('path', type=str, help='Path where the dataset will be downloaded')\n\n    args = parser.parse_args()\n\n    download_burstsr_dataset(args.path)\n\n\nif __name__ == '__main__':\n    main()\n\n\n"
  },
  {
    "path": "code/synthetic/bsrt/scripts/evaluate.sh",
    "content": "set -ex\nrlaunch --cpu=4 --gpu=1 --memory=10240 -- python scripts/evaluate_burstsr_val.py\n"
  },
  {
    "path": "code/synthetic/bsrt/scripts/evaluate_burstsr_val.py",
    "content": "import torch.nn.functional as F\nfrom datasets.burstsr_dataset import BurstSRDataset\nfrom utils.metrics import AlignedPSNR\nfrom pwcnet.pwcnet import PWCNet\n\nroot = '/data/dataset/ntire21/burstsr/real/NTIRE/burstsr_dataset'\n\nclass SimpleBaseline:\n    def __init__(self):\n        pass\n\n    def __call__(self, burst):\n        burst_rgb = burst[:, 0, [0, 1, 3]]\n        burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])\n        burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')\n        return burst_rgb\n\n\ndef main():\n    # Load dataset\n    dataset = BurstSRDataset(root=root,\n                             split='val', burst_size=14, crop_sz=80, random_flip=False)\n\n    # TODO Set your network here\n    net = SimpleBaseline()\n\n    device = 'cuda'\n\n    # Load alignment network, used in AlignedPSNR\n    alignment_net = PWCNet(load_pretrained=True,\n                           weights_path='PATH_TO_PWCNET_WEIGHTS')\n    alignment_net = alignment_net.to(device)\n    aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40)\n\n    scores_all = []\n    for idx in range(len(dataset)):\n        burst, frame_gt, meta_info_burst, meta_info_gt = dataset[idx]\n        burst = burst.unsqueeze(0).to(device)\n        frame_gt = frame_gt.unsqueeze(0).to(device)\n\n        net_pred = net(burst)\n\n        # Calculate Aligned PSNR\n        score = aligned_psnr_fn(net_pred, frame_gt, burst)\n\n        scores_all.append(score)\n\n    mean_psnr = sum(scores_all) / len(scores_all)\n\n    print('Mean PSNR is {:0.3f}'.format(mean_psnr.item()))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/synthetic/bsrt/scripts/save_results_synburst_val.py",
    "content": "import torch.nn.functional as F\nimport cv2\nfrom datasets.synthetic_burst_val_set import SyntheticBurstVal\nimport torch\nimport numpy as np\nimport os\n\n\nclass SimpleBaseline:\n    def __init__(self):\n        pass\n\n    def __call__(self, burst):\n        burst_rgb = burst[:, 0, [0, 1, 3]]\n        burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])\n        burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')\n        return burst_rgb\n\n\ndef main():\n    dataset = SyntheticBurstVal('PATH_TO_SyntheticBurstVal')\n    out_dir = 'PATH_WHERE_RESULTS_ARE_SAVED'\n\n    # TODO Set your network here\n    net = SimpleBaseline()\n\n    device = 'cuda'\n    os.makedirs(out_dir, exist_ok=True)\n\n    for idx in range(len(dataset)):\n        burst, burst_name = dataset[idx]\n\n        burst = burst.to(device).unsqueeze(0)\n\n        with torch.no_grad():\n            net_pred = net(burst)\n\n        # Normalize to 0  2^14 range and convert to numpy array\n        net_pred_np = (net_pred.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16)\n\n        # Save predictions as png\n        cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/synthetic/bsrt/scripts/test_burstsr_dataset.py",
    "content": "import torch.nn.functional as F\nimport cv2\nfrom datasets.burstsr_dataset import BurstSRDataset\nfrom torch.utils.data.dataloader import DataLoader\nfrom utils.metrics import AlignedPSNR\nfrom utils.postprocessing_functions import BurstSRPostProcess\nfrom utils.data_format_utils import convert_dict\nfrom pwcnet.pwcnet import PWCNet\n\n\ndef main():\n    # Load dataset\n    dataset = BurstSRDataset(root='PATH_TO_BURST_SR',\n                             split='val', burst_size=3, crop_sz=56, random_flip=False)\n\n    data_loader = DataLoader(dataset, batch_size=2)\n\n    # Load alignment network, used in AlignedPSNR\n    alignment_net = PWCNet(load_pretrained=True,\n                           weights_path='PATH_TO_PWCNET_WEIGHTS')\n    alignment_net = alignment_net.to('cuda')\n\n    aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40)\n\n    # Postprocessing function to obtain sRGB images\n    postprocess_fn = BurstSRPostProcess(return_np=True)\n\n    for d in data_loader:\n        burst, frame_gt, meta_info_burst, meta_info_gt = d\n\n        # A simple baseline which upsamples the base image using bilinear upsampling\n        burst_rgb = burst[:, 0, [0, 1, 3]]\n        burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])\n        burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')\n\n        # Calculate Aligned PSNR\n        score = aligned_psnr_fn(burst_rgb.cuda(), frame_gt.cuda(), burst.cuda())\n        print('PSNR is {:0.3f}'.format(score))\n\n        meta_info_gt = convert_dict(meta_info_gt, burst.shape[0])\n\n        # Apply simple post-processing to obtain RGB images\n        pred_0 = postprocess_fn.process(burst_rgb[0], meta_info_gt[0])\n        gt_0 = postprocess_fn.process(frame_gt[0], meta_info_gt[0])\n\n        pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR)\n        gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR)\n\n        # Visualize input, ground truth\n        cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0)\n        cv2.imshow('GT', gt_0)\n\n        input_key = cv2.waitKey(0)\n        if input_key == ord('q'):\n            return\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/synthetic/bsrt/scripts/test_synthetic_bursts.py",
    "content": "import torch.nn.functional as F\nimport cv2\nfrom datasets.synthetic_burst_train_set import SyntheticBurst\nfrom torch.utils.data.dataloader import DataLoader\nfrom utils.metrics import PSNR\nfrom utils.postprocessing_functions import SimplePostProcess\nfrom utils.data_format_utils import convert_dict\nfrom datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB\n\n\ndef main():\n    zurich_raw2rgb = ZurichRAW2RGB(root='PATH_TO_ZURICH_RAW_TO_RGB', split='test')\n    dataset = SyntheticBurst(zurich_raw2rgb, burst_size=3, crop_sz=256)\n\n    data_loader = DataLoader(dataset, batch_size=2)\n\n    # Function to calculate PSNR. Note that the boundary pixels (40 pixels) will be ignored during PSNR computation\n    psnr_fn = PSNR(boundary_ignore=40)\n\n    # Postprocessing function to obtain sRGB images\n    postprocess_fn = SimplePostProcess(return_np=True)\n\n    for d in data_loader:\n        burst, frame_gt, flow_vectors, meta_info = d\n\n        # A simple baseline which upsamples the base image using bilinear upsampling\n        burst_rgb = burst[:, 0, [0, 1, 3]]\n        burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:])\n        burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear')\n\n        # Calculate PSNR\n        score = psnr_fn(burst_rgb, frame_gt)\n\n        print('PSNR is {:0.3f}'.format(score))\n\n        meta_info = convert_dict(meta_info, burst.shape[0])\n\n        # Apply simple post-processing to obtain RGB images\n        pred_0 = postprocess_fn.process(burst_rgb[0], meta_info[0])\n        gt_0 = postprocess_fn.process(frame_gt[0], meta_info[0])\n\n        pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR)\n        gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR)\n\n        # Visualize input, ground truth\n        cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0)\n        cv2.imshow('GT', gt_0)\n\n        input_key = cv2.waitKey(0)\n        if input_key == ord('q'):\n            return\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/synthetic/bsrt/test.py",
    "content": "\nimport cv2\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\nimport random\nimport utility\nfrom option import args\n\nfrom datasets.synthetic_burst_test_set import SyntheticBurstTest\nfrom datasets.burstsr_dataset import flatten_raw_image_batch, pack_raw_image_batch\nimport model\n\nimport torch.multiprocessing as mp\nimport torch.backends.cudnn as cudnn\nimport torch.utils.data.distributed\nimport time\n\n\ncheckpoint = utility.checkpoint(args)\n\ndef ttaup(burst):\n    burst0 = flatten_raw_image_batch(burst) # B, T, C, H, W\n    burst1 = utility.bayer_aug(burst0, flip_h=False, flip_w=False, transpose=True)\n    burst0 = pack_raw_image_batch(burst0)\n    burst1 = pack_raw_image_batch(burst1)\n\n    return [burst0, burst1]\n\n\ndef ttadown(bursts):\n    burst0 = bursts[0]\n    burst1 = bursts[1].permute(0, 1, 3, 2)\n    out = (burst0 + burst1) / 2\n    return out\n\n\ndef main():\n    mp.spawn(main_worker, nprocs=1, args=(1, args))\n\n\ndef main_worker(local_rank, nprocs, args):\n    device = 'cuda'\n    cudnn.benchmark = True\n    args.local_rank = local_rank\n    utility.setup(local_rank, nprocs)\n    torch.cuda.set_device(local_rank)\n\n    dataset = SyntheticBurstTest(args.root)\n    out_dir = 'bsrt_synburst'\n    os.makedirs(out_dir, exist_ok=True)\n\n    _model = model.Model(args, checkpoint)\n\n    tt = []\n    for idx in tqdm(range(len(dataset))):\n        burst, meta_info = dataset[idx]\n        burst_name = meta_info['burst_name']\n\n        burst = burst.to(device).unsqueeze(0)\n        bursts = ttaup(burst)\n\n        srs = []\n        with torch.no_grad():\n            for x in bursts:\n                tic = time.time()\n                sr = _model(x, 0)\n                toc = time.time()\n                tt.append(toc-tic)\n                srs.append(sr)\n\n        sr = ttadown(srs)\n        # Normalize to 0  2^14 range and convert to numpy array\n        net_pred_np = (sr.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16)\n        cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np)\n\n    print('avg time: {:.4f}'.format(np.mean(tt)))\n    utility.cleanup()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/synthetic/bsrt/test_synburst.py",
    "content": "\nimport cv2\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\nimport random\nimport utility\nfrom option import args\n\nfrom utils.postprocessing_functions import SimplePostProcess\nfrom datasets.burstsr_dataset import flatten_raw_image_batch, pack_raw_image, pack_raw_image_batch\nfrom datasets.synthetic_burst_val_set import SyntheticBurstVal\nfrom utils.metrics import PSNR\nfrom utils.data_format_utils import convert_dict\nfrom data_processing.camera_pipeline import demosaic\nimport model\n\nimport torch.multiprocessing as mp\nimport torch.backends.cudnn as cudnn\nimport torch.utils.data.distributed\nimport time\n\n# from torchsummaryX import summary\n\n\ncheckpoint = utility.checkpoint(args)\n\ndef ttaup(burst):\n    # burst0 = flatten_raw_image_batch(burst) # B, T, C, H, W\n    # burst1 = utility.bayer_aug(burst0, flip_h=False, flip_w=False, transpose=True)\n    # burst1 = pack_raw_image_batch(burst1)\n    return [burst]\n\ndef ttadown(bursts):\n    burst0 = bursts[0]\n    # burst1 = bursts[1].permute(0, 1, 3, 2)\n    # out = (burst0 + burst1) / 2\n    out = burst0\n    return out\n\ndef main():\n    mp.spawn(main_worker, nprocs=1, args=(1, args))\n\n\ndef main_worker(local_rank, nprocs, args):\n    cudnn.benchmark = True\n    args.local_rank = local_rank\n    utility.setup(local_rank, nprocs)\n    torch.cuda.set_device(local_rank)\n\n    dataset = SyntheticBurstVal(root=args.root)\n    out_dir = 'val/bsrt_synburst'\n\n    _model = model.Model(args, checkpoint)\n\n    for param in _model.parameters():\n        param.requires_grad = False\n\n    psnr_fn = PSNR(boundary_ignore=40)\n\n    postprocess_fn = SimplePostProcess(return_np=True)\n\n    os.makedirs(out_dir, exist_ok=True)\n\n    tt = []\n    psnrs, ssims, lpipss = [], [], []\n    for idx in tqdm(range(len(dataset))):\n        burst_, gt, meta_info = dataset[idx]\n        burst_ = burst_.unsqueeze(0).cuda()\n        gt = gt.unsqueeze(0).cuda()\n        name = meta_info['burst_name']\n\n        bursts = ttaup(burst_)\n\n        srs = []\n        with torch.no_grad():\n            for x in bursts:\n                tic = time.time()\n                sr = _model(x, 0).float()\n                toc = time.time()\n                tt.append(toc-tic)\n                srs.append(sr)\n\n            sr = ttadown(srs)\n\n            # sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short()\n            # sr = sr_int.float() / (2 ** 14)\n\n        psnr, ssim, lpips = psnr_fn(sr, gt)\n        psnrs.append(psnr.item())\n        ssims.append(ssim.item())\n        lpipss.append(lpips.item())\n\n\n        # lrs = burst_[0]\n        # os.makedirs(f'{out_dir}/{name}', exist_ok=True)\n        # for i, lr in enumerate(lrs):\n        #     # print(lr[[0, 1, 3],...].shape)\n        #     lr = postprocess_fn.process(lr[[0, 1, 3],...], meta_info)\n        #     lr = cv2.cvtColor(lr, cv2.COLOR_RGB2BGR)\n        #     cv2.imwrite('{}/{}/{:2d}.png'.format(out_dir, name, i), lr)\n\n        # gt = postprocess_fn.process(gt[0], meta_info)\n        # gt = cv2.cvtColor(gt, cv2.COLOR_RGB2BGR)\n        # cv2.imwrite('{}/{}_gt.png'.format(out_dir, name), gt)\n\n        # sr_ = postprocess_fn.process(sr[0], meta_info)\n        # sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR)\n        # cv2.imwrite('{}/{}_bsrt.png'.format(out_dir, name), sr_)\n\n        del burst_\n        del sr\n        del gt\n\n\n    print(f'avg PSNR: {np.mean(psnrs):.6f}')\n    print(f'avg SSIM: {np.mean(ssims):.6f}')\n    print(f'avg LPIPS: {np.mean(lpipss):.6f}')\n    print(f' avg time: {np.mean(tt):.6f}')\n\n    # utility.cleanup()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "code/synthetic/bsrt/trainer.py",
    "content": "import os, sys\nfrom decimal import Decimal\nimport cv2\nimport utility\nimport random\n\nimport torch\nfrom tensorboardX import SummaryWriter\n\nfrom utils.postprocessing_functions import SimplePostProcess\nfrom utils.data_format_utils import convert_dict\nfrom utils.metrics import PSNR, L1, L2, CharbonnierLoss, MSSSIMLoss\nfrom datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch, pack_raw_image_batch\nfrom data_processing.camera_pipeline import demosaic\nfrom tqdm import tqdm\nimport time\n\nfrom torch.cuda.amp import autocast as autocast, GradScaler\n\ntrain_log_dir = '../train_log/'\n\nexp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]\ntfboard_name = exp_name + \"_\"\nexp_train_log_dir = os.path.join(train_log_dir, exp_name)\n\nLOG_DIR = os.path.join(exp_train_log_dir, 'logs')\n\n# save img path\nIMG_SAVE_DIR = os.path.join(exp_train_log_dir, 'img_log')\n# Where to load model\nLOAD_MODEL_DIR = os.path.join(exp_train_log_dir, 'models')\n# Where to save new model\nSAVE_MODEL_DIR = os.path.join(exp_train_log_dir, 'real_models')\n\nSAVE_STATE_DIR = os.path.join(exp_train_log_dir, 'training_states')\n\n# Where to save visualization images (for report)\nRESULTS_DIR = os.path.join(exp_train_log_dir, 'report')\n\n# print(SAVE_STATE_DIR)\nutility.mkdir(SAVE_STATE_DIR)\nutility.mkdir(SAVE_MODEL_DIR)\nutility.mkdir(IMG_SAVE_DIR)\nutility.mkdir(LOG_DIR)\n\n\nclass Trainer():\n    def __init__(self, args, train_loader, train_sampler, valid_loader, my_model, my_loss, ckp):\n        self.args = args\n        self.scale = args.scale[0]\n\n        self.ckp = ckp\n        self.loader_train = train_loader\n        self.loader_valid = valid_loader\n        self.train_sampler = train_sampler\n        self.model = my_model\n        self.loss = my_loss\n        self.optimizer = utility.make_optimizer(args, self.model)\n\n        ###################################\n        if args.pre_train == \"\":\n            self.fix_unflagged = True\n        else:\n            self.fix_unflagged = False\n        self.fix_epoch = 5\n        self.fix_keys = [\"spynet\", \"dcnpack\"]\n\n        ###################################\n\n        self.psnr_fn = PSNR(boundary_ignore=40)\n        # Postprocessing function to obtain sRGB images\n        self.postprocess_fn = SimplePostProcess(return_np=True)\n\n        if 'L1' in args.loss:\n            self.aligned_loss = L1(boundary_ignore=None).cuda(args.local_rank)\n        elif 'MSE' in args.loss:\n            self.aligned_loss = L2(boundary_ignore=None).cuda(args.local_rank)\n        elif 'CB' in args.loss:\n            self.aligned_loss = CharbonnierLoss(boundary_ignore=None).cuda(args.local_rank)\n        elif 'MSSSIM' in args.loss:\n            self.aligned_loss = MSSSIMLoss(boundary_ignore=None).cuda(args.local_rank)\n\n        if self.args.fp16:\n            self.scaler = GradScaler()\n\n        self.best_psnr = 0.\n        self.best_epoch = 0\n\n        self.error_last = 1e8\n        self.glob_iter = 0\n\n        self.log_dir = LOG_DIR + \"/\" + args.save\n        self.img_save_dir = IMG_SAVE_DIR + \"/\" + args.save\n        # Where to load model\n        self.load_model_dir = LOAD_MODEL_DIR + \"/\" + args.save\n        # Where to save new model\n        self.save_model_dir = SAVE_MODEL_DIR + \"/\" + args.save\n        self.save_state_dir = SAVE_STATE_DIR + \"/\" + args.save\n\n        # Where to save visualization images (for report)\n        self.results_dir = RESULTS_DIR + \"/\" + args.save\n\n        if self.args.load != '':\n            self.optimizer.load(self.save_state_dir, epoch=int(self.args.load))\n\n        utility.mkdir(self.save_state_dir)\n        utility.mkdir(self.save_model_dir)\n        utility.mkdir(self.img_save_dir)\n        utility.mkdir(self.log_dir)\n        utility.mkdir('frames')\n\n        # self.writer = SummaryWriter(log_dir=self.log_dir)\n        if self.args.local_rank <= 0:\n            number_parameters = sum(map(lambda x: x.numel(), self.model.parameters()))\n            print(\"number of parameters: \", number_parameters)\n\n\n    def train(self):\n        self.loss.step()\n        epoch = self.optimizer.get_last_epoch() + 1\n        lr = self.optimizer.get_lr()\n\n        if self.train_sampler:\n            self.train_sampler.set_epoch(epoch)\n        if epoch % 200 == 0:\n            self.ckp.write_log(\n                '[Epoch {}]\\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))\n            )\n        self.loss.start_log()\n\n        # train alignment module after 5 epochs.\n        if self.args.pre_train == \"\":\n            if self.fix_unflagged and epoch < self.fix_epoch:\n                if self.args.local_rank <= 0:\n                    print(f'Fix keys: {self.fix_keys} for the first {self.fix_epoch} epochs.')\n                self.fix_unflagged = False\n                for name, param in self.model.named_parameters():\n                    if any([key in name for key in self.fix_keys]):\n                        param.requires_grad_(False)\n            elif epoch == self.fix_epoch:\n                if self.args.local_rank <= 0:\n                    print(f'Train all the parameters from {self.fix_epoch} epochs.')\n                self.model.requires_grad_(True)\n\n        # self.test()\n        self.model.train()\n        if self.args.local_rank == 0:\n            timer_data, timer_model, timer_epoch = utility.timer(), utility.timer(), utility.timer()\n            timer_epoch.tic()\n        \n        for batch, batch_value in enumerate(self.loader_train):\n\n            burst, gt, flow_vectors, meta_info = batch_value\n\n            burst, gt, flow_vectors = self.prepare(burst, gt, flow_vectors)\n            # burst = flatten_raw_image_batch(burst)\n            if self.args.local_rank == 0:\n                timer_data.hold()\n                timer_model.tic()\n\n            if self.args.fp16:\n                with autocast():\n                    sr = self.model(burst, 0)\n                    loss = self.aligned_loss(sr, gt)\n            else:\n                sr = self.model(burst, 0)\n                loss = self.aligned_loss(sr, gt)\n\n            if self.args.n_GPUs > 1:\n                torch.distributed.barrier()\n                reduced_loss = utility.reduce_mean(loss, self.args.n_GPUs)\n            else:\n                reduced_loss = loss\n\n            self.optimizer.zero_grad()\n            \n            if self.args.fp16:\n                self.scaler.scale(loss).backward()\n                # torch.nn.utils.clip_grad_value_(self.model.parameters(), .02)\n                if torch.isinf(sr).sum() + torch.isnan(sr).sum() <= 0:\n                    self.scaler.step(self.optimizer)\n                    self.scaler.update()\n                else:\n                    print(f'Nan num: {torch.isnan(sr).sum()}, inf num: {torch.isinf(sr).sum()}')\n                    reduced_loss = None\n                    os._exit(0)\n                    sys.exit(0)\n            else:\n                loss.backward()\n                # torch.nn.utils.clip_grad_value_(self.model.parameters(), .02)\n                if torch.isinf(sr).sum() + torch.isnan(sr).sum() <= 0:\n                    self.optimizer.step()\n                else:\n                    print(f'Nan num: {torch.isnan(sr).sum()}, inf num: {torch.isinf(sr).sum()}')\n                    reduced_loss = None\n\n            if self.args.local_rank == 0:\n                timer_model.hold()\n                if (batch + 1) % self.args.print_every == 0:\n                    self.ckp.write_log('[{}/{}]\\t[{:.4f}]\\t{:.1f}+{:.1f}s'.format(\n                        (batch + 1) * self.args.batch_size,\n                        len(self.loader_train.dataset),\n                        reduced_loss.item(),\n                        timer_model.release(),\n                        timer_data.release()))\n\n                self.glob_iter += 1\n                timer_data.tic()\n\n            if self.args.local_rank <= 0 and (batch + 1) % 2000 == 0:\n                if not self.args.test_only:\n                    filename = exp_name + '_latest' + '.pth'\n                    self.save_model(filename)\n\n        if self.args.local_rank <= 0:\n            timer_epoch.hold()\n            print('Epoch {} cost time: {:.1f}s, lr: {:5f}'.format(epoch, timer_epoch.release(), lr))\n            if (epoch) % 1 == 0 and not self.args.test_only:\n                filename = exp_name + '_epoch_' + str(epoch) + '.pth'\n                self.save_model(filename)\n\n            if not self.args.test_only:\n                filename = exp_name + '_latest' + '.pth'\n                self.save_model(filename)\n\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n        self.test()\n        self.loss.end_log(len(self.loader_train))\n        self.error_last = self.loss.log[-1, -1]\n        self.optimizer.schedule()\n\n    def test(self, print_time=False):\n\n\n        def ttaup(burst):\n            # burst0 = flatten_raw_image_batch(burst) # B, T, C, H, W\n            # burst1 = utility.bayer_aug(burst0, flip_h=False, flip_w=False, transpose=True)\n            # burst1 = pack_raw_image_batch(burst1)\n            return [burst]\n\n        def ttadown(bursts):\n            burst0 = bursts[0]\n            # burst1 = bursts[1].permute(0, 1, 3, 2)\n            # out = (burst0 + burst1) / 2\n            out = burst0\n            return out\n\n        torch.set_grad_enabled(False)\n\n        epoch = self.optimizer.get_last_epoch() + 1\n        self.model.eval()\n        if self.args.local_rank == 0:\n            timer_test = utility.timer()\n        if epoch == 1 or epoch % 1 == 0:\n            self.model.eval()\n            total_psnr = 0\n            total_ssim = 0\n            total_lpips = 0\n            count = 0\n            if self.args.local_rank <= 0:\n                print(\"Testing...\")\n            for i, batch_value in enumerate(self.loader_valid):\n                burst_, gt, meta_info = batch_value\n                burst_, gt = self.prepare(burst_, gt)\n\n                bursts = ttaup(burst_)\n\n                # burst_ = flatten_raw_image_batch(burst_)\n                if print_time and self.args.local_rank <= 0:\n                    tic = time.time()\n                with torch.no_grad():\n                    srs = []\n                    for burst in bursts:\n                        if self.args.fp16:\n                            with autocast():\n                                sr = self.model(burst, 0).float()\n                        else:\n                            sr = self.model(burst, 0).float()\n                        srs.append(sr)\n                    sr = ttadown(srs)\n\n                if print_time and self.args.local_rank <= 0:\n                    toc = time.time()\n                    print(f'model pass time: {toc-tic:.4f}')\n\n                psnr_score, ssim_score, lpips_score = self.psnr_fn(sr, gt)\n\n                if self.args.n_GPUs > 1:\n                    torch.distributed.barrier()\n                    psnr_score = utility.reduce_mean(psnr_score, self.args.n_GPUs)\n                    ssim_score = utility.reduce_mean(ssim_score, self.args.n_GPUs)\n                    lpips_score = utility.reduce_mean(lpips_score, self.args.n_GPUs)\n\n                total_psnr += psnr_score\n                total_ssim += ssim_score\n                total_lpips += lpips_score\n                count += 1\n\n            total_psnr = total_psnr / count\n            total_ssim = total_ssim / count\n            total_lpips = total_lpips / count\n            if self.args.local_rank == 0:\n                print(\"[Epoch: {}][PSNR: {:.4f}][SSIM: {:.4f}][LPIPS: {:.4f}][Best PSNR: {:.4f}][Best Epoch: {}]\"\n                    .format(epoch, total_psnr, total_ssim, total_lpips, self.best_psnr, self.best_epoch))\n                if epoch > 1 and total_psnr > self.best_psnr:\n                    self.best_psnr = total_psnr\n                    self.best_epoch = epoch\n                    filename = exp_name + '_best_epoch.pth'\n                    self.save_model(filename)\n                # self.writer.add_scalars('PSNR', {tfboard_name + '_PSNR': total_psnr}, self.glob_iter)\n\n                print('Forward: {:.2f}s\\n'.format(timer_test.toc()))\n\n        torch.cuda.synchronize()\n        torch.set_grad_enabled(True)\n        torch.cuda.empty_cache()\n\n    def save_model(self, filename):\n        print('save model...')\n        net_save_path = os.path.join(self.save_model_dir, filename)\n\n        model = self.model.model\n        if self.args.n_GPUs > 1:\n            model = model.module\n\n        # self.optimizer.save(self.save_state_dir)\n        torch.save(model.state_dict(), net_save_path)\n\n\n    def prepare(self, *args):\n        device = torch.device('cpu' if self.args.cpu else 'cuda:{}'.format(self.args.local_rank))\n\n        def _prepare(tensor):\n            if self.args.precision == 'half': tensor = tensor.half()\n            return tensor.to(device)\n\n        # print(_prepare(args[0]).device)\n        return [_prepare(a) for a in args]\n\n    def terminate(self):\n        if self.args.test_only:\n            self.test()\n            return True\n        else:\n            epoch = self.optimizer.get_last_epoch() + 1\n            return epoch >= self.args.epochs\n"
  },
  {
    "path": "code/synthetic/bsrt/utility.py",
    "content": "import math\nimport time\nimport datetime\nfrom multiprocessing import Process\nfrom multiprocessing import Queue\nimport torch\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\nimport torch.multiprocessing as mp\n\nimport numpy as np\nimport imageio\nimport os\nimport sys\n\nimport torch.optim as optim\nimport torch.optim.lr_scheduler as lrs\n\nimport torch.distributed as dist\nimport matplotlib\n\nmatplotlib.use('Agg')\n\n\ndef reduce_mean(tensor, nprocs):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.ReduceOp.SUM)\n    rt /= nprocs\n    return rt\n\ndef gradient(data):\n    D_dy = data[:, :, 1:] - data[:, :, :-1]\n    D_dx = data[:, :, :, 1:] - data[:, :, :, :-1]\n    return D_dx, D_dy\n\n\ndef smooth_grad_1st(flo, image, alpha):\n    img_dx, img_dy = gradient(image)\n    weights_x = torch.exp(-torch.mean(torch.abs(img_dx), 1, keepdims=True) * alpha)\n    weights_y = torch.exp(-torch.mean(torch.abs(img_dy), 1, keepdims=True) * alpha)\n\n    dx, dy = gradient(flo)\n\n    loss_x = weights_x * torch.abs(dx) / 2.0\n    loss_y = weights_y * torch.abs(dy) / 2.0\n    return torch.mean(loss_x) / 2.0 + torch.mean(loss_y) / 2.0\n\ndef smooth_loss(flow, img):\n    loss = smooth_grad_1st(flow, img, 10)\n    return sum([torch.mean(loss)])\n    \n\ndef setup(rank, world_size):\n    if sys.platform == 'win32':\n        # Distributed package only covers collective communications with Gloo\n        # backend and FileStore on Windows platform. Set init_method parameter\n        # in init_process_group to a local file.\n        # Example init_method=\"file:///f:/libtmp/some_file\"\n        init_method = \"tcp://localhost:1234\"\n\n        # initialize the process group\n        dist.init_process_group(\n            \"gloo\",\n            init_method=init_method,\n            rank=rank,\n            world_size=world_size\n        )\n    else:\n        os.environ['MASTER_ADDR'] = 'localhost'\n        os.environ['MASTER_PORT'] = '4321'\n        # if mp.get_start_method(allow_none=True) is None:\n        if (\n            mp.get_start_method(allow_none=True) != \"spawn\"\n        ):  # Return the name of start method used for starting processes\n            mp.set_start_method(\"spawn\", force=True)  ##'spawn' is the default on Windows\n        # initialize the process group\n        dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n\n\ndef cleanup():\n    dist.destroy_process_group()\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path, exist_ok=True)\n\n\nclass timer():\n    def __init__(self):\n        self.acc = 0\n        self.tic()\n\n    def tic(self):\n        self.t0 = time.time()\n\n    def toc(self, restart=False):\n        diff = time.time() - self.t0\n        if restart: self.t0 = time.time()\n        return diff\n\n    def hold(self):\n        self.acc += self.toc()\n\n    def release(self):\n        ret = self.acc\n        self.acc = 0\n\n        return ret\n\n    def reset(self):\n        self.acc = 0\n\n\nclass checkpoint():\n    def __init__(self, args):\n        self.args = args\n        self.ok = True\n        self.log = torch.Tensor()\n        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')\n\n        if not args.load:\n            if not args.save:\n                args.save = now\n            self.dir = os.path.join('..', 'experiment', args.save)\n        else:\n            self.dir = os.path.join('..', 'experiment', args.load)\n            if os.path.exists(self.dir):\n                self.log = torch.load(self.get_path('psnr_log.pt'))\n                print('Continue from epoch {}...'.format(len(self.log)))\n            else:\n                args.load = ''\n\n        if args.reset:\n            os.system('rm -rf ' + self.dir)\n            args.load = ''\n\n        os.makedirs(self.dir, exist_ok=True)\n        os.makedirs(self.get_path('model'), exist_ok=True)\n        # for d in args.data_test:\n        #     os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)\n\n        open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w'\n        self.log_file = open(self.get_path('log.txt'), open_type)\n        with open(self.get_path('config.txt'), open_type) as f:\n            f.write(now + '\\n\\n')\n            for arg in vars(args):\n                f.write('{}: {}\\n'.format(arg, getattr(args, arg)))\n            f.write('\\n')\n\n        self.n_processes = 8\n\n    def get_path(self, *subdir):\n        return os.path.join(self.dir, *subdir)\n\n    def save(self, trainer, epoch, is_best=False):\n        trainer.model.save(self.get_path('model'), epoch, is_best=is_best)\n        trainer.loss.save(self.dir)\n        trainer.loss.plot_loss(self.dir, epoch)\n\n        self.plot_psnr(epoch)\n        trainer.optimizer.save(self.dir)\n        torch.save(self.log, self.get_path('psnr_log.pt'))\n\n    def add_log(self, log):\n        self.log = torch.cat([self.log, log])\n\n    def write_log(self, log, refresh=False):\n        print(log)\n        self.log_file.write(log + '\\n')\n        if refresh:\n            self.log_file.close()\n            self.log_file = open(self.get_path('log.txt'), 'a')\n\n    def done(self):\n        self.log_file.close()\n\n    def plot_psnr(self, epoch):\n        axis = np.linspace(1, epoch, epoch)\n        for idx_data, d in enumerate(self.args.data_test):\n            label = 'SR on {}'.format(d)\n            fig = plt.figure()\n            plt.title(label)\n            for idx_scale, scale in enumerate(self.args.scale):\n                plt.plot(\n                    axis,\n                    self.log[:, idx_data, idx_scale].numpy(),\n                    label='Scale {}'.format(scale)\n                )\n            plt.legend()\n            plt.xlabel('Epochs')\n            plt.ylabel('PSNR')\n            plt.grid(True)\n            plt.savefig(self.get_path('test_{}.pdf'.format(d)))\n            plt.close(fig)\n\n    def begin_background(self):\n        self.queue = Queue()\n\n        def bg_target(queue):\n            while True:\n                if not queue.empty():\n                    filename, tensor = queue.get()\n                    if filename is None: break\n                    imageio.imwrite(filename, tensor.numpy())\n\n        self.process = [\n            Process(target=bg_target, args=(self.queue,)) \\\n            for _ in range(self.n_processes)\n        ]\n\n        for p in self.process: p.start()\n\n    def end_background(self):\n        for _ in range(self.n_processes): self.queue.put((None, None))\n        while not self.queue.empty(): time.sleep(1)\n        for p in self.process: p.join()\n\n    def save_results(self, dataset, filename, save_list, scale):\n        if self.args.save_results:\n            filename = self.get_path(\n                'results-{}'.format(dataset.dataset.name),\n                '{}_x{}_'.format(filename, scale)\n            )\n\n            postfix = ('SR', 'LR', 'HR')\n            for v, p in zip(save_list, postfix):\n                normalized = v[0].mul(255 / self.args.rgb_range)\n                tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()\n                self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))\n\n\ndef quantize(img, rgb_range):\n    pixel_range = 255 / rgb_range\n    return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)\n\n\ndef calc_psnr(sr, hr, scale, rgb_range, dataset=None):\n    if hr.nelement() == 1: return 0\n\n    diff = (sr - hr) / rgb_range\n    if dataset and dataset.dataset.benchmark:\n        shave = scale\n        if diff.size(1) > 1:\n            gray_coeffs = [65.738, 129.057, 25.064]\n            convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256\n            diff = diff.mul(convert).sum(dim=1)\n    else:\n        shave = scale + 6\n\n    valid = diff[..., shave:-shave, shave:-shave]\n    mse = valid.pow(2).mean()\n\n    return -10 * math.log10(mse)\n\n\ndef make_optimizer(args, target):\n    '''\n        make optimizer and scheduler together\n    '''\n    # optimizer\n    trainable = filter(lambda x: x.requires_grad, target.parameters())\n    kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}\n\n    if args.optimizer == 'SGD':\n        optimizer_class = optim.SGD\n        kwargs_optimizer['momentum'] = args.momentum\n    elif args.optimizer == 'ADAM':\n        optimizer_class = optim.Adam\n        kwargs_optimizer['betas'] = args.betas\n        kwargs_optimizer['eps'] = args.epsilon\n    elif args.optimizer == 'RMSprop':\n        optimizer_class = optim.RMSprop\n        kwargs_optimizer['eps'] = args.epsilon\n\n    # scheduler\n    milestones = list(map(lambda x: int(x), args.decay.split('-')))\n    kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}\n    scheduler_class = lrs.MultiStepLR\n\n    class CustomOptimizer(optimizer_class):\n        def __init__(self, *args, **kwargs):\n            super(CustomOptimizer, self).__init__(*args, **kwargs)\n\n        def _register_scheduler(self, scheduler_class, **kwargs):\n            self.scheduler = scheduler_class(self, **kwargs)\n\n        def save(self, save_dir):\n            torch.save(self.state_dict(), self.get_dir(save_dir))\n\n        def load(self, load_dir, epoch=1):\n            self.load_state_dict(torch.load(self.get_dir(load_dir)))\n            if epoch > 1:\n                for _ in range(epoch): self.scheduler.step()\n\n        def get_dir(self, dir_path):\n            return os.path.join(dir_path, 'optimizer.pt')\n\n        def schedule(self):\n            self.scheduler.step()\n\n        def get_lr(self):\n            return self.scheduler.get_last_lr()[0]\n\n        def get_last_epoch(self):\n            return self.scheduler.last_epoch\n\n    optimizer = CustomOptimizer(trainable, **kwargs_optimizer)\n    optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)\n    return optimizer\n\n\ndef write_gray_to_tfboard(img):\n    img_debug = img[0, ...].detach().cpu().numpy()\n\n    # img_debug = cv2.normalize(img_debug, None, 0, 255,\n    #                           cv2.NORM_MINMAX, cv2.CV_8U)\n    img_debug = img_debug * 255\n    img_debug = np.clip(img_debug, 0, 255)\n    img_debug = img_debug.astype(np.uint8)\n    return img_debug[0, ...]\n\n\n\n\n\n######################## BayerUnifyAug ############################\n\nBAYER_PATTERNS = [\"RGGB\", \"BGGR\", \"GRBG\", \"GBRG\"]\nNORMALIZATION_MODE = [\"crop\", \"pad\"]\n\n\ndef bayer_unify(raw, input_pattern, target_pattern, mode) -> np.ndarray:\n    \"\"\"\n    Convert a bayer raw image from one bayer pattern to another.\n    mode: {\"crop\", \"pad\"}\n        The way to handle submosaic shift. \"crop\" abandons the outmost pixels,\n        and \"pad\" introduces extra pixels. Use \"crop\" in training and \"pad\" in\n        testing.\n    \"\"\"\n\n    if input_pattern == target_pattern:\n        h_offset, w_offset = 0, 0\n    elif input_pattern[0] == target_pattern[2] and input_pattern[1] == target_pattern[3]:\n        h_offset, w_offset = 1, 0\n    elif input_pattern[0] == target_pattern[1] and input_pattern[2] == target_pattern[3]:\n        h_offset, w_offset = 0, 1\n    elif input_pattern[0] == target_pattern[3] and input_pattern[1] == target_pattern[2]:\n        h_offset, w_offset = 1, 1\n    else:  # This is not happening in [\"RGGB\", \"BGGR\", \"GRBG\", \"GBRG\"]\n        raise RuntimeError('Unexpected pair of input and target bayer pattern!')\n\n    if mode == \"pad\":\n        # out = np.pad(raw, [[h_offset, h_offset], [w_offset, w_offset]], 'reflect')\n        out = F.pad(raw, (w_offset, w_offset, h_offset, h_offset), mode='reflect')\n    elif mode == \"crop\":\n        _, _, _, h, w = raw.shape\n        out = raw[..., h_offset:h - h_offset, w_offset:w - w_offset]\n    else:\n        raise ValueError('Unknown normalization mode!')\n\n    return out\n\n\ndef bayer_aug(raw, flip_h=False, flip_w=False, transpose=False, input_pattern='RGGB') -> np.ndarray:\n    \"\"\"\n    Apply augmentation to a bayer raw image.\n    \"\"\"\n\n    aug_pattern, target_pattern = input_pattern, input_pattern\n\n    out = raw\n    if flip_h:\n        out = torch.flip(out, [3]) # GBRG, RGGB\n        aug_pattern = aug_pattern[2] + aug_pattern[3] + aug_pattern[0] + aug_pattern[1]\n    if flip_w:\n        out = torch.flip(out, [4])\n        aug_pattern = aug_pattern[1] + aug_pattern[0] + aug_pattern[3] + aug_pattern[2]\n    if transpose:\n        out = out.permute(0, 1, 2, 4, 3)\n        aug_pattern = aug_pattern[0] + aug_pattern[2] + aug_pattern[1] + aug_pattern[3]\n\n    out = bayer_unify(out, aug_pattern, target_pattern, \"crop\")\n    return out\n\n\n\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "code/synthetic/bsrt/utils/__init__.py",
    "content": ""
  },
  {
    "path": "code/synthetic/bsrt/utils/data_format_utils.py",
    "content": "import numpy as np\nimport torch\nimport cv2 as cv\n\n\ndef numpy_to_torch(a: np.ndarray):\n    return torch.from_numpy(a).float().permute(2, 0, 1)\n\n\ndef torch_to_numpy(a: torch.Tensor):\n    return a.permute(1, 2, 0).cpu().numpy()\n\n\ndef torch_to_npimage(a: torch.Tensor, unnormalize=True):\n    a_np = torch_to_numpy(a)\n\n    if unnormalize:\n        a_np = a_np * 255\n    a_np = a_np.astype(np.uint8)\n    return cv.cvtColor(a_np, cv.COLOR_RGB2BGR)\n\n\ndef npimage_to_torch(a, normalize=True, input_bgr=True):\n    if input_bgr:\n        a = cv.cvtColor(a, cv.COLOR_BGR2RGB)\n    a_t = numpy_to_torch(a)\n\n    if normalize:\n        a_t = a_t / 255.0\n\n    return a_t\n\n\ndef convert_dict(base_dict, batch_sz):\n    out_dict = []\n    for b_elem in range(batch_sz):\n        b_info = {}\n        for k, v in base_dict.items():\n            if isinstance(v, (list, torch.Tensor)):\n                b_info[k] = v[b_elem]\n        out_dict.append(b_info)\n\n    return out_dict"
  },
  {
    "path": "code/synthetic/bsrt/utils/debayer.py",
    "content": "import torch\nimport torch.nn\nimport torch.nn.functional\n\nclass Debayer3x3(torch.nn.Module):\n    '''Demosaicing of Bayer images using 3x3 convolutions.\n\n    Requires BG-Bayer color filter array layout. That is,\n    the image[1,1]='B', image[1,2]='G'. This corresponds\n    to OpenCV naming conventions.\n\n    Compared to Debayer2x2 this method does not use upsampling.\n    Instead, we identify five 3x3 interpolation kernels that\n    are sufficient to reconstruct every color channel at every\n    pixel location.\n\n    We convolve the image with these 5 kernels using stride=1\n    and a one pixel replication padding. Finally, we gather\n    the correct channel values for each pixel location. Todo so,\n    we recognize that the Bayer pattern repeats horizontally and\n    vertically every 2 pixels. Therefore, we define the correct\n    index lookups for a 2x2 grid cell and then repeat to image\n    dimensions.\n\n    Note, in every 2x2 grid cell we have red, blue and two greens\n    (G1,G2). The lookups for the two greens differ.\n    '''\n\n    def __init__(self):\n        super(Debayer3x3, self).__init__()\n\n        self.kernels = torch.nn.Parameter(\n            torch.tensor([\n                [0,0,0],\n                [0,1,0],\n                [0,0,0],\n\n                [0, 0.25, 0],\n                [0.25, 0, 0.25],\n                [0, 0.25, 0],\n\n                [0.25, 0, 0.25],\n                [0, 0, 0],\n                [0.25, 0, 0.25],\n\n                [0, 0, 0],\n                [0.5, 0, 0.5],\n                [0, 0, 0],\n\n                [0, 0.5, 0],\n                [0, 0, 0],\n                [0, 0.5, 0],\n            ]).view(5,1,3,3), requires_grad=False\n        )\n\n\n        self.index = torch.nn.Parameter(\n            torch.tensor([\n                # dest channel r\n                [0, 3], # pixel is R,G1\n                [4, 2], # pixel is G2,B\n                # dest channel g\n                [1, 0], # pixel is R,G1\n                [0, 1], # pixel is G2,B\n                # dest channel b\n                [2, 4], # pixel is R,G1\n                [3, 0], # pixel is G2,B\n            ]).view(1,3,2,2), requires_grad=False\n        )\n\n    def forward(self, x):\n        '''Debayer image.\n\n        Parameters\n        ----------\n        x : Bx1xHxW tensor\n            Images to debayer\n\n        Returns\n        -------\n        rgb : Bx3xHxW tensor\n            Color images in RGB channel order.\n        '''\n        B,C,H,W = x.shape\n\n        x = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate')\n        c = torch.nn.functional.conv2d(x, self.kernels, stride=1)\n        rgb = torch.gather(c, 1, self.index.repeat(B,1,H//2,W//2))\n        return rgb\n\nclass Debayer2x2(torch.nn.Module):\n    '''Demosaicing of Bayer images using 2x2 convolutions.\n\n    Requires BG-Bayer color filter array layout. That is,\n    the image[1,1]='B', image[1,2]='G'. This corresponds\n    to OpenCV naming conventions.\n    '''\n\n    def __init__(self):\n        super(Debayer2x2, self).__init__()\n\n        self.kernels = torch.nn.Parameter(\n            torch.tensor([\n                [1, 0],\n                [0, 0],\n\n                [0, 0.5],\n                [0.5, 0],\n\n                [0, 0],\n                [0, 1],\n            ]).view(3,1,2,2), requires_grad=False\n        )\n\n    def forward(self, x):\n        '''Debayer image.\n\n        Parameters\n        ----------\n        x : Bx1xHxW tensor\n            Images to debayer\n\n        Returns\n        -------\n        rgb : Bx3xHxW tensor\n            Color images in RGB channel order.\n        '''\n\n        x = torch.nn.functional.conv2d(x, self.kernels, stride=2)\n        x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)\n        return x\n\nclass DebayerSplit(torch.nn.Module):\n    '''Demosaicing of Bayer images using 3x3 green convolution and red,blue upsampling.\n\n    Requires BG-Bayer color filter array layout. That is,\n    the image[1,1]='B', image[1,2]='G'. This corresponds\n    to OpenCV naming conventions.\n    '''\n    def __init__(self):\n        super().__init__()\n\n        self.pad = torch.nn.ReflectionPad2d(1)\n        self.kernel = torch.nn.Parameter(\n            torch.tensor([\n                [0,1,0],\n                [1,0,1],\n                [0,1,0]\n            ])[None, None] * 0.25)\n\n    def forward(self, x):\n        '''Debayer image.\n\n        Parameters\n        ----------\n        x : Bx1xHxW tensor\n            Images to debayer\n\n        Returns\n        -------\n        rgb : Bx3xHxW tensor\n            Color images in RGB channel order.\n        '''\n        B,_,H,W = x.shape\n        red = x[:, :, ::2, ::2]\n        blue = x[:, :, 1::2, 1::2]\n\n        green = torch.nn.functional.conv2d(self.pad(x), self.kernel)\n        green[:, :, ::2, 1::2] = x[:, :, ::2, 1::2]\n        green[:, :, 1::2, ::2] = x[:, :, 1::2, ::2]\n\n        return torch.cat((\n            torch.nn.functional.interpolate(red, size=(H, W), mode='bilinear', align_corners=False),\n            green,\n            torch.nn.functional.interpolate(blue, size=(H, W), mode='bilinear', align_corners=False)),\n            dim=1)"
  },
  {
    "path": "code/synthetic/bsrt/utils/interp_methods.py",
    "content": "from math import pi\n\ntry:\n    import torch\nexcept ImportError:\n    torch = None\n\ntry:\n    import numpy\nexcept ImportError:\n    numpy = None\n\nif numpy is None and torch is None:\n    raise ImportError(\"Must have either Numpy or PyTorch but both not found\")\n\n\ndef set_framework_dependencies(x):\n    if type(x) is numpy.ndarray:\n        to_dtype = lambda a: a\n        fw = numpy\n    else:\n        to_dtype = lambda a: a.to(x.dtype)\n        fw = torch\n    eps = fw.finfo(fw.float32).eps\n    return fw, to_dtype, eps\n\n\ndef support_sz(sz):\n    def wrapper(f):\n        f.support_sz = sz\n        return f\n    return wrapper\n\n@support_sz(4)\ndef cubic(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    absx = fw.abs(x)\n    absx2 = absx ** 2\n    absx3 = absx ** 3\n    return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +\n            (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *\n            to_dtype((1. < absx) & (absx <= 2.)))\n\n@support_sz(4)\ndef lanczos2(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /\n            ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))\n\n@support_sz(6)\ndef lanczos3(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /\n            ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))\n\n@support_sz(2)\ndef linear(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *\n            to_dtype((0 <= x) & (x <= 1)))\n\n@support_sz(1)\ndef box(x):\n    fw, to_dtype, eps = set_framework_dependencies(x)\n    return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))\n"
  },
  {
    "path": "code/synthetic/bsrt/utils/metrics.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport utils.spatial_color_alignment as sca_utils\nfrom utils.spatial_color_alignment import get_gaussian_kernel, match_colors\nfrom utils.warp import warp\nfrom torch.cuda.amp import autocast\nfrom loss.Charbonnier import CharbonnierLoss as CBLoss\nfrom loss.mssim import MSSSIM\nfrom pytorch_msssim import ssim\nimport lpips\n\n\nclass MSSSIMLoss(nn.Module):\n    def __init__(self, boundary_ignore=None):\n        super().__init__()\n        self.boundary_ignore = boundary_ignore\n        self.msssim = MSSSIM()\n\n    def forward(self, pred, gt, valid=None):\n        if self.boundary_ignore is not None:\n            pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_m = pred\n        gt_m = gt\n\n        loss = self.msssim(pred_m, gt_m)\n\n        return loss\n\nclass CharbonnierLoss(nn.Module):\n    def __init__(self, boundary_ignore=None):\n        super().__init__()\n        self.boundary_ignore = boundary_ignore\n        self.charbonnier_loss = CBLoss(reduce=True)\n\n    def forward(self, pred, gt, valid=None):\n        if self.boundary_ignore is not None:\n            pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_m = pred\n        gt_m = gt\n\n        loss = self.charbonnier_loss(pred_m, gt_m)\n\n        return loss\n\nclass L1(nn.Module):\n    def __init__(self, boundary_ignore=None):\n        super().__init__()\n        self.boundary_ignore = boundary_ignore\n\n    def forward(self, pred, gt, valid=None):\n        if self.boundary_ignore is not None:\n            pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            if valid is not None:\n                valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_m = pred\n        gt_m = gt\n\n        if valid is None:\n            mse = F.l1_loss(pred_m, gt_m)\n        else:\n            mse = F.l1_loss(pred_m, gt_m, reduction='none')\n\n            eps = 1e-12\n            elem_ratio = mse.numel() / valid.numel()\n            mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        return mse\n\nclass L2(nn.Module):\n    def __init__(self, boundary_ignore=None):\n        super().__init__()\n        self.boundary_ignore = boundary_ignore\n        self.loss_fn = lpips.LPIPS(net='alex').cuda()\n\n    def forward(self, pred, gt, valid=None):\n        if self.boundary_ignore is not None:\n            pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            if valid is not None:\n                valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_m = pred\n        gt_m = gt\n\n        if valid is None:\n            mse = F.mse_loss(pred_m, gt_m)\n        else:\n            mse = F.mse_loss(pred_m, gt_m, reduction='none')\n\n            eps = 1e-12\n            elem_ratio = mse.numel() / valid.numel()\n            mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        ss = ssim(pred_m.contiguous(), gt_m.contiguous(), data_range=1.0, size_average=True)\n        lp = self.loss_fn(pred_m.contiguous(), gt_m.contiguous()).squeeze()\n\n        return mse, ss, lp\n\n\nclass PSNR(nn.Module):\n    def __init__(self, boundary_ignore=None, max_value=1.0):\n        super().__init__()\n        self.l2 = L2(boundary_ignore=boundary_ignore)\n        self.max_value = max_value\n\n    def psnr(self, pred, gt, valid=None):\n        mse, ss, lp = self.l2(pred, gt, valid=valid)\n\n        psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10()\n\n        return psnr, ss, lp\n\n    def forward(self, pred, gt, valid=None):\n        assert pred.dim() == 4 and pred.shape == gt.shape\n        if valid is None:\n            all_scores = [self.psnr(p.unsqueeze(0), g.unsqueeze(0)) for p, g in zip(pred, gt)]\n        else:\n            all_scores = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), v.unsqueeze(0)) for p, g, v in zip(pred, gt, valid)]\n        # psnr, ss, lp = sum(psnr_all) / len(psnr_all)\n        psnr = sum([score[0] for score in all_scores]) / len(all_scores)\n        ssim_ = sum([score[1] for score in all_scores]) / len(all_scores)\n        lpips_ = sum([score[2] for score in all_scores]) / len(all_scores)\n        return psnr, ssim_, lpips_\n\n\nclass AlignedL1(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):\n        super().__init__()\n        self.sr_factor = sr_factor\n        self.boundary_ignore = boundary_ignore\n        self.alignment_net = alignment_net\n\n        self.gauss_kernel, self.ksz = get_gaussian_kernel(sd=1.5)\n\n    def forward(self, pred, gt, burst_input):\n        # Estimate flow between the prediction and the ground truth\n        with torch.no_grad():\n            flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6))\n\n        # Warp the prediction to the ground truth coordinates\n        pred_warped = warp(pred, flow)\n\n        # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between\n        # the input and the ground truth\n        sr_factor = self.sr_factor\n        ds_factor = 1.0 / float(2.0 * sr_factor)\n        flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor\n\n        burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous()\n        burst_0_warped = warp(burst_0, flow_ds)\n        frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)\n\n        # Match the colorspace between the prediction and ground truth\n        pred_warped_m, valid = match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz,\n                                                      self.gauss_kernel)\n\n        # Ignore boundary pixels if specified\n        if self.boundary_ignore is not None:\n            pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore,\n                            self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        pred_warped_m = pred_warped_m.contiguous()\n        gt = gt.contiguous()\n        # Estimate MSE\n        l1 = F.l1_loss(pred_warped_m, gt, reduction='none')\n\n        eps = 1e-12\n        elem_ratio = l1.numel() / valid.numel()\n        l1 = (l1 * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        return l1\n\nclass AlignedL2(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):\n        super().__init__()\n        self.sr_factor = sr_factor\n        self.boundary_ignore = boundary_ignore\n        self.alignment_net = alignment_net\n        self.loss_fn = lpips.LPIPS(net='alex').cuda()\n\n        self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5)\n\n    def forward(self, pred, gt, burst_input):\n        # Estimate flow between the prediction and the ground truth\n        with torch.no_grad():\n            flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6))\n\n        # Warp the prediction to the ground truth coordinates\n        pred_warped = warp(pred, flow)\n\n        # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between\n        # the input and the ground truth\n        sr_factor = self.sr_factor\n        ds_factor = 1.0 / float(2.0 * sr_factor)\n        flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor\n\n        burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous()\n        burst_0_warped = warp(burst_0, flow_ds)\n        frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)\n\n        # Match the colorspace between the prediction and ground truth\n        pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz,\n                                                      self.gauss_kernel)\n\n        # Ignore boundary pixels if specified\n        if self.boundary_ignore is not None:\n            pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore,\n                            self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        # Estimate MSE\n        mse = F.mse_loss(pred_warped_m.contiguous(), gt.contiguous(), reduction='none')\n\n        eps = 1e-12\n        elem_ratio = mse.numel() / valid.numel()\n        mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        ss = ssim(pred_warped_m.contiguous(), gt.contiguous(), data_range=1.0, size_average=True)\n        # eps = 1e-12\n        # elem_ratio = ss.numel() / valid.numel()\n        # ss = (ss * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        lp = self.loss_fn(pred_warped_m.contiguous(), gt.contiguous()).squeeze()\n\n        return mse, ss, lp\n\n\nclass AlignedPSNR(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None, max_value=1.0):\n        super().__init__()\n        self.l2 = AlignedL2(alignment_net=alignment_net, sr_factor=sr_factor, boundary_ignore=boundary_ignore)\n        self.max_value = max_value\n\n    def psnr(self, pred, gt, burst_input):\n        mse, ss, lp = self.l2(pred, gt, burst_input)\n\n        psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10()\n\n        return psnr, ss, lp\n\n    def forward(self, pred, gt, burst_input):\n        all_scores = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)]\n        psnr = sum([score[0] for score in all_scores]) / len(all_scores)\n        ssim_ = sum([score[1] for score in all_scores]) / len(all_scores)\n        lpips_ = sum([score[2] for score in all_scores]) / len(all_scores)\n        return psnr, ssim_, lpips_\n\n\n\nclass AlignedSSIM(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):\n        super().__init__()\n        self.sr_factor = sr_factor\n        self.boundary_ignore = boundary_ignore\n        self.alignment_net = alignment_net\n\n        self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5)\n\n    def _ssim(self, pred, gt, burst_input):\n        # Estimate flow between the prediction and the ground truth\n        with torch.no_grad():\n            flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6))\n\n        # Warp the prediction to the ground truth coordinates\n        pred_warped = warp(pred, flow)\n\n        # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between\n        # the input and the ground truth\n        sr_factor = self.sr_factor\n        ds_factor = 1.0 / float(2.0 * sr_factor)\n        flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor\n\n        burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous()\n        burst_0_warped = warp(burst_0, flow_ds)\n        frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)\n\n        # Match the colorspace between the prediction and ground truth\n        pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz,\n                                                      self.gauss_kernel)\n\n        # Ignore boundary pixels if specified\n        if self.boundary_ignore is not None:\n            pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore,\n                            self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        # Estimate MSE\n        mse = ssim(pred_warped_m.contiguous(), gt.contiguous(), data_range=1.0, size_average=True)\n        # print(mse.shape)\n        # eps = 1e-12\n        # elem_ratio = mse.numel() / valid.numel()\n        # mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)\n\n        return mse\n\n    def forward(self, pred, gt, burst_input):\n        ssim_all = [self._ssim(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)]\n        _ssim = sum(ssim_all) / len(ssim_all)\n        return _ssim\n\n\nclass AlignedLPIPS(nn.Module):\n    def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):\n        super().__init__()\n        self.sr_factor = sr_factor\n        self.boundary_ignore = boundary_ignore\n        self.alignment_net = alignment_net\n        self.loss_fn = lpips.LPIPS(net='alex').cuda()\n\n        self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5)\n\n    def _lpips(self, pred, gt, burst_input):\n        # Estimate flow between the prediction and the ground truth\n        with torch.no_grad():\n            flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6))\n\n        # Warp the prediction to the ground truth coordinates\n        pred_warped = warp(pred, flow)\n\n        # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between\n        # the input and the ground truth\n        sr_factor = self.sr_factor\n        ds_factor = 1.0 / float(2.0 * sr_factor)\n        flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor\n\n        burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous()\n        burst_0_warped = warp(burst_0, flow_ds)\n        frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)\n\n        # Match the colorspace between the prediction and ground truth\n        pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz,\n                                                      self.gauss_kernel)\n\n        # Ignore boundary pixels if specified\n        if self.boundary_ignore is not None:\n            pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore,\n                            self.boundary_ignore:-self.boundary_ignore]\n            gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n            valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]\n\n        # Estimate MSE\n        mse = self.loss_fn(pred_warped_m.contiguous(), gt.contiguous()).squeeze()\n        return mse\n\n    def forward(self, pred, gt, burst_input):\n        lpips_all = [self._lpips(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)]\n        _lpips = sum(lpips_all) / len(lpips_all)\n        return _lpips\n"
  },
  {
    "path": "code/synthetic/bsrt/utils/postprocessing_functions.py",
    "content": "import torch\nimport numpy as np\nimport utils.data_format_utils as df_utils\nfrom data_processing.camera_pipeline import apply_gains, apply_ccm, apply_smoothstep, gamma_compression\n\n\nclass SimplePostProcess:\n    def __init__(self, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False):\n        self.gains = gains\n        self.ccm = ccm\n        self.gamma = gamma\n        self.smoothstep = smoothstep\n        self.return_np = return_np\n\n    def process(self, image, meta_info):\n        return process_linear_image_rgb(image, meta_info, self.gains, self.ccm, self.gamma,\n                                        self.smoothstep, self.return_np)\n\n\ndef process_linear_image_rgb(image, meta_info, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False):\n    if gains:\n        image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])\n\n    if ccm:\n        image = apply_ccm(image, meta_info['cam2rgb'])\n\n    if meta_info['gamma'] and gamma:\n        image = gamma_compression(image)\n\n    if meta_info['smoothstep'] and smoothstep:\n        image = apply_smoothstep(image)\n\n    image = image.clamp(0.0, 1.0)\n\n    if return_np:\n        image = df_utils.torch_to_npimage(image)\n    return image\n\n\nclass BurstSRPostProcess:\n    def __init__(self, no_white_balance=False, gamma=True, smoothstep=True, return_np=False):\n        self.no_white_balance = no_white_balance\n        self.gamma = gamma\n        self.smoothstep = smoothstep\n        self.return_np = return_np\n\n    def process(self, image, meta_info, external_norm_factor=None):\n        return process_burstsr_image_rgb(image, meta_info, external_norm_factor=external_norm_factor,\n                                         no_white_balance=self.no_white_balance, gamma=self.gamma,\n                                         smoothstep=self.smoothstep, return_np=self.return_np)\n\n\ndef process_burstsr_image_rgb(im, meta_info, return_np=False, external_norm_factor=None, gamma=True, smoothstep=True,\n                              no_white_balance=False):\n    im = im * meta_info.get('norm_factor', 1.0)\n\n    if not meta_info.get('black_level_subtracted', False):\n        im = (im - torch.tensor(meta_info['black_level'])[[0, 1, -1]].view(3, 1, 1).to(im.device))\n\n    if not meta_info.get('while_balance_applied', False) and not no_white_balance:\n        im = im * (meta_info['cam_wb'][[0, 1, -1]].view(3, 1, 1) / meta_info['cam_wb'][1]).to(im.device)\n\n    im_out = im\n\n    if external_norm_factor is None:\n        im_out = im_out / im_out.max()\n    else:\n        im_out = im_out / external_norm_factor\n\n    im_out = im_out.clamp(0.0, 1.0)\n\n    if gamma:\n        im_out = im_out ** (1.0 / 2.2)\n\n    if smoothstep:\n        # Smooth curve\n        im_out = 3 * im_out ** 2 - 2 * im_out ** 3\n\n    if return_np:\n        im_out = im_out.permute(1, 2, 0).cpu().numpy() * 255.0\n        im_out = im_out.astype(np.uint8)\n\n    return im_out\n"
  },
  {
    "path": "code/synthetic/bsrt/utils/resize_right.py",
    "content": "import warnings\nfrom math import ceil\nimport interp_methods\n\n\nclass NoneClass:\n    pass\n\ntry:\n    import torch\n    from torch import nn\n    nnModuleWrapped = nn.Module\nexcept ImportError:\n    warnings.warn('No PyTorch found, will work only with Numpy')\n    torch = None\n    nnModuleWrapped = NoneClass\n\ntry:\n    import numpy\nexcept ImportError:\n    warnings.warn('No Numpy found, will work only with PyTorch')\n    numpy = None\n\n\nif numpy is None and torch is None:\n    raise ImportError(\"Must have either Numpy or PyTorch but both not found\")\n\n\ndef resize(input, scale_factors=None, out_shape=None,\n           interp_method=interp_methods.cubic, support_sz=None,\n           antialiasing=True):\n    # get properties of the input tensor\n    in_shape, n_dims = input.shape, input.ndim\n\n    # fw stands for framework that can be either numpy or torch,\n    # determined by the input type\n    fw = numpy if type(input) is numpy.ndarray else torch\n    eps = fw.finfo(fw.float32).eps\n\n    # set missing scale factors or output shapem one according to another,\n    # scream if both missing\n    scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,\n                                                    scale_factors, fw)\n\n    # sort indices of dimensions according to scale of each dimension.\n    # since we are going dim by dim this is efficient\n    sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])\n                                       for dim in sorted(range(n_dims),\n                                       key=lambda ind: scale_factors[ind])\n                                       if scale_factors[dim] != 1.]\n\n    # unless support size is specified by the user, it is an attribute\n    # of the interpolation method\n    if support_sz is None:\n        support_sz = interp_method.support_sz\n\n    # when using pytorch, we need to know what is the input tensor device\n    if fw is torch:\n        device = input.device\n\n    # output begins identical to input and changes with each iteration\n    output = input\n\n    # iterate over dims\n    for dim, scale_factor in sorted_filtered_dims_and_scales:\n\n        # get 1d set of weights and fields of view for each output location\n        # along this dim\n        field_of_view, weights = prepare_weights_and_field_of_view_1d(\n            dim, scale_factor, in_shape[dim], out_shape[dim], interp_method,\n            support_sz, antialiasing, fw, eps, device)\n\n        # multiply the weights by the values in the field of view and\n        # aggreagate\n        output = apply_weights(output, field_of_view, weights, dim, n_dims,\n                               fw)\n    return output\n\n\nclass ResizeLayer(nnModuleWrapped):\n    def __init__(self, in_shape, scale_factors=None, out_shape=None,\n                 interp_method=interp_methods.cubic, support_sz=None,\n                 antialiasing=True):\n        super(ResizeLayer, self).__init__()\n\n        # fw stands for framework, that can be either numpy or torch. since\n        # this is a torch layer, only one option in this case.\n        fw = torch\n        eps = fw.finfo(fw.float32).eps\n\n        # set missing scale factors or output shapem one according to another,\n        # scream if both missing\n        scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,\n                                                        scale_factors, fw)\n\n        # unless support size is specified by the user, it is an attribute\n        # of the interpolation method\n        if support_sz is None:\n            support_sz = interp_method.support_sz\n\n        self.n_dims = len(in_shape)\n\n        # sort indices of dimensions according to scale of each dimension.\n        # since we are going dim by dim this is efficient\n        self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])\n                                                for dim in\n                                                sorted(range(self.n_dims),\n                                                key=lambda ind:\n                                                scale_factors[ind])\n                                                if scale_factors[dim] != 1.]\n\n        # iterate over dims\n        field_of_view_list = []\n        weights_list = []\n        for dim, scale_factor in self.sorted_filtered_dims_and_scales:\n\n            # get 1d set of weights and fields of view for each output\n            # location along this dim\n            field_of_view, weights = prepare_weights_and_field_of_view_1d(\n                dim, scale_factor, in_shape[dim], out_shape[dim],\n                interp_method, support_sz, antialiasing, fw, eps, input.device)\n\n            # keep weights and fields of views for all dims\n            weights_list.append(nn.Parameter(weights, requires_grad=False))\n            field_of_view_list.append(nn.Parameter(field_of_view,\n                                      requires_grad=False))\n\n        self.field_of_view = nn.ParameterList(field_of_view_list)\n        self.weights = nn.ParameterList(weights_list)\n        self.in_shape = in_shape\n\n    def forward(self, input):\n        # output begins identical to input and changes with each iteration\n        output = input\n\n        for (dim, scale_factor), field_of_view, weights in zip(\n                self.sorted_filtered_dims_and_scales,\n                self.field_of_view,\n                self.weights):\n            # multiply the weights by the values in the field of view and\n            # aggreagate\n            output = apply_weights(output, field_of_view, weights, dim,\n                                   self.n_dims, torch)\n        return output\n\n\ndef prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,\n                                         interp_method, support_sz,\n                                         antialiasing, fw, eps, device=None):\n    # If antialiasing is taking place, we modify the window size and the\n    # interpolation method (see inside function)\n    interp_method, cur_support_sz = apply_antialiasing_if_needed(\n                                                             interp_method,\n                                                             support_sz,\n                                                             scale_factor,\n                                                             antialiasing)\n\n    # STEP 1- PROJECTED GRID: The non-integer locations of the projection of\n    # output pixel locations to the input tensor\n    projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device)\n\n    # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels\n    # that influence it\n    field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz,\n                                      fw, eps)\n\n    # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the\n    # field of view for each output pixel\n    weights = get_weights(interp_method, projected_grid, field_of_view)\n\n    return field_of_view, weights\n\n\ndef apply_weights(input, field_of_view, weights, dim, n_dims, fw):\n    # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying\n    # its set of weights with the pixel values in its field of view.\n    # We now multiply the fields of view with their matching weights.\n    # We do this by tensor multiplication and broadcasting.\n    # this step is separated to a different function, so that it can be\n    # repeated with the same calculated weights and fields.\n\n    # for this operations we assume the resized dim is the first one.\n    # so we transpose and will transpose back after multiplying\n    tmp_input = fw_swapaxes(input, dim, 0, fw)\n\n    # field_of_view is a tensor of order 2: for each output (1d location\n    # along cur dim)- a list of 1d neighbors locations.\n    # note that this whole operations is applied to each dim separately,\n    # this is why it is all in 1d.\n    # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:\n    # for each output pixel (this time indicated in all dims), these are the\n    # values of the neighbors in the 1d field of view. note that we only\n    # consider neighbors along the current dim, but such set exists for every\n    # multi-dim location, hence the final tensor order is image_dims+1.\n    neighbors = tmp_input[field_of_view]\n\n    # weights is an order 2 tensor: for each output location along 1d- a list\n    # of weighs matching the field of view. we augment it with ones, for\n    # broadcasting, so that when multiplies some tensor the weights affect\n    # only its first dim.\n    tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1)))\n\n    # now we simply multiply the weights with the neighbors, and then sum\n    # along the field of view, to get a single value per out pixel\n    tmp_output = (neighbors * tmp_weights).sum(1)\n\n    # we transpose back the resized dim to its original position\n    return fw_swapaxes(tmp_output, 0, dim, fw)\n\n\ndef set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):\n    # eventually we must have both scale-factors and out-sizes for all in/out\n    # dims. however, we support many possible partial arguments\n    if scale_factors is None and out_shape is None:\n        raise ValueError(\"either scale_factors or out_shape should be \"\n                         \"provided\")\n    if out_shape is not None:\n        # if out_shape has less dims than in_shape, we defaultly resize the\n        # first dims for numpy and last dims for torch\n        out_shape = (list(out_shape) + list(in_shape[:-len(out_shape)])\n                     if fw is numpy\n                     else list(in_shape[:-len(out_shape)]) + list(out_shape))\n        if scale_factors is None:\n            # if no scale given, we calculate it as the out to in ratio\n            # (not recomended)\n            scale_factors = [out_sz / in_sz for out_sz, in_sz\n                             in zip(out_shape, in_shape)]\n    if scale_factors is not None:\n        # by default, if a single number is given as scale, we assume resizing\n        # two dims (most common are images with 2 spatial dims)\n        scale_factors = (scale_factors\n                         if isinstance(scale_factors, (list, tuple))\n                         else [scale_factors, scale_factors])\n        # if less scale_factors than in_shape dims, we defaultly resize the\n        # first dims for numpy and last dims for torch\n        scale_factors = (list(scale_factors) + [1] *\n                         (len(in_shape) - len(scale_factors)) if fw is numpy\n                         else [1] * (len(in_shape) - len(scale_factors)) +\n                         list(scale_factors))\n        if out_shape is None:\n            # when no out_shape given, it is calculated by multiplying the\n            # scale by the in_shape (not recomended)\n            out_shape = [ceil(scale_factor * in_sz)\n                         for scale_factor, in_sz in\n                         zip(scale_factors, in_shape)]\n        # next line intentionally after out_shape determined for stability\n        scale_factors = [float(sf) for sf in scale_factors]\n    return scale_factors, out_shape\n\n\ndef get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):\n    # we start by having the ouput coordinates which are just integer locations\n    out_coordinates = fw.arange(out_sz)\n\n    # if using torch we need to match the grid tensor device to the input device\n    out_coordinates = fw_set_device(out_coordinates, device, fw)\n\n    # This is projecting the ouput pixel locations in 1d to the input tensor,\n    # as non-integer locations.\n    # the following fomrula is derived in the paper\n    # \"From Discrete to Continuous Convolutions\" by Shocher et al.\n    return (out_coordinates / scale_factor +\n            (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor))\n\n\ndef get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):\n    # for each output pixel, map which input pixels influence it, in 1d.\n    # we start by calculating the leftmost neighbor, using half of the window\n    # size (eps is for when boundary is exact int)\n    left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)\n\n    # then we simply take all the pixel centers in the field by counting\n    # window size pixels from the left boundary\n    ordinal_numbers = fw.arange(ceil(cur_support_sz - eps))\n    # in case using torch we need to match the device\n    ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw)\n    field_of_view = left_boundaries[:, None] + ordinal_numbers\n\n    # next we do a trick instead of padding, we map the field of view so that\n    # it would be like mirror padding, without actually padding\n    # (which would require enlarging the input tensor)\n    mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw)\n    field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])]\n    field_of_view = fw_set_device(field_of_view,projected_grid.device, fw)\n    return field_of_view\n\n\ndef get_weights(interp_method, projected_grid, field_of_view):\n    # the set of weights per each output pixels is the result of the chosen\n    # interpolation method applied to the distances between projected grid\n    # locations and the pixel-centers in the field of view (distances are\n    # directed, can be positive or negative)\n    weights = interp_method(projected_grid[:, None] - field_of_view)\n\n    # we now carefully normalize the weights to sum to 1 per each output pixel\n    sum_weights = weights.sum(1, keepdims=True)\n    sum_weights[sum_weights == 0] = 1\n    return weights / sum_weights\n\n\ndef apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,\n                                 antialiasing):\n    # antialiasing is \"stretching\" the field of view according to the scale\n    # factor (only for downscaling). this is low-pass filtering. this\n    # requires modifying both the interpolation (stretching the 1d\n    # function and multiplying by the scale-factor) and the window size.\n    if scale_factor >= 1.0 or not antialiasing:\n        return interp_method, support_sz\n    cur_interp_method = (lambda arg: scale_factor *\n                         interp_method(scale_factor * arg))\n    cur_support_sz = support_sz / scale_factor\n    return cur_interp_method, cur_support_sz\n\n\ndef fw_ceil(x, fw):\n    if fw is numpy:\n        return fw.int_(fw.ceil(x))\n    else:\n        return x.ceil().long()\n\n\ndef fw_cat(x, fw):\n    if fw is numpy:\n        return fw.concatenate(x)\n    else:\n        return fw.cat(x)\n\n\ndef fw_swapaxes(x, ax_1, ax_2, fw):\n    if fw is numpy:\n        return fw.swapaxes(x, ax_1, ax_2)\n    else:\n        return x.transpose(ax_1, ax_2)\n\ndef fw_set_device(x, device, fw):\n    if fw is numpy:\n        return x\n    else:\n        return x.to(device)\n"
  },
  {
    "path": "code/synthetic/bsrt/utils/spatial_color_alignment.py",
    "content": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef gauss_1d(sz, sigma, center, end_pad=0, density=False):\n    \"\"\" Returns a 1-D Gaussian \"\"\"\n    k = torch.arange(-(sz-1)/2, (sz+1)/2 + end_pad).reshape(1, -1)\n    gauss = torch.exp(-1.0/(2*sigma**2) * (k - center.reshape(-1, 1))**2)\n    if density:\n        gauss /= math.sqrt(2*math.pi) * sigma\n    return gauss\n\n\ndef gauss_2d(sz, sigma, center, end_pad=(0, 0), density=False):\n    \"\"\" Returns a 2-D Gaussian \"\"\"\n    if isinstance(sigma, (float, int)):\n        sigma = (sigma, sigma)\n    if isinstance(sz, int):\n        sz = (sz, sz)\n\n    if isinstance(center, (list, tuple)):\n        center = torch.tensor(center).view(1, 2)\n\n    return gauss_1d(sz[0], sigma[0], center[:, 0], end_pad[0], density).reshape(center.shape[0], 1, -1) * \\\n           gauss_1d(sz[1], sigma[1], center[:, 1], end_pad[1], density).reshape(center.shape[0], -1, 1)\n\n\ndef get_gaussian_kernel(sd):\n    \"\"\" Returns a Gaussian kernel with standard deviation sd \"\"\"\n    ksz = int(4 * sd + 1)\n    assert ksz % 2 == 1\n    K = gauss_2d(ksz, sd, (0.0, 0.0), density=True)\n    K = K / K.sum()\n    return K.unsqueeze(0), ksz\n\n\ndef apply_kernel(im, ksz, gauss_kernel):\n    shape = im.shape\n    im = im.view(-1, 1, *im.shape[-2:])\n\n    pad = [ksz // 2, ksz // 2, ksz // 2, ksz // 2]\n    im = F.pad(im, pad, mode='reflect')\n    im_mean = F.conv2d(im, gauss_kernel).view(shape)\n    return im_mean\n\n\ndef match_colors(im_ref, im_q, im_test, ksz, gauss_kernel):\n    \"\"\" Estimates a color transformation matrix between im_ref and im_q. Applies the estimated transformation to\n        im_test\n    \"\"\"\n    gauss_kernel = gauss_kernel.to(im_ref.device)\n    bi = 5\n\n    # Apply Gaussian smoothing\n    im_ref_mean = apply_kernel(im_ref, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous()\n    im_q_mean = apply_kernel(im_q, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous()\n\n    im_ref_mean_re = im_ref_mean.view(*im_ref_mean.shape[:2], -1)\n    im_q_mean_re = im_q_mean.view(*im_q_mean.shape[:2], -1)\n\n    # Estimate color transformation matrix by minimizing the least squares error\n    c_mat_all = []\n    for ir, iq in zip(im_ref_mean_re, im_q_mean_re):\n        c = torch.lstsq(ir.t(), iq.t())\n        c = c.solution[:3]\n        c_mat_all.append(c)\n\n    c_mat = torch.stack(c_mat_all, dim=0)\n    im_q_mean_conv = torch.matmul(im_q_mean_re.permute(0, 2, 1), c_mat).permute(0, 2, 1)\n    im_q_mean_conv = im_q_mean_conv.view(im_q_mean.shape)\n\n    err = ((im_q_mean_conv - im_ref_mean) * 255.0).norm(dim=1)\n\n    thresh = 20\n\n    # If error is larger than a threshold, ignore these pixels\n    valid = err < thresh\n\n    pad = (im_q.shape[-1] - valid.shape[-1]) // 2\n    pad = [pad, pad, pad, pad]\n    valid = F.pad(valid, pad)\n\n    upsample_factor = im_test.shape[-1] / valid.shape[-1]\n    valid = F.interpolate(valid.unsqueeze(1).float(), scale_factor=upsample_factor, mode='bilinear', align_corners=False)\n    valid = valid > 0.9\n\n    # Apply the transformation to test image\n    im_test_re = im_test.view(*im_test.shape[:2], -1)\n    im_t_conv = torch.matmul(im_test_re.permute(0, 2, 1), c_mat).permute(0, 2, 1)\n    im_t_conv = im_t_conv.view(im_test.shape)\n\n    return im_t_conv, valid\n\n"
  },
  {
    "path": "code/synthetic/bsrt/utils/stn.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass SpatialTransformer(nn.Module):\n    \"\"\"\n    [SpatialTransformer] represesents a spatial transformation block\n    that uses the output from the UNet to preform an grid_sample\n    https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n    \"\"\"\n    def __init__(self, size, mode='bilinear'):\n        \"\"\"\n        Instiatiate the block\n            :param size: size of input to the spatial transformer block\n            :param mode: method of interpolation for grid_sampler\n        \"\"\"\n        super(OldSpatialTransformer, self).__init__()\n        if isinstance(size, int):\n            size = (size, size)\n        # Create sampling grid\n        vectors = [ torch.arange(0, s) for s in size ]\n        grids = torch.meshgrid(vectors)\n        grid  = torch.stack(grids) # y, x, z\n        grid  = torch.unsqueeze(grid, 0)  #add batch\n        grid = grid.type(torch.FloatTensor)\n        self.register_buffer('grid', grid)\n\n        self.mode = mode\n\n    def forward(self, src, flow):\n        \"\"\"\n        Push the src and flow through the spatial transform block\n            :param src: the original moving image\n            :param flow: the output from the U-Net\n        \"\"\"\n        new_locs = self.grid + flow\n\n        shape = flow.shape[2:]\n\n        # Need to normalize grid values to [-1, 1] for resampler\n        for i in range(len(shape)):\n            new_locs[:,i,...] = 2*(new_locs[:,i,...]/(shape[i]-1) - 0.5)\n\n        if len(shape) == 2:\n            new_locs = new_locs.permute(0, 2, 3, 1)\n            new_locs = new_locs[..., [1,0]]\n        elif len(shape) == 3:\n            new_locs = new_locs.permute(0, 2, 3, 4, 1)\n            new_locs = new_locs[..., [2,1,0]]\n\n        return F.grid_sample(src, new_locs, mode=self.mode, align_corners=True)\n"
  },
  {
    "path": "code/synthetic/bsrt/utils/warp.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef warp(feat, flow, mode='bilinear', padding_mode='zeros'):\n    \"\"\"\n    warp an image/tensor (im2) back to im1, according to the optical flow im1 --> im2\n\n    input flow must be in format (x, y) at every pixel\n    feat: [B, C, H, W] (im2)\n    flow: [B, 2, H, W] flow (x, y)\n\n    \"\"\"\n    B, C, H, W = feat.size()\n    # print(feat.device, flow.device)\n\n    # mesh grid\n    rowv, colv = torch.meshgrid([torch.arange(0.5, H + 0.5), torch.arange(0.5, W + 0.5)])\n    grid = torch.stack((colv, rowv), dim=0).unsqueeze(0).float().to(flow.device)\n    # print(grid.device, flow.device, feat.device)\n    # grid = grid.cuda()\n    grid = grid + flow\n\n    # scale grid to [-1,1]\n    grid_norm_c = 2.0 * grid[:, 0] / W - 1.0\n    grid_norm_r = 2.0 * grid[:, 1] / H - 1.0\n\n    grid_norm = torch.stack((grid_norm_c, grid_norm_r), dim=1).to(flow.device)\n\n    grid_norm = grid_norm.permute(0, 2, 3, 1)\n\n    output = F.grid_sample(feat, grid_norm, mode=mode, align_corners=False, padding_mode=padding_mode)\n\n    return output\n"
  },
  {
    "path": "requirements.txt",
    "content": "matplotlib\nimageio\nopencv-python\ntensorboardX\n"
  }
]