Repository: Algolzw/BSRT
Branch: main
Commit: da8d5154478a
Files: 184
Total size: 1.1 MB
Directory structure:
gitextract_kbz2i0m2/
├── .gitignore
├── LICENSE
├── README.md
├── code/
│ ├── real/
│ │ └── bsrt/
│ │ ├── README.md
│ │ ├── data_processing/
│ │ │ ├── __init__.py
│ │ │ ├── camera_pipeline.py
│ │ │ └── synthetic_burst_generation.py
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── burstsr_dataset.py
│ │ │ ├── burstsr_test_dataset.py
│ │ │ ├── data_sampler.py
│ │ │ ├── realworld_burst_test_set.py
│ │ │ ├── synthetic_burst_test_set.py
│ │ │ ├── synthetic_burst_train_set.py
│ │ │ ├── synthetic_burst_val_set.py
│ │ │ └── zurich_raw2rgb_dataset.py
│ │ ├── demo.sh
│ │ ├── loss/
│ │ │ ├── Charbonnier.py
│ │ │ ├── __init__.py
│ │ │ ├── adversarial.py
│ │ │ ├── discriminator.py
│ │ │ ├── filter.py
│ │ │ ├── hist_entropy.py
│ │ │ ├── mssim.py
│ │ │ └── vgg.py
│ │ ├── main.py
│ │ ├── model/
│ │ │ ├── DCNv2/
│ │ │ │ ├── LICENSE
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dcn_v2.py
│ │ │ │ ├── files.txt
│ │ │ │ ├── make.sh
│ │ │ │ ├── setup.py
│ │ │ │ ├── src/
│ │ │ │ │ ├── cpu/
│ │ │ │ │ │ ├── dcn_v2_cpu.cpp
│ │ │ │ │ │ ├── dcn_v2_im2col_cpu.cpp
│ │ │ │ │ │ ├── dcn_v2_im2col_cpu.h
│ │ │ │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp
│ │ │ │ │ │ └── vision.h
│ │ │ │ │ ├── cuda/
│ │ │ │ │ │ ├── dcn_v2_cuda.cu
│ │ │ │ │ │ ├── dcn_v2_im2col_cuda.cu
│ │ │ │ │ │ ├── dcn_v2_im2col_cuda.h
│ │ │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu
│ │ │ │ │ │ └── vision.h
│ │ │ │ │ ├── dcn_v2.h
│ │ │ │ │ └── vision.cpp
│ │ │ │ └── test.py
│ │ │ ├── __init__.py
│ │ │ ├── arch_util.py
│ │ │ ├── bsrt.py
│ │ │ ├── checkpoint.py
│ │ │ ├── common.py
│ │ │ ├── non_local/
│ │ │ │ ├── network.py
│ │ │ │ ├── non_local_concatenation.py
│ │ │ │ ├── non_local_cross_dot_product.py
│ │ │ │ ├── non_local_dot_product.py
│ │ │ │ ├── non_local_embedded_gaussian.py
│ │ │ │ └── non_local_gaussian.py
│ │ │ ├── swin_util.py
│ │ │ └── utils/
│ │ │ ├── interp_methods.py
│ │ │ ├── psconv.py
│ │ │ └── resize_right.py
│ │ ├── option.py
│ │ ├── pwcnet/
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── comparison/
│ │ │ │ └── comparison.py
│ │ │ ├── correlation/
│ │ │ │ ├── README.md
│ │ │ │ └── correlation.py
│ │ │ ├── download.bash
│ │ │ ├── images/
│ │ │ │ └── README.md
│ │ │ ├── out.flo
│ │ │ ├── pwcnet.py
│ │ │ ├── requirements.txt
│ │ │ └── run.py
│ │ ├── requirements.txt
│ │ ├── scripts/
│ │ │ ├── __init__.py
│ │ │ ├── cal_mean_std.py
│ │ │ ├── demo.sh
│ │ │ ├── download_burstsr_dataset.py
│ │ │ ├── evaluate.sh
│ │ │ ├── evaluate_burstsr_val.py
│ │ │ ├── save_results_synburst_val.py
│ │ │ ├── test_burstsr_dataset.py
│ │ │ └── test_synthetic_bursts.py
│ │ ├── test.py
│ │ ├── test_real.py
│ │ ├── trainer.py
│ │ ├── utility.py
│ │ ├── utils/
│ │ │ ├── __init__.py
│ │ │ ├── data_format_utils.py
│ │ │ ├── debayer.py
│ │ │ ├── interp_methods.py
│ │ │ ├── metrics.py
│ │ │ ├── postprocessing_functions.py
│ │ │ ├── resize_right.py
│ │ │ ├── spatial_color_alignment.py
│ │ │ ├── stn.py
│ │ │ └── warp.py
│ │ └── validate.py
│ └── synthetic/
│ └── bsrt/
│ ├── README.md
│ ├── data_processing/
│ │ ├── __init__.py
│ │ ├── camera_pipeline.py
│ │ └── synthetic_burst_generation.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── burstsr_dataset.py
│ │ ├── burstsr_test_dataset.py
│ │ ├── data_sampler.py
│ │ ├── realworld_burst_test_set.py
│ │ ├── synthetic_burst_test_set.py
│ │ ├── synthetic_burst_train_set.py
│ │ ├── synthetic_burst_val_set.py
│ │ └── zurich_raw2rgb_dataset.py
│ ├── demo.sh
│ ├── loss/
│ │ ├── Charbonnier.py
│ │ ├── __init__.py
│ │ ├── adversarial.py
│ │ ├── discriminator.py
│ │ ├── filter.py
│ │ ├── hist_entropy.py
│ │ ├── mssim.py
│ │ └── vgg.py
│ ├── main.py
│ ├── model/
│ │ ├── DCNv2/
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── dcn_v2.py
│ │ │ ├── files.txt
│ │ │ ├── make.sh
│ │ │ ├── setup.py
│ │ │ ├── src/
│ │ │ │ ├── cpu/
│ │ │ │ │ ├── dcn_v2_cpu.cpp
│ │ │ │ │ ├── dcn_v2_im2col_cpu.cpp
│ │ │ │ │ ├── dcn_v2_im2col_cpu.h
│ │ │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp
│ │ │ │ │ └── vision.h
│ │ │ │ ├── cuda/
│ │ │ │ │ ├── dcn_v2_cuda.cu
│ │ │ │ │ ├── dcn_v2_im2col_cuda.cu
│ │ │ │ │ ├── dcn_v2_im2col_cuda.h
│ │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu
│ │ │ │ │ └── vision.h
│ │ │ │ ├── dcn_v2.h
│ │ │ │ └── vision.cpp
│ │ │ └── test.py
│ │ ├── __init__.py
│ │ ├── arch_util.py
│ │ ├── bsrt.py
│ │ ├── checkpoint.py
│ │ ├── common.py
│ │ ├── ebsr.py
│ │ ├── non_local/
│ │ │ ├── network.py
│ │ │ ├── non_local_concatenation.py
│ │ │ ├── non_local_cross_dot_product.py
│ │ │ ├── non_local_dot_product.py
│ │ │ ├── non_local_embedded_gaussian.py
│ │ │ └── non_local_gaussian.py
│ │ ├── swin_util.py
│ │ └── utils/
│ │ ├── interp_methods.py
│ │ ├── psconv.py
│ │ └── resize_right.py
│ ├── option.py
│ ├── requirements.txt
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── cal_mean_std.py
│ │ ├── demo.sh
│ │ ├── download_burstsr_dataset.py
│ │ ├── evaluate.sh
│ │ ├── evaluate_burstsr_val.py
│ │ ├── save_results_synburst_val.py
│ │ ├── test_burstsr_dataset.py
│ │ └── test_synthetic_bursts.py
│ ├── test.py
│ ├── test_synburst.py
│ ├── trainer.py
│ ├── utility.py
│ └── utils/
│ ├── __init__.py
│ ├── data_format_utils.py
│ ├── debayer.py
│ ├── interp_methods.py
│ ├── metrics.py
│ ├── postprocessing_functions.py
│ ├── resize_right.py
│ ├── spatial_color_alignment.py
│ ├── stn.py
│ └── warp.py
└── requirements.txt
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Megvii Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
## BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (CVPRW 2022)
[](https://paperswithcode.com/sota/burst-image-super-resolution-on-burstsr?p=bsrt-improving-burst-super-resolution-with) [](https://paperswithcode.com/sota/burst-image-super-resolution-on?p=bsrt-improving-burst-super-resolution-with)
#### [BSRT](https://arxiv.org/abs/2204.08332), the winner of the NTIRE 2022 Burst Super-Resolution Challenge Real-World Track.
You can also find our winner method in NTIRE 2021 Burst Super-Resolution Challenge [here](https://github.com/Algolzw/EBSR).
> This work addresses the Burst Super-Resolution (BurstSR) task using a new architecture, which requires restoring a high-quality image from a sequence of noisy, misaligned, and low-resolution RAW bursts. To overcome the challenges in BurstSR, we propose a **B**urst **S**uper-**R**esolution **T**ransformer (**BSRT**), which can significantly improve the capability of extracting inter-frame information and reconstruction. To achieve this goal, we propose a Pyramid Flow-Guided Deformable Convolution Network (Pyramid FG-DCN) and incorporate Swin Transformer Blocks and Groups as our main backbone. More specifically, we combine optical flows and deformable convolutions, hence our BSRT can handle misalignment and aggregate the potential texture information in multi-frames more efficiently. In addition, our Transformer-based structure can capture long-range dependency to further improve the performance. The evaluation on both synthetic and real-world tracks demonstrates that our approach achieves a new state-of-the-art in BurstSR task. Further, our BSRT wins the championship in the NTIRE2022 Burst Super-Resolution Challenge.
#### Comparison with State-of-the-art Burst Super-Resolution Methods

## Overview Architecture

## Dependencies
- OS: Ubuntu 18.04
- Python: Python 3.7
- nvidia :
- cuda: 10.1
- cudnn: 7.6.1
- Other reference requirements
## Quick Start
1.Create a conda virtual environment and activate it
```python3
conda create -n pytorch_1.6 python=3.7
source activate pytorch_1.6
```
2.Install PyTorch and torchvision following the official instructions
```python3
conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch
```
3.Install build requirements
```python3
pip3 install -r requirements.txt
```
4.Install DCN
```python3
cd DCNv2
python3 setup.py build develop # build
python3 test.py # run examples and check
```
## Training
We provide all pretrained model weights [here](https://drive.google.com/file/d/1Bv1ZwoE3s8trhG--wjB0Yt6WJIQPpvsn/view?usp=sharing).
#### For Synthetic data
```python3
cd code/synthetic/bsrt
# Modify the root path of training dataset and model etc.
# The number of GPUs should be more than 1
python main.py --n_GPUs 8 --print_every 40 --lr 0.0001 --decay 150-300 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 32 --burst_size 14 --patch_size 256
```
#### For Real-World data
```python3
cd code/real/bsrt
# Modify the root path of training dataset and model etc.
# The number of GPUs should be more than 1
python main.py --n_GPUs 8 --print_every 20 --lr 0.00005 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth
```
The pretrained PWC-Net model can be downloaded [here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing).
## Test
#### For Synthetic data
```python3
# Modify the path of test dataset and the path of the trained model
python test_synburst.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth --root /data/dataset/ntire21/burstsr/synthetic
```
#### For Real-World data
```python3
# Modify the path of test dataset and the path of the trained model
python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real
```
## Results
### Comparison on Synthetic dataset

### Comparison on Real-World dataset

## Citations
If our code helps your research or work, please consider citing our paper.
The following is a BibTeX reference.
```
@inproceedings{luo2022bsrt,
title={BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment},
author={Luo, Ziwei and Li, Youwei and Cheng, Shen and Yu, Lei and Wu, Qi and Wen, Zhihong and Fan, Haoqiang and Sun, Jian and Liu, Shuaicheng},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={998--1008},
year={2022}
}
```
## Contact
email: [ziwei.ro@gmail.com]
================================================
FILE: code/real/bsrt/README.md
================================================
# BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (Real-World)
## Dependencies
- OS: Ubuntu 18.04
- Python: Python 3.7
- nvidia :
- cuda: 10.1
- cudnn: 7.6.1
- Other reference requirements
## Quick Start
1.Create a conda virtual environment and activate it
```python3
conda create -n pytorch_1.6 python=3.7
source activate pytorch_1.6
```
2.Install PyTorch and torchvision following the official instructions
```python3
conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch
```
3.Install build requirements
```python3
pip3 install -r requirements.txt
```
4.Install DCN
```python3
cd DCNv2
python3 setup.py build develop # build
python3 test.py # run examples and check
```
## Training
The pretrained PWC-Net model can be downloaded [here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing).
```python3
# Modify the root path of training dataset and model etc.
# The number of GPUs should be more than 1
python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth
```
## Test
```python3
# Modify the path of test dataset and the path of the trained model
python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real
```
================================================
FILE: code/real/bsrt/data_processing/__init__.py
================================================
================================================
FILE: code/real/bsrt/data_processing/camera_pipeline.py
================================================
import torch
import random
import math
import cv2 as cv
import numpy as np
import utils.data_format_utils as df_utils
""" Based on http://timothybrooks.com/tech/unprocessing
Functions for forward and inverse camera pipeline. All functions input a torch float tensor of shape (c, h, w).
Additionally, some also support batch operations, i.e. inputs of shape (b, c, h, w)
"""
def random_ccm():
"""Generates random RGB -> Camera color correction matrices."""
# Takes a random convex combination of XYZ -> Camera CCMs.
xyz2cams = [[[1.0234, -0.2969, -0.2266],
[-0.5625, 1.6328, -0.0469],
[-0.0703, 0.2188, 0.6406]],
[[0.4913, -0.0541, -0.0202],
[-0.613, 1.3513, 0.2906],
[-0.1564, 0.2151, 0.7183]],
[[0.838, -0.263, -0.0639],
[-0.2887, 1.0725, 0.2496],
[-0.0627, 0.1427, 0.5438]],
[[0.6596, -0.2079, -0.0562],
[-0.4782, 1.3016, 0.1933],
[-0.097, 0.1581, 0.5181]]]
num_ccms = len(xyz2cams)
xyz2cams = torch.tensor(xyz2cams)
weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(0.0, 1.0)
weights_sum = weights.sum()
xyz2cam = (xyz2cams * weights).sum(dim=0) / weights_sum
# Multiplies with RGB -> XYZ to get RGB -> Camera CCM.
rgb2xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]])
rgb2cam = torch.mm(xyz2cam, rgb2xyz)
# Normalizes each row.
rgb2cam = rgb2cam / rgb2cam.sum(dim=-1, keepdims=True)
return rgb2cam
def random_gains():
"""Generates random gains for brightening and white balance."""
# RGB gain represents brightening.
rgb_gain = 1.0 / random.gauss(mu=0.8, sigma=0.1)
# Red and blue gains represent white balance.
red_gain = random.uniform(1.9, 2.4)
blue_gain = random.uniform(1.5, 1.9)
return rgb_gain, red_gain, blue_gain
def apply_smoothstep(image):
"""Apply global tone mapping curve."""
image_out = 3 * image**2 - 2 * image**3
return image_out
def invert_smoothstep(image):
"""Approximately inverts a global tone mapping curve."""
image = image.clamp(0.0, 1.0)
return 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)
def gamma_expansion(image):
"""Converts from gamma to linear space."""
# Clamps to prevent numerical instability of gradients near zero.
return image.clamp(1e-8) ** 2.2
def gamma_compression(image):
"""Converts from linear to gammaspace."""
# Clamps to prevent numerical instability of gradients near zero.
return image.clamp(1e-8) ** (1.0 / 2.2)
def apply_ccm(image, ccm):
"""Applies a color correction matrix."""
assert image.dim() == 3 and image.shape[0] == 3
shape = image.shape
image = image.view(3, -1)
ccm = ccm.to(image.device).type_as(image)
image = torch.mm(ccm, image)
return image.view(shape)
def apply_gains(image, rgb_gain, red_gain, blue_gain):
"""Inverts gains while safely handling saturated pixels."""
assert image.dim() == 3 and image.shape[0] in [3, 4]
if image.shape[0] == 3:
gains = torch.tensor([red_gain, 1.0, blue_gain]) * rgb_gain
else:
gains = torch.tensor([red_gain, 1.0, 1.0, blue_gain]) * rgb_gain
gains = gains.view(-1, 1, 1)
gains = gains.to(image.device).type_as(image)
return (image * gains).clamp(0.0, 1.0)
def safe_invert_gains(image, rgb_gain, red_gain, blue_gain):
"""Inverts gains while safely handling saturated pixels."""
assert image.dim() == 3 and image.shape[0] == 3
gains = torch.tensor([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain
gains = gains.view(-1, 1, 1)
# Prevents dimming of saturated pixels by smoothly masking gains near white.
gray = image.mean(dim=0, keepdims=True)
inflection = 0.9
mask = ((gray - inflection).clamp(0.0) / (1.0 - inflection)) ** 2.0
safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)
return image * safe_gains
def mosaic(image, mode='rggb'):
"""Extracts RGGB Bayer planes from an RGB image."""
shape = image.shape
if image.dim() == 3:
image = image.unsqueeze(0)
if mode == 'rggb':
red = image[:, 0, 0::2, 0::2]
green_red = image[:, 1, 0::2, 1::2]
green_blue = image[:, 1, 1::2, 0::2]
blue = image[:, 2, 1::2, 1::2]
image = torch.stack((red, green_red, green_blue, blue), dim=1)
elif mode == 'grbg':
green_red = image[:, 1, 0::2, 0::2]
red = image[:, 0, 0::2, 1::2]
blue = image[:, 2, 0::2, 1::2]
green_blue = image[:, 1, 1::2, 1::2]
image = torch.stack((green_red, red, blue, green_blue), dim=1)
if len(shape) == 3:
return image.view((4, shape[-2] // 2, shape[-1] // 2))
else:
return image.view((-1, 4, shape[-2] // 2, shape[-1] // 2))
def demosaic(image):
assert isinstance(image, torch.Tensor)
image = image.clamp(0.0, 1.0) * 255
if image.dim() == 4:
num_images = image.dim()
batch_input = True
else:
num_images = 1
batch_input = False
image = image.unsqueeze(0)
# Generate single channel input for opencv
im_sc = torch.zeros((num_images, image.shape[-2] * 2, image.shape[-1] * 2, 1))
im_sc[:, ::2, ::2, 0] = image[:, 0, :, :]
im_sc[:, ::2, 1::2, 0] = image[:, 1, :, :]
im_sc[:, 1::2, ::2, 0] = image[:, 2, :, :]
im_sc[:, 1::2, 1::2, 0] = image[:, 3, :, :]
im_sc = im_sc.numpy().astype(np.uint8)
out = []
for im in im_sc:
# cv.imwrite('frames/tmp.png', im)
im_dem_np = cv.cvtColor(im, cv.COLOR_BAYER_BG2RGB)#_VNG)
# Convert to torch image
im_t = df_utils.npimage_to_torch(im_dem_np, input_bgr=False)
out.append(im_t)
if batch_input:
return torch.stack(out, dim=0)
else:
return out[0]
def random_noise_levels():
"""Generates random noise levels from a log-log linear distribution."""
log_min_shot_noise = math.log(0.0001)
log_max_shot_noise = math.log(0.012)
log_shot_noise = random.uniform(log_min_shot_noise, log_max_shot_noise)
shot_noise = math.exp(log_shot_noise)
line = lambda x: 2.18 * x + 1.20
log_read_noise = line(log_shot_noise) + random.gauss(mu=0.0, sigma=0.26)
read_noise = math.exp(log_read_noise)
return shot_noise, read_noise
def add_noise(image, shot_noise=0.01, read_noise=0.0005):
"""Adds random shot (proportional to image) and read (independent) noise."""
variance = image * shot_noise + read_noise
noise = torch.FloatTensor(image.shape).normal_().to(image.device)*variance.sqrt()
return image + noise
def process_linear_image_rgb(image, meta_info, return_np=False):
image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
image = apply_ccm(image, meta_info['cam2rgb'])
if meta_info['gamma']:
image = gamma_compression(image)
if meta_info['smoothstep']:
image = apply_smoothstep(image)
image = image.clamp(0.0, 1.0)
if return_np:
image = df_utils.torch_to_npimage(image)
return image
def process_linear_image_raw(image, meta_info):
image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
image = demosaic(image)
image = apply_ccm(image, meta_info['cam2rgb'])
if meta_info['gamma']:
image = gamma_compression(image)
if meta_info['smoothstep']:
image = apply_smoothstep(image)
return image.clamp(0.0, 1.0)
================================================
FILE: code/real/bsrt/data_processing/synthetic_burst_generation.py
================================================
import torch
import random
import cv2
import numpy as np
import torch.nn.functional as F
from data_processing.camera_pipeline import *
from utils.data_format_utils import torch_to_numpy, numpy_to_torch
def random_crop(frames, crop_sz):
""" Extract a random crop of size crop_sz from the input frames. If the crop_sz is larger than the input image size,
then the largest possible crop of same aspect ratio as crop_sz will be extracted from frames, and upsampled to
crop_sz.
"""
if not isinstance(crop_sz, (tuple, list)):
crop_sz = (crop_sz, crop_sz)
crop_sz = torch.tensor(crop_sz).float()
shape = frames.shape
# Select scale_factor. Ensure the crop fits inside the image
max_scale_factor = torch.tensor(shape[-2:]).float() / crop_sz
max_scale_factor = max_scale_factor.min().item()
if max_scale_factor < 1.0:
scale_factor = max_scale_factor
else:
scale_factor = 1.0
# Extract the crop
orig_crop_sz = (crop_sz * scale_factor).floor()
assert orig_crop_sz[-2] <= shape[-2] and orig_crop_sz[-1] <= shape[-1], 'Bug in crop size estimation!'
r1 = random.randint(0, shape[-2] - orig_crop_sz[-2])
c1 = random.randint(0, shape[-1] - orig_crop_sz[-1])
r2 = r1 + orig_crop_sz[0].int().item()
c2 = c1 + orig_crop_sz[1].int().item()
frames_crop = frames[:, r1:r2, c1:c2]
# Resize to crop_sz
if scale_factor < 1.0:
frames_crop = F.interpolate(frames_crop.unsqueeze(0), size=crop_sz.int().tolist(), mode='bilinear', align_corners=False).squeeze(0)
return frames_crop
def rgb2rawburst(image, burst_size, downsample_factor=1, burst_transformation_params=None,
image_processing_params=None, interpolation_type='bilinear'):
""" Generates a synthetic LR RAW burst from the input image. The input sRGB image is first converted to linear
sensor space using an inverse camera pipeline. A LR burst is then generated by applying random
transformations defined by burst_transformation_params to the input image, and downsampling it by the
downsample_factor. The generated burst is then mosaicekd and corrputed by random noise.
"""
if image_processing_params is None:
image_processing_params = {}
_defaults = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 'gamma': True, 'add_noise': True}
for k, v in _defaults.items():
if k not in image_processing_params:
image_processing_params[k] = v
# Sample camera pipeline params
if image_processing_params['random_ccm']:
rgb2cam = random_ccm()
else:
rgb2cam = torch.eye(3).float()
cam2rgb = rgb2cam.inverse()
# Sample gains
if image_processing_params['random_gains']:
rgb_gain, red_gain, blue_gain = random_gains()
else:
rgb_gain, red_gain, blue_gain = (1.0, 1.0, 1.0)
# Approximately inverts global tone mapping.
use_smoothstep = image_processing_params['smoothstep']
if use_smoothstep:
image = invert_smoothstep(image)
# Inverts gamma compression.
use_gamma = image_processing_params['gamma']
if use_gamma:
image = gamma_expansion(image)
# Inverts color correction.
image = apply_ccm(image, rgb2cam)
# Approximately inverts white balance and brightening.
image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain)
# Clip saturated pixels.
image = image.clamp(0.0, 1.0)
# Generate LR burst
image_burst_rgb, flow_vectors = single2lrburst(image, burst_size=burst_size,
downsample_factor=downsample_factor,
transformation_params=burst_transformation_params,
interpolation_type=interpolation_type)
# mosaic
image_burst = mosaic(image_burst_rgb.clone())
# Add noise
if image_processing_params['add_noise']:
shot_noise_level, read_noise_level = random_noise_levels()
image_burst = add_noise(image_burst, shot_noise_level, read_noise_level)
else:
shot_noise_level = 0
read_noise_level = 0
# Clip saturated pixels.
image_burst = image_burst.clamp(0.0, 1.0)
meta_info = {'rgb2cam': rgb2cam, 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain,
'blue_gain': blue_gain, 'smoothstep': use_smoothstep, 'gamma': use_gamma,
'shot_noise_level': shot_noise_level, 'read_noise_level': read_noise_level}
return image_burst, image, image_burst_rgb, flow_vectors, meta_info
def get_tmat(image_shape, translation, theta, shear_values, scale_factors):
""" Generates a transformation matrix corresponding to the input transformation parameters """
im_h, im_w = image_shape
t_mat = np.identity(3)
t_mat[0, 2] = translation[0]
t_mat[1, 2] = translation[1]
t_rot = cv2.getRotationMatrix2D((im_w * 0.5, im_h * 0.5), theta, 1.0)
t_rot = np.concatenate((t_rot, np.array([0.0, 0.0, 1.0]).reshape(1, 3)))
t_shear = np.array([[1.0, shear_values[0], -shear_values[0] * 0.5 * im_w],
[shear_values[1], 1.0, -shear_values[1] * 0.5 * im_h],
[0.0, 0.0, 1.0]])
t_scale = np.array([[scale_factors[0], 0.0, 0.0],
[0.0, scale_factors[1], 0.0],
[0.0, 0.0, 1.0]])
t_mat = t_scale @ t_rot @ t_shear @ t_mat
t_mat = t_mat[:2, :]
return t_mat
def single2lrburst(image, burst_size, downsample_factor=1, transformation_params=None,
interpolation_type='bilinear'):
""" Generates a burst of size burst_size from the input image by applying random transformations defined by
transformation_params, and downsampling the resulting burst by downsample_factor.
"""
if interpolation_type == 'bilinear':
interpolation = cv2.INTER_LINEAR
elif interpolation_type == 'lanczos':
interpolation = cv2.INTER_LANCZOS4
else:
raise ValueError
normalize = False
if isinstance(image, torch.Tensor):
if image.max() < 2.0:
image = image * 255.0
normalize = True
image = torch_to_numpy(image).astype(np.uint8)
burst = []
sample_pos_inv_all = []
rvs, cvs = torch.meshgrid([torch.arange(0, image.shape[0]),
torch.arange(0, image.shape[1])])
sample_grid = torch.stack((cvs, rvs, torch.ones_like(cvs)), dim=-1).float()
for i in range(burst_size):
if i == 0:
# For base image, do not apply any random transformations. We only translate the image to center the
# sampling grid
shift = (downsample_factor / 2.0) - 0.5
translation = (shift, shift)
theta = 0.0
shear_factor = (0.0, 0.0)
scale_factor = (1.0, 1.0)
else:
# Sample random image transformation parameters
max_translation = transformation_params.get('max_translation', 0.0)
if max_translation <= 0.01:
shift = (downsample_factor / 2.0) - 0.5
translation = (shift, shift)
else:
translation = (random.uniform(-max_translation, max_translation),
random.uniform(-max_translation, max_translation))
max_rotation = transformation_params.get('max_rotation', 0.0)
theta = random.uniform(-max_rotation, max_rotation)
max_shear = transformation_params.get('max_shear', 0.0)
shear_x = random.uniform(-max_shear, max_shear)
shear_y = random.uniform(-max_shear, max_shear)
shear_factor = (shear_x, shear_y)
max_ar_factor = transformation_params.get('max_ar_factor', 0.0)
ar_factor = np.exp(random.uniform(-max_ar_factor, max_ar_factor))
max_scale = transformation_params.get('max_scale', 0.0)
scale_factor = np.exp(random.uniform(-max_scale, max_scale))
scale_factor = (scale_factor, scale_factor * ar_factor)
output_sz = (image.shape[1], image.shape[0])
# Generate a affine transformation matrix corresponding to the sampled parameters
t_mat = get_tmat((image.shape[0], image.shape[1]), translation, theta, shear_factor, scale_factor)
t_mat_tensor = torch.from_numpy(t_mat)
# Apply the sampled affine transformation
image_t = cv2.warpAffine(image, t_mat, output_sz, flags=interpolation,
borderMode=cv2.BORDER_CONSTANT)
t_mat_tensor_3x3 = torch.cat((t_mat_tensor.float(), torch.tensor([0.0, 0.0, 1.0]).view(1, 3)), dim=0)
t_mat_tensor_inverse = t_mat_tensor_3x3.inverse()[:2, :].contiguous()
sample_pos_inv = torch.mm(sample_grid.view(-1, 3), t_mat_tensor_inverse.t().float()).view(
*sample_grid.shape[:2], -1)
if transformation_params.get('border_crop') is not None:
border_crop = transformation_params.get('border_crop')
image_t = image_t[border_crop:-border_crop, border_crop:-border_crop, :]
sample_pos_inv = sample_pos_inv[border_crop:-border_crop, border_crop:-border_crop, :]
# Downsample the image
image_t = cv2.resize(image_t, None, fx=1.0 / downsample_factor, fy=1.0 / downsample_factor,
interpolation=interpolation)
sample_pos_inv = cv2.resize(sample_pos_inv.numpy(), None, fx=1.0 / downsample_factor,
fy=1.0 / downsample_factor,
interpolation=interpolation)
sample_pos_inv = torch.from_numpy(sample_pos_inv).permute(2, 0, 1).contiguous()
if normalize:
image_t = numpy_to_torch(image_t).float() / 255.0
else:
image_t = numpy_to_torch(image_t).float()
burst.append(image_t)
sample_pos_inv_all.append(sample_pos_inv / downsample_factor)
burst_images = torch.stack(burst)
sample_pos_inv_all = torch.stack(sample_pos_inv_all)
# Compute the flow vectors to go from the i'th burst image to the base image
flow_vectors = sample_pos_inv_all - sample_pos_inv_all[:, :1, ...]
return burst_images, flow_vectors
================================================
FILE: code/real/bsrt/datasets/__init__.py
================================================
================================================
FILE: code/real/bsrt/datasets/burstsr_dataset.py
================================================
import os
import torch
import cv2
import numpy as np
import pickle as pkl
import torch.nn.functional as F
import random
import time
class SamsungRAWImage:
@staticmethod
def load(path):
im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED)
im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16)
im_raw = torch.from_numpy(im_raw)
meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), "rb", -1))
return SamsungRAWImage(im_raw, meta_data['black_level'], meta_data['cam_wb'],
meta_data['daylight_wb'], meta_data['color_matrix'], meta_data['exif_data'],
meta_data.get('crop_info', None), meta_data.get('im_preview', None))
def __init__(self, im_raw, black_level, cam_wb, daylight_wb, color_matrix, exif_data, crop_info=None,
im_preview=None):
self.im_raw = im_raw
self.black_level = black_level
self.cam_wb = cam_wb
self.daylight_wb = daylight_wb
self.color_matrix = color_matrix
self.exif_data = exif_data
self.crop_info = crop_info
self.im_preview = im_preview
self.norm_factor = 1023.0
def get_all_meta_data(self):
return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb,
'color_matrix': self.color_matrix.tolist()}
def get_exposure_time(self):
return self.exif_data['Image ExposureTime'].values[0].decimal()
def get_noise_profile(self):
noise = self.exif_data['Image Tag 0xC761'].values
noise = [n[0] for n in noise]
noise = np.array(noise).reshape(3, 2)
return noise
def get_f_number(self):
return self.exif_data['Image FNumber'].values[0].decimal()
def get_iso(self):
return self.exif_data['Image ISOSpeedRatings'].values[0]
def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False):
im_raw = self.im_raw.float()
if substract_black_level:
im_raw = im_raw - torch.tensor(self.black_level).view(4, 1, 1)
if white_balance:
im_raw = im_raw * torch.tensor(self.cam_wb).view(4, 1, 1)
if normalize:
im_raw = im_raw / self.norm_factor
return im_raw
def shape(self):
shape = (4, self.im_raw.shape[1], self.im_raw.shape[2])
return shape
def crop_image(self, r1, r2, c1, c2):
self.im_raw = self.im_raw[:, r1:r2, c1:c2]
def get_crop(self, r1, r2, c1, c2):
im_raw = self.im_raw[:, r1:r2, c1:c2]
if self.im_preview is not None:
im_preview = self.im_preview[2*r1:2*r2, 2*c1:2*c2]
else:
im_preview = None
return SamsungRAWImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.color_matrix,
self.exif_data, im_preview=im_preview)
def postprocess(self, return_np=True, norm_factor=None):
# Convert to rgb
# im = torch.from_numpy(self.im_raw.astype(np.float32))
im = self.im_raw
im = (im - torch.tensor(self.black_level).view(4, 1, 1)) * torch.tensor(self.cam_wb).view(4, 1, 1)
if norm_factor is None:
im = im / im.max()
else:
im = im / norm_factor
im = torch.stack((im[0], (im[1] + im[2])/2, im[3]), dim=0)
# im = torch.stack((im[0], im[1], im[3]), dim=0)
im_out = im.clamp(0.0, 1.0)
if return_np:
im_out = im_out.permute(1, 2, 0).numpy() * 255.0
im_out = im_out.astype(np.uint8)
return im_out
class CanonImage:
@staticmethod
def load(path, split='train'):
im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED)
im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16)
im_raw = torch.from_numpy(im_raw)
meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), "rb", -1))
return CanonImage(im_raw.float(), meta_data['black_level'], meta_data['cam_wb'],
meta_data['daylight_wb'], meta_data['rgb_xyz_matrix'], meta_data.get('exif_data', None),
meta_data.get('crop_info', None))
def __init__(self, im_raw, black_level, cam_wb, daylight_wb, rgb_xyz_matrix, exif_data, crop_info=None):
super(CanonImage, self).__init__()
self.im_raw = im_raw
if len(black_level) == 4:
black_level = [black_level[0], black_level[1], black_level[3]]
self.black_level = black_level
if len(cam_wb) == 4:
cam_wb = [cam_wb[0], cam_wb[1], cam_wb[3]]
self.cam_wb = cam_wb
if len(daylight_wb) == 4:
daylight_wb = [daylight_wb[0], daylight_wb[1], daylight_wb[3]]
self.daylight_wb = daylight_wb
self.rgb_xyz_matrix = rgb_xyz_matrix
self.xyz_srgb_matrix = torch.tensor([3.2404542, -1.5371385, -0.4985314,
-0.9692660, 1.8760108, 0.0415560,
0.0556434, -0.2040259, 1.0572252]).view(3, 3)
self.exif_data = exif_data
self.crop_info = crop_info
self.norm_factor = 16383
def shape(self):
shape = (3, self.im_raw.shape[1], self.im_raw.shape[2])
return shape
def get_all_meta_data(self):
return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb,
'rgb_xyz_matrix': self.rgb_xyz_matrix.tolist(), 'crop_info': self.crop_info,
'norm_factor': self.norm_factor}
def get_exposure_time(self):
return self.exif_data['EXIF ExposureTime'].values[0].decimal()
def get_f_number(self):
return self.exif_data['EXIF FNumber'].values[0].decimal()
def get_iso(self):
return self.exif_data['EXIF ISOSpeedRatings'].values[0]
def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False):
im_raw = self.im_raw.float()
if substract_black_level:
im_raw = im_raw - torch.tensor(self.black_level).view(3, 1, 1)
if white_balance:
im_raw = im_raw * torch.tensor(self.cam_wb).view(3, 1, 1) / 1024.0
if normalize:
im_raw = im_raw / self.norm_factor
return im_raw
def set_image_data(self, im_data):
self.im_raw = im_data
def crop_image(self, r1, r2, c1, c2):
self.im_raw = self.im_raw[:, r1:r2, c1:c2]
def get_crop(self, r1, r2, c1, c2):
im_raw = self.im_raw[:, r1:r2, c1:c2]
return CanonImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.rgb_xyz_matrix,
self.exif_data, self.crop_info)
def set_crop_info(self, crop_info):
self.crop_info = crop_info
def resize(self, size=None, scale_factor=None):
self.im_raw = F.interpolate(self.im_raw.unsqueeze(0), size=size, scale_factor=scale_factor,
mode='bilinear').squeeze(0)
def postprocess(self, return_np=True):
# Convert to rgb
im = self.im_raw
im = (im - torch.tensor(self.black_level).view(3, 1, 1)).float() * torch.tensor(self.cam_wb).view(3, 1, 1)
im_out = im / im.max()
im_out = im_out.clamp(0.0, 1.0)
if return_np:
im_out = im_out.permute(1, 2, 0).numpy() * 255.0
im_out = im_out.astype(np.uint8)
return im_out
def load_txt(path):
with open(path, 'r') as fh:
out = [d.rstrip() for d in fh.readlines()]
return out
class BurstSRDataset(torch.utils.data.Dataset):
""" Real-world burst super-resolution dataset. """
def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='train'):
"""
args:
root : path of the root directory
burst_size : Burst size. Maximum allowed burst size is 14.
crop_sz: Size of the extracted crop. Maximum allowed crop size is 80
center_crop: Whether to extract a random crop, or a centered crop.
random_flip: Whether to apply random horizontal and vertical flip
split: Can be 'train' or 'val'
"""
assert burst_size <= 14, 'burst_sz must be less than or equal to 14'
assert crop_sz <= 80, 'crop_sz must be less than or equal to 80'
assert split in ['train', 'val']
root = root + '/' + split
super().__init__()
self.burst_size = burst_size
self.crop_sz = crop_sz
self.split = split
self.center_crop = center_crop
self.random_flip = random_flip
self.root = root
self.substract_black_level = True
self.white_balance = False
self.burst_list = self._get_burst_list()
def _get_burst_list(self):
burst_list = sorted(os.listdir('{}'.format(self.root)))
# print(burst_list)
return burst_list
def get_burst_info(self, burst_id):
burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]}
return burst_info
def _get_raw_image(self, burst_id, im_id):
raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id))
return raw_image
def _get_gt_image(self, burst_id):
canon_im = CanonImage.load('{}/{}/canon'.format(self.root, self.burst_list[burst_id]), split=self.split)
return canon_im
def get_burst(self, burst_id, im_ids, info=None):
frames = [self._get_raw_image(burst_id, i) for i in im_ids]
gt = self._get_gt_image(burst_id)
if info is None:
info = self.get_burst_info(burst_id)
return frames, gt, info
def _sample_images(self):
burst_size = 14
ids = random.sample(range(1, burst_size), k=self.burst_size - 1)
ids = [0, ] + ids
return ids
def __len__(self):
return len(self.burst_list)
def __getitem__(self, index):
# Sample the images in the burst, in case a burst_size < 14 is used.
im_ids = self._sample_images()
# Read the burst images along with HR ground truth
frames, gt, meta_info = self.get_burst(index, im_ids)
# Extract crop if needed
if frames[0].shape()[-1] != self.crop_sz:
if getattr(self, 'center_crop', False):
r1 = (frames[0].shape()[-2] - self.crop_sz) // 2
c1 = (frames[0].shape()[-1] - self.crop_sz) // 2
else:
r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz)
c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz)
r2 = r1 + self.crop_sz
c2 = c1 + self.crop_sz
scale_factor = gt.shape()[-1] // frames[0].shape()[-1]
frames = [im.get_crop(r1, r2, c1, c2) for im in frames]
gt = gt.get_crop(scale_factor * r1, scale_factor * r2, scale_factor * c1, scale_factor * c2)
# Load the RAW image data
burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level,
white_balance=self.white_balance) for im in frames]
# Convert to tensor
gt_image_data = gt.get_image_data(normalize=True, white_balance=self.white_balance,
substract_black_level=self.substract_black_level)
if self.random_flip:
burst_image_data = [flatten_raw_image(im) for im in burst_image_data]
pad = [0, 0, 0, 0]
if random.random() > 0.5:
burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data]
gt_image_data = gt_image_data.flip([2, ])[:, :, 2:-2].contiguous()
pad[1] = 1
if random.random() > 0.5:
burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data]
gt_image_data = gt_image_data.flip([1, ])[:, 2:-2, :].contiguous()
pad[3] = 1
burst_image_data = [pack_raw_image(im) for im in burst_image_data]
burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data]
gt_image_data = F.pad(gt_image_data.unsqueeze(0), [4 * p for p in pad], mode='replicate').squeeze(0)
burst_image_meta_info = frames[0].get_all_meta_data()
burst_image_meta_info['black_level_subtracted'] = self.substract_black_level
burst_image_meta_info['while_balance_applied'] = self.white_balance
burst_image_meta_info['norm_factor'] = frames[0].norm_factor
gt_image_meta_info = gt.get_all_meta_data()
burst = torch.stack(burst_image_data, dim=0)
burst_exposure = frames[0].get_exposure_time()
canon_exposure = gt.get_exposure_time()
burst_f_number = frames[0].get_f_number()
canon_f_number = gt.get_f_number()
burst_iso = frames[0].get_iso()
canon_iso = gt.get_iso()
# Normalize the GT image to account for differences in exposure, ISO etc
light_factor_burst = burst_exposure * burst_iso / (burst_f_number ** 2)
light_factor_canon = canon_exposure * canon_iso / (canon_f_number ** 2)
exp_scale_factor = (light_factor_burst / light_factor_canon)
gt_image_data = gt_image_data * exp_scale_factor
gt_image_meta_info['black_level_subtracted'] = self.substract_black_level
gt_image_meta_info['while_balance_applied'] = self.white_balance
gt_image_meta_info['norm_factor'] = gt.norm_factor / exp_scale_factor
burst_image_meta_info['exposure'] = burst_exposure
burst_image_meta_info['f_number'] = burst_f_number
burst_image_meta_info['iso'] = burst_iso
gt_image_meta_info['exposure'] = canon_exposure
gt_image_meta_info['f_number'] = canon_f_number
gt_image_meta_info['iso'] = canon_iso
burst = burst.float()
frame_gt = gt_image_data.float()
meta_info_burst = burst_image_meta_info
meta_info_gt = gt_image_meta_info
del meta_info_gt['crop_info']
for k, v in meta_info_gt.items():
if isinstance(v, (list, tuple)):
meta_info_gt[k] = torch.tensor(v)
for k, v in meta_info_burst.items():
if isinstance(v, (list, tuple)):
meta_info_burst[k] = torch.tensor(v)
meta_info_burst['burst_name'] = meta_info['burst_name']
return burst, frame_gt, meta_info_burst, meta_info_gt
def pack_raw_image(im_raw):
if isinstance(im_raw, np.ndarray):
im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2))
elif isinstance(im_raw, torch.Tensor):
im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype).to(im_raw.device)
else:
raise Exception
im_out[0, :, :] = im_raw[0::2, 0::2]
im_out[1, :, :] = im_raw[0::2, 1::2]
im_out[2, :, :] = im_raw[1::2, 0::2]
im_out[3, :, :] = im_raw[1::2, 1::2]
return im_out
def flatten_raw_image(im_raw_4ch):
if isinstance(im_raw_4ch, np.ndarray):
im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2))
elif isinstance(im_raw_4ch, torch.Tensor):
im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device)
else:
raise Exception
im_out[0::2, 0::2] = im_raw_4ch[0, :, :]
im_out[0::2, 1::2] = im_raw_4ch[1, :, :]
im_out[1::2, 0::2] = im_raw_4ch[2, :, :]
im_out[1::2, 1::2] = im_raw_4ch[3, :, :]
return im_out
def pack_raw_image_batch(im_raw):
im_out = torch.zeros((im_raw.shape[0], im_raw.shape[1], 4, im_raw.shape[3] // 2, im_raw.shape[4] // 2), dtype=im_raw.dtype).to(im_raw.device)
im_out[:, :, 0, :, :] = im_raw[:, :, 0, 0::2, 0::2]
im_out[:, :, 1, :, :] = im_raw[:, :, 0, 0::2, 1::2]
im_out[:, :, 2, :, :] = im_raw[:, :, 0, 1::2, 0::2]
im_out[:, :, 3, :, :] = im_raw[:, :, 0, 1::2, 1::2]
return im_out
def flatten_raw_image_batch(im_raw_4ch):
im_out = torch.zeros((im_raw_4ch.shape[0], im_raw_4ch.shape[1], 1, im_raw_4ch.shape[3] * 2, im_raw_4ch.shape[4] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device)
im_out[:, :, 0, 0::2, 0::2] = im_raw_4ch[:, :, 0, :, :]
im_out[:, :, 0, 0::2, 1::2] = im_raw_4ch[:, :, 1, :, :]
im_out[:, :, 0, 1::2, 0::2] = im_raw_4ch[:, :, 2, :, :]
im_out[:, :, 0, 1::2, 1::2] = im_raw_4ch[:, :, 3, :, :]
return im_out
================================================
FILE: code/real/bsrt/datasets/burstsr_test_dataset.py
================================================
import os
import torch
import torch.nn.functional as F
import random
from .burstsr_dataset import SamsungRAWImage, flatten_raw_image, pack_raw_image
class BurstSRDataset(torch.utils.data.Dataset):
""" Real-world burst super-resolution dataset. """
def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='test'):
"""
args:
root : path of the root directory
burst_size : Burst size. Maximum allowed burst size is 14.
crop_sz: Size of the extracted crop. Maximum allowed crop size is 80
center_crop: Whether to extract a random crop, or a centered crop.
random_flip: Whether to apply random horizontal and vertical flip
split: Can be 'train' or 'val'
"""
assert burst_size <= 14, 'burst_sz must be less than or equal to 14'
assert crop_sz <= 80, 'crop_sz must be less than or equal to 80'
assert split in ['test']
root = root + '/' + split
super().__init__()
self.burst_size = burst_size
self.crop_sz = crop_sz
self.split = split
self.center_crop = center_crop
self.random_flip = random_flip
self.root = root
self.substract_black_level = True
self.white_balance = False
self.burst_list = self._get_burst_list()
def _get_burst_list(self):
burst_list = sorted(os.listdir('{}'.format(self.root)))
return burst_list
def get_burst_info(self, burst_id):
burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]}
return burst_info
def _get_raw_image(self, burst_id, im_id):
raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id))
return raw_image
def get_burst(self, burst_id, im_ids, info=None):
frames = [self._get_raw_image(burst_id, i) for i in im_ids]
if info is None:
info = self.get_burst_info(burst_id)
return frames, info
def _sample_images(self):
burst_size = 14
ids = random.sample(range(1, burst_size), k=self.burst_size - 1)
ids = [0, ] + ids
return ids
def __len__(self):
return len(self.burst_list)
def __getitem__(self, index):
# Sample the images in the burst, in case a burst_size < 14 is used.
im_ids = self._sample_images()
# Read the burst images along with HR ground truth
frames, meta_info = self.get_burst(index, im_ids)
# Extract crop if needed
if frames[0].shape()[-1] != self.crop_sz:
if getattr(self, 'center_crop', False):
r1 = (frames[0].shape()[-2] - self.crop_sz) // 2
c1 = (frames[0].shape()[-1] - self.crop_sz) // 2
else:
r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz)
c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz)
r2 = r1 + self.crop_sz
c2 = c1 + self.crop_sz
frames = [im.get_crop(r1, r2, c1, c2) for im in frames]
# Load the RAW image data
burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level,
white_balance=self.white_balance) for im in frames]
if self.random_flip:
burst_image_data = [flatten_raw_image(im) for im in burst_image_data]
pad = [0, 0, 0, 0]
if random.random() > 0.5:
burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data]
pad[1] = 1
if random.random() > 0.5:
burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data]
pad[3] = 1
burst_image_data = [pack_raw_image(im) for im in burst_image_data]
burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data]
burst_image_meta_info = frames[0].get_all_meta_data()
burst_image_meta_info['black_level_subtracted'] = self.substract_black_level
burst_image_meta_info['while_balance_applied'] = self.white_balance
burst_image_meta_info['norm_factor'] = frames[0].norm_factor
burst = torch.stack(burst_image_data, dim=0)
burst_exposure = frames[0].get_exposure_time()
burst_f_number = frames[0].get_f_number()
burst_iso = frames[0].get_iso()
burst_image_meta_info['exposure'] = burst_exposure
burst_image_meta_info['f_number'] = burst_f_number
burst_image_meta_info['iso'] = burst_iso
burst = burst.float()
meta_info_burst = burst_image_meta_info
for k, v in meta_info_burst.items():
if isinstance(v, (list, tuple)):
meta_info_burst[k] = torch.tensor(v)
return burst, meta_info_burst
================================================
FILE: code/real/bsrt/datasets/data_sampler.py
================================================
"""
Modified from torch.utils.data.distributed.DistributedSampler
Support enlarging the dataset for *iter-oriented* training, for saving time when restart the
dataloader after each epoch
"""
import math
import torch
import torch.distributed as dist
from torch.utils.data.sampler import Sampler
class DistIterSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(
self.total_size, generator=g
).tolist() # Returns a random permutation of integers from 0 to n - 1
dsize = len(self.dataset)
indices = [v % dsize for v in indices]
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
================================================
FILE: code/real/bsrt/datasets/realworld_burst_test_set.py
================================================
import torch
import cv2
import numpy as np
import pickle as pkl
class RealWorldBurstTest(torch.utils.data.Dataset):
"""
"""
def __init__(self, root):
self.root = root
self.burst_list = list(range(20))
self.burst_size = 14
def __len__(self):
return len(self.burst_list)
def _read_burst_image(self, index, image_id):
im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)
im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
return im_t
def __getitem__(self, index):
"""
args:
index: Index of the burst
returns:
burst: LR RAW burst, a torch tensor of shape
The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
meta_info: Meta information about the burst
"""
burst_name = '{:04d}'.format(index)
burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
burst = torch.stack(burst, 0)
meta_info = {}
meta_info['burst_name'] = burst_name
return burst, meta_info
================================================
FILE: code/real/bsrt/datasets/synthetic_burst_test_set.py
================================================
import torch
import cv2
import numpy as np
import pickle as pkl
class SyntheticBurstTest(torch.utils.data.Dataset):
""" Synthetic burst test set. The test burst have been generated using the same synthetic pipeline as
employed in SyntheticBurst dataset.
"""
def __init__(self, root):
self.root = root
self.burst_list = list(range(92))
self.burst_size = 14
def __len__(self):
return len(self.burst_list)
def _read_burst_image(self, index, image_id):
im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)
im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
return im_t
def __getitem__(self, index):
""" Generates a synthetic burst
args:
index: Index of the burst
returns:
burst: LR RAW burst, a torch tensor of shape
The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
meta_info: Meta information about the burst
"""
burst_name = '{:04d}'.format(index)
burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
burst = torch.stack(burst, 0)
meta_info = {}
meta_info['burst_name'] = burst_name
return burst, meta_info
================================================
FILE: code/real/bsrt/datasets/synthetic_burst_train_set.py
================================================
import torch
import numpy as np
from PIL import Image
from data_processing.synthetic_burst_generation import rgb2rawburst, random_crop #syn_burst_utils
import torchvision.transforms as tfm
class SyntheticBurst(torch.utils.data.Dataset):
""" Synthetic burst dataset for joint denoising, demosaicking, and super-resolution. RAW Burst sequences are
synthetically generated on the fly as follows. First, a single image is loaded from the base_dataset. The sampled
image is converted to linear sensor space using the inverse camera pipeline employed in [1]. A burst
sequence is then generated by adding random translations and rotations to the converted image. The generated burst
is then converted is then mosaicked, and corrupted by random noise to obtain the RAW burst.
[1] Unprocessing Images for Learned Raw Denoising, Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen,
Jiawen and Sharlet, Dillon and Barron, Jonathan T, CVPR 2019
"""
def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=tfm.ToTensor()):
self.base_dataset = base_dataset
self.burst_size = burst_size
self.crop_sz = crop_sz
self.transform = transform
self.downsample_factor = 4
self.burst_transformation_params = {'max_translation': 24.0,
'max_rotation': 1.0,
'max_shear': 0.0,
'max_scale': 0.0,
'border_crop': 24}
self.image_processing_params = {'random_ccm': True, 'random_gains': True, 'smoothstep': True,
'gamma': True,
'add_noise': True}
self.interpolation_type = 'bilinear'
def __len__(self):
return len(self.base_dataset)
def __getitem__(self, index):
""" Generates a synthetic burst
args:
index: Index of the image in the base_dataset used to generate the burst
returns:
burst: Generated LR RAW burst, a torch tensor of shape
[burst_size, 4, self.crop_sz / (2*self.downsample_factor), self.crop_sz / (2*self.downsample_factor)]
The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
The extra factor 2 in the denominator (2*self.downsample_factor) corresponds to the mosaicking
operation.
frame_gt: The HR RGB ground truth in the linear sensor space, a torch tensor of shape
[3, self.crop_sz, self.crop_sz]
flow_vectors: The ground truth flow vectors between a burst image and the base image (i.e. the first image in the burst).
The flow_vectors can be used to warp the burst images to the base frame, using the 'warp'
function in utils.warp package.
flow_vectors is torch tensor of shape
[burst_size, 2, self.crop_sz / self.downsample_factor, self.crop_sz / self.downsample_factor].
Note that the flow_vectors are in the LR RGB space, before mosaicking. Hence it has twice
the number of rows and columns, compared to the output burst.
NOTE: The flow_vectors are only available during training for the purpose of using any
auxiliary losses if needed. The flow_vectors will NOT be provided for the bursts in the
test set
meta_info: A dictionary containing the parameters used to generate the synthetic burst.
"""
frame = self.base_dataset[index]
# Augmentation, e.g. convert to tensor
if self.transform is not None:
# frame = Image.fromarray(frame)
frame = self.transform(frame)
# Extract a random crop from the image
crop_sz = self.crop_sz + 2 * self.burst_transformation_params.get('border_crop', 0)
frame_crop = random_crop(frame, crop_sz)
# Generate RAW burst
burst, frame_gt, burst_rgb, flow_vectors, meta_info = rgb2rawburst(frame_crop,
self.burst_size,
self.downsample_factor,
burst_transformation_params=self.burst_transformation_params,
image_processing_params=self.image_processing_params,
interpolation_type=self.interpolation_type
)
if self.burst_transformation_params.get('border_crop') is not None:
border_crop = self.burst_transformation_params.get('border_crop')
frame_gt = frame_gt[:, border_crop:-border_crop, border_crop:-border_crop]
return burst, frame_gt, flow_vectors, meta_info
================================================
FILE: code/real/bsrt/datasets/synthetic_burst_val_set.py
================================================
import os
import torch
import cv2
import numpy as np
import pickle as pkl
class SyntheticBurstVal(torch.utils.data.Dataset):
""" Synthetic burst validation set introduced in [1]. The validation burst have been generated using a
synthetic data generation pipeline. The dataset can be downloaded from
https://data.vision.ee.ethz.ch/bhatg/SyntheticBurstVal.zip
[1] Deep Burst Super-Resolution. Goutam Bhat, Martin Danelljan, Luc Van Gool, and Radu Timofte. CVPR 2021
"""
def __init__(self, root=None, initialize=True):
"""
args:
root - Path to root dataset directory
initialize - boolean indicating whether to load the meta-data for the dataset
"""
self.root = os.path.join(root, 'val')
self.burst_list = list(range(300))
self.burst_size = 14
def initialize(self):
pass
def __len__(self):
return len(self.burst_list)
def _read_burst_image(self, index, image_id):
im = cv2.imread('{}/bursts/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)
im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
return im_t
def _read_gt_image(self, index):
gt = cv2.imread('{}/gt/{:04d}/im_rgb.png'.format(self.root, index), cv2.IMREAD_UNCHANGED)
gt_t = (torch.from_numpy(gt.astype(np.float32)) / 2 ** 14).permute(2, 0, 1).float()
return gt_t
def _read_meta_info(self, index):
with open('{}/gt/{:04d}/meta_info.pkl'.format(self.root, index), "rb") as input_file:
meta_info = pkl.load(input_file)
return meta_info
def __getitem__(self, index):
""" Generates a synthetic burst
args:
index: Index of the burst
returns:
burst: LR RAW burst, a torch tensor of shape
[14, 4, 48, 48]
The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
gt : Ground truth linear image
meta_info: Meta info about the burst which can be used to convert gt to sRGB space
"""
burst_name = '{:04d}'.format(index)
burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
burst = torch.stack(burst, 0)
gt = self._read_gt_image(index)
meta_info = self._read_meta_info(index)
meta_info['burst_name'] = burst_name
return burst, gt, meta_info
================================================
FILE: code/real/bsrt/datasets/zurich_raw2rgb_dataset.py
================================================
import torch
import os
import numpy as np
from cv2 import imread
class ZurichRAW2RGB(torch.utils.data.Dataset):
""" Canon RGB images from the "Zurich RAW to RGB mapping" dataset. You can download the full
dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the
Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip
"""
def __init__(self, root, split='train'):
super().__init__()
if split in ['train', 'test']:
self.img_pth = os.path.join(root, split, 'canon')
else:
raise Exception('Unknown split {}'.format(split))
self.image_list = self._get_image_list(split)
def _get_image_list(self, split):
if split == 'train':
image_list = ['{:d}.jpg'.format(i) for i in range(46839)]
elif split == 'test':
# image_list = ['{:d}.jpg'.format(int(i)) for i in np.linspace(1, 1200, 200)]
image_list = ['{:d}.jpg'.format(i) for i in range(1200)]
else:
raise Exception
return image_list
def _get_image(self, im_id):
path = os.path.join(self.img_pth, self.image_list[im_id])
img = imread(path)
return img
def get_image(self, im_id):
frame = self._get_image(im_id)
return frame
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
frame = self._get_image(index)
return frame
================================================
FILE: code/real/bsrt/demo.sh
================================================
#!/usr/bin/env bash
python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth
# python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_large --model BSRT --fp16 --model_level L --swinfeature --batch_size 8 --burst_size 14 --patch_size 48 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_large/bsrt_best_epoch.pth
# python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real
# python test_real.py --n_GPUs 1 --model BSRT --model_level L --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_large/bsrt_realworld.pth --root /data/dataset/ntire21/burstsr/real
================================================
FILE: code/real/bsrt/loss/Charbonnier.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class CharbonnierLoss(nn.Module):
"""L1 charbonnier loss."""
def __init__(self, epsilon=1e-3, reduce=True):
super(CharbonnierLoss, self).__init__()
self.eps = epsilon * epsilon
self.reduce = reduce
def forward(self, X, Y):
diff = torch.add(X, -Y)
error = torch.sqrt(diff * diff + self.eps)
if self.reduce:
loss = torch.mean(error)
else:
loss = error
return loss
================================================
FILE: code/real/bsrt/loss/__init__.py
================================================
import os
from importlib import import_module
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Loss(nn.modules.loss._Loss):
def __init__(self, args, ckp):
super(Loss, self).__init__()
if args.local_rank == 0:
print('Preparing loss function:')
self.n_GPUs = args.n_GPUs
self.loss = []
self.loss_module = nn.ModuleList()
for loss in args.loss.split('+'):
weight, loss_type = loss.split('*')
if loss_type == 'MSE':
loss_function = nn.MSELoss()
elif loss_type == 'L1':
loss_function = nn.L1Loss()
elif loss_type.find('VGG') >= 0:
module = import_module('loss.vgg')
loss_function = getattr(module, 'VGG')(
loss_type[3:],
rgb_range=args.rgb_range
)
elif loss_type.find('GAN') >= 0:
module = import_module('loss.adversarial')
loss_function = getattr(module, 'Adversarial')(
args,
loss_type
)
elif loss_type == 'FILTER':
module = import_module('loss.filter')
loss_function = getattr(module, 'Filter')(args)
elif loss_type == 'SSIM':
module = import_module('loss.mssim')
loss_function = getattr(module, 'SSIM')(args)
elif loss_type == 'MSSSIM':
module = import_module('loss.mssim')
loss_function = getattr(module, 'MSSSIM')(args)
self.loss.append({
'type': loss_type,
'weight': float(weight),
'function': loss_function}
)
if loss_type.find('GAN') >= 0:
self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
if len(self.loss) > 1:
self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
for l in self.loss:
if l['function'] is not None:
if args.local_rank == 0:
print('{:.3f} * {}'.format(l['weight'], l['type']))
self.loss_module.append(l['function'])
self.log = torch.Tensor()
device = torch.device('cpu' if args.cpu else 'cuda')
self.loss_module.to(device)
if args.precision == 'half': self.loss_module.half()
if not args.cpu and args.n_GPUs > 1:
self.loss_module = nn.DataParallel(
self.loss_module, range(args.n_GPUs)
)
if args.load != '': self.load(ckp.dir, cpu=args.cpu)
def forward(self, sr, hr):
losses = []
for i, l in enumerate(self.loss):
if l['function'] is not None:
loss = l['function'](sr, hr)
effective_loss = l['weight'] * loss
losses.append(effective_loss)
self.log[-1, i] += effective_loss.item()
elif l['type'] == 'DIS':
self.log[-1, i] += self.loss[i - 1]['function'].loss
loss_sum = sum(losses)
if len(self.loss) > 1:
self.log[-1, -1] += loss_sum.item()
return loss_sum
def step(self):
for l in self.get_loss_module():
if hasattr(l, 'scheduler'):
l.scheduler.step()
def start_log(self):
self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
def end_log(self, n_batches):
self.log[-1].div_(n_batches)
def display_loss(self, batch):
n_samples = batch + 1
log = []
for l, c in zip(self.loss, self.log[-1]):
log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
return ''.join(log)
def plot_loss(self, apath, epoch):
axis = np.linspace(1, epoch, epoch)
for i, l in enumerate(self.loss):
label = '{} Loss'.format(l['type'])
fig = plt.figure()
plt.title(label)
plt.plot(axis, self.log[:, i].numpy(), label=label)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.grid(True)
plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
plt.close(fig)
def get_loss_module(self):
if self.n_GPUs == 1:
return self.loss_module
else:
return self.loss_module.module
def save(self, apath):
torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
def load(self, apath, cpu=False):
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
self.load_state_dict(torch.load(
os.path.join(apath, 'loss.pt'),
**kwargs
))
self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
for l in self.get_loss_module():
if hasattr(l, 'scheduler'):
for _ in range(len(self.log)): l.scheduler.step()
================================================
FILE: code/real/bsrt/loss/adversarial.py
================================================
import utility
from types import SimpleNamespace
from model import common
from loss import discriminator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Adversarial(nn.Module):
def __init__(self, args, gan_type):
super(Adversarial, self).__init__()
self.gan_type = gan_type
self.gan_k = args.gan_k
self.dis = discriminator.Discriminator(args)
# if gan_type == 'WGAN_GP':
if True:
# see https://arxiv.org/pdf/1704.00028.pdf pp.4
optim_dict = {
'optimizer': 'ADAM',
'betas': (0.5, 0.9),
'epsilon': 1e-8,
'lr': 1e-5,
'weight_decay': args.weight_decay,
'decay': args.decay,
'gamma': args.gamma
}
optim_args = SimpleNamespace(**optim_dict)
else:
optim_args = args
self.optimizer = utility.make_optimizer(optim_args, self.dis)
def forward(self, fake, real):
# updating discriminator...
self.loss = 0
fake_detach = fake.detach() # do not backpropagate through G
for _ in range(self.gan_k):
self.optimizer.zero_grad()
# d: B x 1 tensor
d_fake = self.dis(fake_detach)
d_real = self.dis(real)
retain_graph = False
if self.gan_type in ['GAN', 'SNGAN']:
loss_d = self.bce(d_real, d_fake)
elif self.gan_type.find('WGAN') >= 0:
loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0:
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True
d_hat = self.dis(hat)
gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty
# from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake.mean(dim=0, keepdim=True)
better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
loss_d = self.bce(better_real, better_fake)
retain_graph = True
# Discriminator update
self.loss += loss_d.item()
loss_d.backward(retain_graph=retain_graph)
self.optimizer.step()
if self.gan_type == 'WGAN':
for p in self.dis.parameters():
p.data.clamp_(-1, 1)
self.loss /= self.gan_k
# updating generator...
d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
if self.gan_type in ['GAN', 'SNGAN']:
label_real = torch.ones_like(d_fake_bp)
loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
elif self.gan_type.find('WGAN') >= 0:
loss_g = -d_fake_bp.mean()
elif self.gan_type == 'RGAN':
better_real = d_real.detach() - d_fake_bp.mean(dim=0, keepdim=True)
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True).detach()
loss_g = self.bce(better_fake, better_real)
# Generator loss
return loss_g
def state_dict(self, *args, **kwargs):
state_discriminator = self.dis.state_dict(*args, **kwargs)
state_optimizer = self.optimizer.state_dict()
return dict(**state_discriminator, **state_optimizer)
def bce(self, real, fake):
label_real = torch.ones_like(real)
label_fake = torch.zeros_like(fake)
bce_real = F.binary_cross_entropy_with_logits(real, label_real)
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
bce_loss = bce_real + bce_fake
return bce_loss
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
================================================
FILE: code/real/bsrt/loss/discriminator.py
================================================
from model import common
import torch.nn as nn
class Discriminator(nn.Module):
'''
output is not normalized
'''
def __init__(self, args, gan_type='GAN'):
super(Discriminator, self).__init__()
in_channels = args.n_colors
out_channels = 32
depth = 6
def _block(_in_channels, _out_channels, stride=1):
Conv = nn.Conv2d(
_in_channels,
_out_channels,
3,
padding=1,
stride=stride,
bias=False
)
if gan_type == 'SNGAN':
return nn.Sequential(
spectral_norm(Conv),
nn.BatchNorm2d(_out_channels),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
else:
return nn.Sequential(
Conv,
nn.BatchNorm2d(_out_channels),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
m_features = [_block(in_channels, out_channels)]
for i in range(depth):
in_channels = out_channels
# if i % 2 == 1:
# stride = 1
# out_channels *= 2
# else:
out_channels *= 2
stride = 2
m_features.append(_block(in_channels, out_channels, stride=stride))
patch_size = args.patch_size // 2**(depth-1)
# print(out_channels, patch_size)
m_classifier = [
nn.Flatten(),
nn.Linear(out_channels*patch_size**2, 512),
nn.LeakyReLU(0.2, True),
nn.Linear(512, 1)
]
self.features = nn.Sequential(*m_features)
self.classifier = nn.Sequential(*m_classifier)
def forward(self, x):
features = self.features(x)
# print(features.shape)
output = self.classifier(features)
return output
================================================
FILE: code/real/bsrt/loss/filter.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class Filter(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
kernel = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
self.conv = nn.Conv2d(args.n_colors, args.n_colors, 3, 3)
with torch.no_grad():
self.conv.weight.copy_(kernel.float())
self.loss = nn.L1Loss()
def forward(self, x, y):
preds_x = self.conv(x)
preds_y = self.conv(y)
return self.loss(preds_x, preds_y)
================================================
FILE: code/real/bsrt/loss/hist_entropy.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class HistEntropy(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
def forward(self, x):
p = torch.softmax(x, dim=1)
logp = torch.log_softmax(x, dim=1)
entropy = (-p * logp).sum(dim=(2, 3)).mean()
return entropy
================================================
FILE: code/real/bsrt/loss/mssim.py
================================================
import torch
import torch.nn.functional as F
from math import exp
import numpy as np
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
if torch.max(img1) > 128:
max_val = 255
else:
max_val = 1
if torch.min(img1) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, channel, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window(real_size, channel=channel).to(img1.device)
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None):
device = img1.device
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
levels = weights.size()[0]
ssims = []
mcs = []
for _ in range(levels):
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
# Relu normalize (not compliant with original definition)
if normalize == "relu":
ssims.append(torch.relu(sim))
mcs.append(torch.relu(cs))
else:
ssims.append(sim)
mcs.append(cs)
img1 = F.avg_pool2d(img1, (2, 2))
img2 = F.avg_pool2d(img2, (2, 2))
ssims = torch.stack(ssims)
mcs = torch.stack(mcs)
# Simple normalize (not compliant with original definition)
# TODO: remove support for normalize == True (kept for backward support)
if normalize == "simple" or normalize == True:
ssims = (ssims + 1) / 2
mcs = (mcs + 1) / 2
pow1 = mcs ** weights
pow2 = ssims ** weights
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
output = torch.prod(pow1[:-1] * pow2[-1])
return output
# Classes to re-use window
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.val_range = val_range
# Assume 1 channel for SSIM
self.channel = 1
self.window = create_window(window_size)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
self.window = window
self.channel = channel
return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super(MSSSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel
def forward(self, img1, img2):
# TODO: store window between calls if possible
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
================================================
FILE: code/real/bsrt/loss/vgg.py
================================================
from model import common
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class VGG(nn.Module):
def __init__(self, conv_index, rgb_range=1):
super(VGG, self).__init__()
vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
if conv_index.find('22') >= 0:
self.vgg = nn.Sequential(*modules[:8])
elif conv_index.find('54') >= 0:
self.vgg = nn.Sequential(*modules[:35])
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
for p in self.parameters():
p.requires_grad = False
def forward(self, sr, hr):
def _forward(x):
# x = self.sub_mean(x)
x = self.vgg(x)
return x
sr = sr.repeat(1, 3, 1, 1)
hr = hr.repeat(1, 3, 1, 1)
vgg_sr = _forward(sr)
with torch.no_grad():
vgg_hr = _forward(hr.detach())
loss = F.mse_loss(vgg_sr, vgg_hr)
return loss
================================================
FILE: code/real/bsrt/main.py
================================================
import torch
import random
import numpy as np
from torch.utils.data import DataLoader
import os
import utility
import model
import loss
from option import args
from trainer import Trainer
from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.utils.data.distributed
def init_seeds(seed=0, cuda_deterministic=True):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
checkpoint = utility.checkpoint(args)
def main():
mp.spawn(main_worker, nprocs=args.n_GPUs, args=(args.n_GPUs, args))
def main_worker(local_rank, nprocs, args):
# print(local_rank)
if checkpoint.ok:
args.local_rank = local_rank
init_seeds(local_rank+1)
cudnn.benchmark = True
utility.setup(local_rank, nprocs)
torch.cuda.set_device(local_rank)
batch_size = int(args.batch_size / nprocs)
train_data = BurstSRDataset(root=args.root,
burst_size=args.burst_size,
crop_sz=args.patch_size, random_flip=True,
center_crop=True, split='train')
valid_data = BurstSRDataset(root=args.root,
burst_size=14,
crop_sz=80, split='val')
if local_rank <= 0:
print(f"train data: {len(train_data)}, test data: {len(valid_data)}")
if nprocs > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data, shuffle=False)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, num_workers=args.batch_size,
pin_memory=True, drop_last=True, sampler=train_sampler) # args.cpus
valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, num_workers=args.batch_size,
pin_memory=True, drop_last=True, sampler=valid_sampler) # args.cpus
else:
train_sampler = None
train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=8,
shuffle=True, pin_memory=True, drop_last=True) # args.cpus
valid_loader = DataLoader(dataset=valid_data, batch_size=args.batch_size, num_workers=4, shuffle=False,
pin_memory=True, drop_last=True) # args.cpus
_model = model.Model(args, checkpoint)
_loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, train_loader, train_sampler, valid_loader, _model, _loss, checkpoint)
while not t.terminate():
t.train()
del _model
del _loss
del train_loader
del valid_loader
# checkpoint.done()
if __name__ == '__main__':
# if not args.cpu: torch.cuda.set_device(0)
main()
================================================
FILE: code/real/bsrt/model/DCNv2/LICENSE
================================================
BSD 3-Clause License
Copyright (c) 2019, Charles Shang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: code/real/bsrt/model/DCNv2/README.md
================================================
## Deformable Convolutional Networks V2 with Pytorch 1.0
### Build
```bash
./make.sh # build
python test.py # run examples and gradient check
```
### An Example
- deformable conv
```python
from dcn_v2 import DCN
input = torch.randn(2, 64, 128, 128).cuda()
# wrap all things (offset and mask) in DCN
dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()
output = dcn(input)
print(output.shape)
```
- deformable roi pooling
```python
from dcn_v2 import DCNPooling
input = torch.randn(2, 32, 64, 64).cuda()
batch_inds = torch.randint(2, (20, 1)).cuda().float()
x = torch.randint(256, (20, 1)).cuda().float()
y = torch.randint(256, (20, 1)).cuda().float()
w = torch.randint(64, (20, 1)).cuda().float()
h = torch.randint(64, (20, 1)).cuda().float()
rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
# mdformable pooling (V2)
# wrap all things (offset and mask) in DCNPooling
dpooling = DCNPooling(spatial_scale=1.0 / 4,
pooled_size=7,
output_dim=32,
no_trans=False,
group_size=1,
trans_std=0.1).cuda()
dout = dpooling(input, rois)
```
### Note
Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,
```bash
git checkout pytorch_0.4
```
### Known Issues:
- [x] Gradient check w.r.t offset (solved)
- [ ] Backward is not reentrant (minor)
This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).
I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.
However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some
non-differential points?
Update: all gradient check passes with double precision.
Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for
float `<1e-15` for double),
so it may not be a serious problem (?)
Please post an issue or PR if you have any comments.
================================================
FILE: code/real/bsrt/model/DCNv2/__init__.py
================================================
================================================
FILE: code/real/bsrt/model/DCNv2/dcn_v2.py
================================================
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import math
import torch
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from torch.cuda.amp import custom_fwd, custom_bwd
# from apex import amp
import _ext as _backend
class _DCNv2(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
# @amp.float_function
def forward(
ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups
):
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.kernel_size = _pair(weight.shape[2:4])
ctx.deformable_groups = deformable_groups
output = _backend.dcn_v2_forward(
input,
weight,
bias,
offset,
mask,
ctx.kernel_size[0],
ctx.kernel_size[1],
ctx.stride[0],
ctx.stride[1],
ctx.padding[0],
ctx.padding[1],
ctx.dilation[0],
ctx.dilation[1],
ctx.deformable_groups,
)
ctx.save_for_backward(input, offset, mask, weight, bias)
return output
@staticmethod
@once_differentiable
@custom_bwd
# @amp.float_function
def backward(ctx, grad_output):
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input, grad_offset, grad_mask, grad_weight, grad_bias = _backend.dcn_v2_backward(
input,
weight,
bias,
offset,
mask,
grad_output,
ctx.kernel_size[0],
ctx.kernel_size[1],
ctx.stride[0],
ctx.stride[1],
ctx.padding[0],
ctx.padding[1],
ctx.dilation[0],
ctx.dilation[1],
ctx.deformable_groups,
)
return grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None
@staticmethod
def symbolic(
g, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups
):
from torch.nn.modules.utils import _pair
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
# as of trt 7, the dcn operation will be translated again by modifying the onnx file
# so the exporting code is kept to resemble the forward()
return g.op(
"DCNv2_2",
input,
offset,
mask,
weight,
bias,
stride_i=stride,
padding_i=padding,
dilation_i=dilation,
deformable_groups_i=deformable_groups,
)
dcn_v2_conv = _DCNv2.apply
class DCNv2(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1,
):
super(DCNv2, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1.0 / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def forward(self, input, offset, mask):
assert (
2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1]
== offset.shape[1]
)
assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1]
return dcn_v2_conv(
input,
offset,
mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups,
)
class DCN(DCNv2):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1,
):
super(DCN, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups
)
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
channels_,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True,
)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return dcn_v2_conv(
input,
offset,
mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups,
)
class DCN_sep(DCNv2):
'''Use other features to generate offsets and masks'''
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1):
super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
dilation, deformable_groups)
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
channels_,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input, fea):
'''input: input features for deformable conv
fea: other features used for generating offsets and mask'''
out = self.conv_offset_mask(fea)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
# offset = torch.clamp(offset, -100, 100)
offset_mean = torch.mean(torch.abs(offset))
if offset_mean > 250:
print('Offset mean is {}, larger than 100.'.format(offset_mean))
# return None
# offset[offset>=150] = 1e-3
# offset = offset.clamp(-50, 50)
mask = torch.sigmoid(mask)
return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class FlowGuidedDCN(DCNv2):
'''Use other features to generate offsets and masks'''
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1):
super(FlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
dilation, deformable_groups)
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset_mask = nn.Conv2d(
in_channels, channels_, kernel_size, stride, padding, bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input, fea, flows):
'''input: input features for deformable conv: N, C, H, W.
fea: other features used for generating offsets and mask: N, C, H, W.
flows: N, 2, H, W.
'''
out = self.conv_offset_mask(fea)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude
offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1)
offset_mean = torch.mean(torch.abs(offset))
if offset_mean > 250:
print('FlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean))
# offset = offset.clamp(-50, 50)
# return None
mask = torch.sigmoid(mask)
return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class InsideFlowGuidedDCN(DCNv2):
'''Use other features to generate offsets and masks'''
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1):
super(InsideFlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
dilation, deformable_groups)
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset_mask = nn.Sequential(
nn.Conv2d(in_channels*2+2, out_channels, kernel_size, stride, padding, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(out_channels, channels_, kernel_size, stride, padding, bias=True)
)
self.reset_parameters()
self.init_offset()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1.0 / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def init_offset(self):
self.conv_offset_mask[-1].weight.data.zero_()
self.conv_offset_mask[-1].bias.data.zero_()
def forward(self, input, warped, ref, flows):
'''input: input features for deformable conv: N, C, H, W.
fea: other features used for generating offsets and mask: N, C, H, W.
flows: N, 2, H, W.
'''
out = self.conv_offset_mask(torch.cat([warped, ref, flows], dim=1))
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude
offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1)
offset_mean = torch.mean(torch.abs(offset))
if offset_mean > 250:
print('InsideFlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean))
print('flow mean is {}'.format(torch.abs(flows).mean()))
offset = offset.clamp(-50, 50)
# return None
mask = torch.sigmoid(mask)
return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class _DCNv2Pooling(Function):
@staticmethod
def forward(
ctx,
input,
rois,
offset,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=0.0,
):
ctx.spatial_scale = spatial_scale
ctx.no_trans = int(no_trans)
ctx.output_dim = output_dim
ctx.group_size = group_size
ctx.pooled_size = pooled_size
ctx.part_size = pooled_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
output, output_count = _backend.dcn_v2_psroi_pooling_forward(
input,
rois,
offset,
ctx.no_trans,
ctx.spatial_scale,
ctx.output_dim,
ctx.group_size,
ctx.pooled_size,
ctx.part_size,
ctx.sample_per_part,
ctx.trans_std,
)
ctx.save_for_backward(input, rois, offset, output_count)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, rois, offset, output_count = ctx.saved_tensors
grad_input, grad_offset = _backend.dcn_v2_psroi_pooling_backward(
grad_output,
input,
rois,
offset,
output_count,
ctx.no_trans,
ctx.spatial_scale,
ctx.output_dim,
ctx.group_size,
ctx.pooled_size,
ctx.part_size,
ctx.sample_per_part,
ctx.trans_std,
)
return grad_input, None, grad_offset, None, None, None, None, None, None, None, None
dcn_v2_pooling = _DCNv2Pooling.apply
class DCNv2Pooling(nn.Module):
def __init__(
self,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=0.0,
):
super(DCNv2Pooling, self).__init__()
self.spatial_scale = spatial_scale
self.pooled_size = pooled_size
self.output_dim = output_dim
self.no_trans = no_trans
self.group_size = group_size
self.part_size = pooled_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
def forward(self, input, rois, offset):
assert input.shape[1] == self.output_dim
if self.no_trans:
offset = input.new()
return dcn_v2_pooling(
input,
rois,
offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
self.no_trans,
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std,
)
class DCNPooling(DCNv2Pooling):
def __init__(
self,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=0.0,
deform_fc_dim=1024,
):
super(DCNPooling, self).__init__(
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size,
part_size,
sample_per_part,
trans_std,
)
self.deform_fc_dim = deform_fc_dim
if not no_trans:
self.offset_mask_fc = nn.Sequential(
nn.Linear(
self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim
),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3),
)
self.offset_mask_fc[4].weight.data.zero_()
self.offset_mask_fc[4].bias.data.zero_()
def forward(self, input, rois):
offset = input.new()
if not self.no_trans:
# do roi_align first
n = rois.shape[0]
roi = dcn_v2_pooling(
input,
rois,
offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
True, # no trans
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std,
)
# build mask and offset
offset_mask = self.offset_mask_fc(roi.view(n, -1))
offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size)
o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
# do pooling with offset and mask
return (
dcn_v2_pooling(
input,
rois,
offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
self.no_trans,
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std,
)
* mask
)
# only roi_align
return dcn_v2_pooling(
input,
rois,
offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
self.no_trans,
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std,
)
================================================
FILE: code/real/bsrt/model/DCNv2/files.txt
================================================
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu.so
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.py
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/PKG-INFO
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/SOURCES.txt
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/dependency_links.txt
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/native_libs.txt
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/not-zip-safe
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/top_level.txt
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/__pycache__/_ext.cpython-37.pyc
================================================
FILE: code/real/bsrt/model/DCNv2/make.sh
================================================
#!/usr/bin/env bash
python setup.py build develop
================================================
FILE: code/real/bsrt/model/DCNv2/setup.py
================================================
#!/usr/bin/env python
import glob
import os
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
requirements = ["torch", "torchvision"]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
os.environ["CC"] = "g++"
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
if True:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
else:
# raise NotImplementedError('Cuda is not available')
pass
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
"_ext",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name="DCNv2",
version="0.1",
author="charlesshang",
url="https://github.com/charlesshang/DCNv2",
description="deformable convolutional networks",
packages=find_packages(exclude=("configs", "tests")),
# install_requires=requirements,
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp
================================================
#include
#include "cpu/dcn_v2_im2col_cpu.h"
#include
//#include
#include
//#include
//#include
//extern THCState *state;
// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
// modified from the CUDA version for CPU use by Daniel K. Suhendro
at::Tensor
dcn_v2_cpu_forward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group)
{
// THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
/*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
// printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
// printf("Channels: %d %d\n", channels, channels_kernel);
// printf("Channels: %d %d\n", channels_out, channels_kernel);
AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
using scalar_t = float;
for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto output_n = output.select(0, b);
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
long m_ = channels_out;
long n_ = height_out * width_out;
long k_ = 1;
THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,
ones.contiguous().data(), k_,
bias.contiguous().data(), k_, 0.0f,
output_n.data(), n_);
modulated_deformable_im2col_cpu(input_n.data(),
offset_n.data(),
mask_n.data(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group,
columns.data());
//(k * m) x (m * n)
// Y = WC
long m = channels_out;
long n = height_out * width_out;
long k = channels * kernel_h * kernel_w;
THFloatBlas_gemm('n', 'n', n, m, k, 1.0f,
columns.data(), n,
weight.data(), k, 1.0f,
output_n.data(), n);
}
return output;
}
std::vector dcn_v2_cpu_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");
/*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
auto grad_input = at::zeros_like(input);
auto grad_weight = at::zeros_like(weight);
auto grad_bias = at::zeros_like(bias);
auto grad_offset = at::zeros_like(offset);
auto grad_mask = at::zeros_like(mask);
using scalar_t = float;
for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto grad_output_n = grad_output.select(0, b);
auto grad_input_n = grad_input.select(0, b);
auto grad_offset_n = grad_offset.select(0, b);
auto grad_mask_n = grad_mask.select(0, b);
long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;
THFloatBlas_gemm('n', 't', n, m, k, 1.0f,
grad_output_n.data(), n,
weight.data(), m, 0.0f,
columns.data(), n);
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cpu(columns.data(),
input_n.data(),
offset_n.data(),
mask_n.data(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset_n.data(),
grad_mask_n.data());
// gradient w.r.t. input data
modulated_deformable_col2im_cpu(columns.data(),
offset_n.data(),
mask_n.data(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input_n.data());
// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cpu(input_n.data(),
offset_n.data(),
mask_n.data(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data());
long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;
THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,
columns.data(), k_,
grad_output_n.data(), k_, 1.0f,
grad_weight.data(), n_);
// gradient w.r.t. bias
// long m_ = channels_out;
// long k__ = height_out * width_out;
// THFloatBlas_gemv('t', k_, m_, 1.0f,
// grad_output_n.data(), k_,
// ones.data(), 1, 1.0f,
// grad_bias.data(), 1);
}
return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
};
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp
================================================
#include "dcn_v2_im2col_cpu.h"
#include
#include
#include
#include
//#include
#include
//#include
//#include
// modified from the CUDA version for CPU use by Daniel K. Suhendro
/*#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}*/
float dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_width,
const int height, const int width, float h, float w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh, hw = 1 - lw;
float v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
float dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
float weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
float dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w,
const int height, const int width, const float *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
float weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
void modulated_deformable_im2col_cpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
float *data_col)
{
// launch channels * batch_size * height_col * width_col cores
for(int index=0; index(0);
const float h_im = h_in + i * dilation_h + offset_h;
const float w_im = w_in + j * dilation_w + offset_w;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
// data_col_ptr += batch_size * height_col * width_col;
data_col_ptr += height_col * width_col;
}
}
}
}
void modulated_deformable_col2im_cpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
float *grad_im)
{
for(int index = 0; index < n; index++)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const float offset_h = data_offset_ptr[data_offset_h_ptr];
const float offset_w = data_offset_ptr[data_offset_w_ptr];
const float mask = data_mask_ptr[data_mask_hw_ptr];
const float cur_inv_h_data = h_in + i * dilation_h + offset_h;
const float cur_inv_w_data = w_in + j * dilation_w + offset_w;
const float cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
float weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
//atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
*(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad;
}
}
}
}
}
void modulated_deformable_col2im_coord_cpu_kernel(const int n, const float *data_col, const float *data_im,
const float *data_offset, const float *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col,
float *grad_offset, float *grad_mask)
{
for(int index = 0; index < n; index++)
{
float val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const float offset_h = data_offset_ptr[data_offset_h_ptr];
const float offset_w = data_offset_ptr[data_offset_w_ptr];
const float mask = data_mask_ptr[data_mask_hw_ptr];
float inv_h = h_in + i * dilation_h + offset_h;
float inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
else
{
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
}
const float weight = dmcn_get_coordinate_weight_cpu(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
}
}
void modulated_deformable_im2col_cpu(const float* data_im, const float* data_offset, const float* data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float* data_col) {
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
modulated_deformable_im2col_cpu_kernel(
num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col);
/*cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}*/
}
void modulated_deformable_col2im_cpu(const float* data_col, const float* data_offset, const float* data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float* grad_im){
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
modulated_deformable_col2im_cpu_kernel(
num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im);
/*cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}*/
}
void modulated_deformable_col2im_coord_cpu(const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
float* grad_offset, float* grad_mask) {
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
modulated_deformable_col2im_coord_cpu_kernel(
num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
grad_offset, grad_mask);
/*cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
}*/
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h
================================================
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer ********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.h
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1811.11168
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
*/
/***************** Adapted by Charles Shang *********************/
// modified from the CUDA version for CPU use by Daniel K. Suhendro
#ifndef DCN_V2_IM2COL_CPU
#define DCN_V2_IM2COL_CPU
#ifdef __cplusplus
extern "C"
{
#endif
void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float *data_col);
void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float *grad_im);
void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
float *grad_offset, float *grad_mask);
#ifdef __cplusplus
}
#endif
#endif
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp
================================================
/*!
* Copyright (c) 2017 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file deformable_psroi_pooling.cu
* \brief
* \author Yi Li, Guodong Zhang, Jifeng Dai
*/
/***************** Adapted by Charles Shang *********************/
// modified from the CUDA version for CPU use by Daniel K. Suhendro
#include
#include
#include
#include
//#include
#include |
//#include
//#include
/*#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}*/
template
T bilinear_interp_cpu(
const T *data,
const T x,
const T y,
const int width,
const int height)
{
int x1 = floor(x);
int x2 = ceil(x);
int y1 = floor(y);
int y2 = ceil(y);
T dist_x = static_cast(x - x1);
T dist_y = static_cast(y - y1);
T value11 = data[y1 * width + x1];
T value12 = data[y2 * width + x1];
T value21 = data[y1 * width + x2];
T value22 = data[y2 * width + x2];
T value = (1 - dist_x) * (1 - dist_y) * value11 +
(1 - dist_x) * dist_y * value12 +
dist_x * (1 - dist_y) * value21 +
dist_x * dist_y * value22;
return value;
}
template
void DeformablePSROIPoolForwardKernelCpu(
const int count,
const T *bottom_data,
const T spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const T *bottom_rois, const T *bottom_trans,
const int no_trans,
const T trans_std,
const int sample_per_part,
const int output_dim,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class,
T *top_data,
T *top_count)
{
for(int index = 0; index < count; index++)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const T *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0
T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast(pooled_height);
T bin_size_w = roi_width / static_cast(pooled_width);
T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
int part_h = floor(static_cast(ph) / pooled_height * part_size);
int part_w = floor(static_cast(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
T wstart = static_cast(pw) * bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
T hstart = static_cast(ph) * bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
T sum = 0;
int count = 0;
int gw = floor(static_cast(pw) * group_size / pooled_width);
int gh = floor(static_cast(ph) * group_size / pooled_height);
gw = std::min(std::max(gw, 0), group_size - 1);
gh = std::min(std::max(gh, 0), group_size - 1);
const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
T w = wstart + iw * sub_bin_size_w;
T h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = std::min(std::max(w, T(0.)), width - T(1.));
h = std::min(std::max(h, T(0.)), height - T(1.));
int c = (ctop * group_size + gh) * group_size + gw;
T val = bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height);
sum += val;
count++;
}
}
top_data[index] = count == 0 ? static_cast(0) : sum / count;
top_count[index] = count;
}
}
template
void DeformablePSROIPoolBackwardAccKernelCpu(
const int count,
const T *top_diff,
const T *top_count,
const int num_rois,
const T spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int output_dim,
T *bottom_data_diff, T *bottom_trans_diff,
const T *bottom_data,
const T *bottom_rois,
const T *bottom_trans,
const int no_trans,
const T trans_std,
const int sample_per_part,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class)
{
for(int index = 0; index < count; index++)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const T *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0
T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast(pooled_height);
T bin_size_w = roi_width / static_cast(pooled_width);
T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
int part_h = floor(static_cast(ph) / pooled_height * part_size);
int part_w = floor(static_cast(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
T wstart = static_cast(pw) * bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
T hstart = static_cast(ph) * bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
if (top_count[index] <= 0)
{
continue;
}
T diff_val = top_diff[index] / top_count[index];
const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
int gw = floor(static_cast(pw) * group_size / pooled_width);
int gh = floor(static_cast(ph) * group_size / pooled_height);
gw = std::min(std::max(gw, 0), group_size - 1);
gh = std::min(std::max(gh, 0), group_size - 1);
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
T w = wstart + iw * sub_bin_size_w;
T h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = std::min(std::max(w, T(0.)), width - T(1.));
h = std::min(std::max(h, T(0.)), height - T(1.));
int c = (ctop * group_size + gh) * group_size + gw;
// backward on feature
int x0 = floor(w);
int x1 = ceil(w);
int y0 = floor(h);
int y1 = ceil(h);
T dist_x = w - x0, dist_y = h - y0;
T q00 = (1 - dist_x) * (1 - dist_y);
T q01 = (1 - dist_x) * dist_y;
T q10 = dist_x * (1 - dist_y);
T q11 = dist_x * dist_y;
int bottom_index_base = c * height * width;
/*atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);*/
*(offset_bottom_data_diff + bottom_index_base + y0 * width + x0) += q00 * diff_val;
*(offset_bottom_data_diff + bottom_index_base + y1 * width + x0) += q01 * diff_val;
*(offset_bottom_data_diff + bottom_index_base + y0 * width + x1) += q10 * diff_val;
*(offset_bottom_data_diff + bottom_index_base + y1 * width + x1) += q11 * diff_val;
if (no_trans)
{
continue;
}
T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
diff_x *= roi_width;
T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
diff_y *= roi_height;
/*atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);*/
*(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w) += diff_x;
*(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w) += diff_y;
}
}
}
}
std::tuple
dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
/*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");*/
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
auto pooled_height = pooled_size;
auto pooled_width = pooled_size;
auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
long out_size = num_bbox * output_dim * pooled_height * pooled_width;
auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
//cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (out.numel() == 0)
{
//THCudaCheck(cudaGetLastError());
return std::make_tuple(out, top_count);
}
/*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
dim3 block(512);*/
AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cpu_forward", [&] {
DeformablePSROIPoolForwardKernelCpu(
out_size,
input.contiguous().data(),
spatial_scale,
channels,
height, width,
pooled_height,
pooled_width,
bbox.contiguous().data(),
trans.contiguous().data(),
no_trans,
trans_std,
sample_per_part,
output_dim,
group_size,
part_size,
num_classes,
channels_each_class,
out.data(),
top_count.data());
});
//THCudaCheck(cudaGetLastError());
return std::make_tuple(out, top_count);
}
std::tuple
dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const at::Tensor &top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
/*AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor");
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor");
AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");*/
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
auto pooled_height = pooled_size;
auto pooled_width = pooled_size;
long out_size = num_bbox * output_dim * pooled_height * pooled_width;
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());
auto trans_grad = at::zeros_like(trans);
if (input_grad.numel() == 0)
{
//THCudaCheck(cudaGetLastError());
return std::make_tuple(input_grad, trans_grad);
}
/*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
dim3 block(512);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/
AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cpu_backward", [&] {
DeformablePSROIPoolBackwardAccKernelCpu(
out_size,
out_grad.contiguous().data(),
top_count.contiguous().data(),
num_bbox,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
output_dim,
input_grad.contiguous().data(),
trans_grad.contiguous().data(),
input.contiguous().data(),
bbox.contiguous().data(),
trans.contiguous().data(),
no_trans,
trans_std,
sample_per_part,
group_size,
part_size,
num_classes,
channels_each_class);
});
//THCudaCheck(cudaGetLastError());
return std::make_tuple(input_grad, trans_grad);
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/vision.h
================================================
#pragma once
#include
at::Tensor
dcn_v2_cpu_forward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group);
std::vector
dcn_v2_cpu_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group);
std::tuple
dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
std::tuple
dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const at::Tensor &top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
================================================
FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_cuda.cu
================================================
#include
#include "cuda/dcn_v2_im2col_cuda.h"
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
THCState *state = at::globalContext().lazyInitCUDA();
static cublasOperation_t _cublasOpFromChar(char op) {
switch (op) {
case 'n':
case 'N':
return CUBLAS_OP_N;
case 't':
case 'T':
return CUBLAS_OP_T;
case 'c':
case 'C':
return CUBLAS_OP_C;
}
AT_ERROR(
"_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
}
static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
// Note: leading dimensions generally are checked that they are > 0
// and at least as big the result requires (even if the value won't
// be used).
// Q: Why does Level3 check trans but this doesn't?
// A: In level 2, the sizes (m, n) specify the size of A
// (independent of trans value). In level 3. the sizes (m, n, k)
// specify the sizes of op(A), op(B) where op depend on trans
// values.
if (n <= 1)
*lda = std::max(m, 1);
}
// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
// [batch gemm]
// https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu
__global__ void createBatchGemmBuffer(const float **input_b, float **output_b,
float **columns_b, const float **ones_b,
const float **weight_b, const float **bias_b,
float *input, float *output,
float *columns, float *ones,
float *weight, float *bias,
const int input_stride, const int output_stride,
const int columns_stride, const int ones_stride,
const int num_batches)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_batches)
{
input_b[idx] = input + idx * input_stride;
output_b[idx] = output + idx * output_stride;
columns_b[idx] = columns + idx * columns_stride;
ones_b[idx] = ones + idx * ones_stride;
// share weights and bias within a Mini-Batch
weight_b[idx] = weight;
bias_b[idx] = bias;
}
}
at::Tensor
dcn_v2_cuda_forward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group)
{
using scalar_t = float;
// THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
// printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
// printf("Channels: %d %d\n", channels, channels_kernel);
// printf("Channels: %d %d\n", channels_out, channels_kernel);
AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
auto ones = at::ones({batch, height_out, width_out}, input.options());
auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
// prepare for batch-wise computing, which is significantly faster than instance-wise computing
// when batch size is large.
// launch batch threads
int matrices_size = batch * sizeof(float *);
auto input_b = static_cast(THCudaMalloc(state, matrices_size));
auto output_b = static_cast(THCudaMalloc(state, matrices_size));
auto columns_b = static_cast(THCudaMalloc(state, matrices_size));
auto ones_b = static_cast(THCudaMalloc(state, matrices_size));
auto weight_b = static_cast(THCudaMalloc(state, matrices_size));
auto bias_b = static_cast(THCudaMalloc(state, matrices_size));
const int block = 128;
const int grid = (batch + block - 1) / block;
createBatchGemmBuffer<<>>(
input_b, output_b,
columns_b, ones_b,
weight_b, bias_b,
input.data_ptr(),
output.data_ptr(),
columns.data_ptr(),
ones.data_ptr(),
weight.data_ptr(),
bias.data_ptr(),
channels * width * height,
channels_out * width_out * height_out,
channels * kernel_h * kernel_w * height_out * width_out,
height_out * width_out,
batch);
long m_ = channels_out;
long n_ = height_out * width_out;
long k_ = 1;
THCudaBlas_SgemmBatched(state,
't',
'n',
n_,
m_,
k_,
1.0f,
ones_b, k_,
bias_b, k_,
0.0f,
output_b, n_,
batch);
modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),
input.data_ptr(),
offset.data_ptr(),
mask.data_ptr(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group,
columns.data_ptr());
long m = channels_out;
long n = height_out * width_out;
long k = channels * kernel_h * kernel_w;
THCudaBlas_SgemmBatched(state,
'n',
'n',
n,
m,
k,
1.0f,
(const float **)columns_b, n,
weight_b, k,
1.0f,
output_b, n,
batch);
THCudaFree(state, input_b);
THCudaFree(state, output_b);
THCudaFree(state, columns_b);
THCudaFree(state, ones_b);
THCudaFree(state, weight_b);
THCudaFree(state, bias_b);
return output;
}
__global__ void createBatchGemmBufferBackward(
float **grad_output_b,
float **columns_b,
float **ones_b,
float **weight_b,
float **grad_weight_b,
float **grad_bias_b,
float *grad_output,
float *columns,
float *ones,
float *weight,
float *grad_weight,
float *grad_bias,
const int grad_output_stride,
const int columns_stride,
const int ones_stride,
const int num_batches)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_batches)
{
grad_output_b[idx] = grad_output + idx * grad_output_stride;
columns_b[idx] = columns + idx * columns_stride;
ones_b[idx] = ones + idx * ones_stride;
// share weights and bias within a Mini-Batch
weight_b[idx] = weight;
grad_weight_b[idx] = grad_weight;
grad_bias_b[idx] = grad_bias;
}
}
std::vector dcn_v2_cuda_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
auto grad_input = at::zeros_like(input);
auto grad_weight = at::zeros_like(weight);
auto grad_bias = at::zeros_like(bias);
auto grad_offset = at::zeros_like(offset);
auto grad_mask = at::zeros_like(mask);
using scalar_t = float;
for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto grad_output_n = grad_output.select(0, b);
auto grad_input_n = grad_input.select(0, b);
auto grad_offset_n = grad_offset.select(0, b);
auto grad_mask_n = grad_mask.select(0, b);
long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;
THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,
grad_output_n.data_ptr(), n,
weight.data_ptr(), m, 0.0f,
columns.data_ptr(), n);
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(),
columns.data_ptr(),
input_n.data_ptr(),
offset_n.data_ptr(),
mask_n.data_ptr(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset_n.data_ptr(),
grad_mask_n.data_ptr());
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(),
columns.data_ptr(),
offset_n.data_ptr(),
mask_n.data_ptr(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input_n.data_ptr());
// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),
input_n.data_ptr(),
offset_n.data_ptr(),
mask_n.data_ptr(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data_ptr());
long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;
THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,
columns.data_ptr(), k_,
grad_output_n.data_ptr(), k_, 1.0f,
grad_weight.data_ptr(), n_);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar('t');
_cublasAdjustLdLevel2(k_, m_, &k_);
scalar_t* grad_output_n_float = grad_output_n.data_ptr();
scalar_t* one_float = ones.data_ptr();
scalar_t alpha = 1.0;
scalar_t beta = 1.0;
cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr(), 1);
}
return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
};
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu
================================================
#include "dcn_v2_im2col_cuda.h"
#include
#include
#include
#include
#include
#include
#include
#include
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
__device__ float dmcn_im2col_bilinear_cuda(const float *bottom_data, const int data_width,
const int height, const int width, float h, float w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh, hw = 1 - lw;
float v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
__device__ float dmcn_get_gradient_weight_cuda(float argmax_h, float argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
float weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
__device__ float dmcn_get_coordinate_weight_cuda(float argmax_h, float argmax_w,
const int height, const int width, const float *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
float weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
const float *data_im, const float *data_offset, const float *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
float *data_col)
{
// launch channels * batch_size * height_col * width_col cores
CUDA_KERNEL_LOOP(index, n)
{
// NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)
// here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
// const int b_col = (index / width_col / height_col) % batch_size;
const int b_col = (index / width_col / height_col / num_channels) % batch_size;
// const int c_im = (index / width_col / height_col) / batch_size;
const int c_im = (index / width_col / height_col) % num_channels;
// const int c_col = c_im * kernel_h * kernel_w;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
// float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col;
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const float offset_h = data_offset_ptr[data_offset_h_ptr];
const float offset_w = data_offset_ptr[data_offset_w_ptr];
const float mask = data_mask_ptr[data_mask_hw_ptr];
float val = static_cast(0);
const float h_im = h_in + i * dilation_h + offset_h;
const float w_im = w_in + j * dilation_w + offset_w;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
// data_col_ptr += batch_size * height_col * width_col;
data_col_ptr += height_col * width_col;
}
}
}
}
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
const float *data_col, const float *data_offset, const float *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
float *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const float offset_h = data_offset_ptr[data_offset_h_ptr];
const float offset_w = data_offset_ptr[data_offset_w_ptr];
const float mask = data_mask_ptr[data_mask_hw_ptr];
const float cur_inv_h_data = h_in + i * dilation_h + offset_h;
const float cur_inv_w_data = w_in + j * dilation_w + offset_w;
const float cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
float weight = dmcn_get_gradient_weight_cuda(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const float *data_col, const float *data_im,
const float *data_offset, const float *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col,
float *grad_offset, float *grad_mask)
{
CUDA_KERNEL_LOOP(index, n)
{
float val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const float offset_h = data_offset_ptr[data_offset_h_ptr];
const float offset_w = data_offset_ptr[data_offset_w_ptr];
const float mask = data_mask_ptr[data_mask_hw_ptr];
float inv_h = h_in + i * dilation_h + offset_h;
float inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
else
{
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cuda(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
}
const float weight = dmcn_get_coordinate_weight_cuda(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
}
}
void modulated_deformable_im2col_cuda(cudaStream_t stream,
const float* data_im, const float* data_offset, const float* data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float* data_col) {
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
modulated_deformable_im2col_gpu_kernel
<<>>(
num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_cuda(cudaStream_t stream,
const float* data_col, const float* data_offset, const float* data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float* grad_im){
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
modulated_deformable_col2im_gpu_kernel
<<>>(
num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
float* grad_offset, float* grad_mask) {
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
modulated_deformable_col2im_coord_gpu_kernel
<<>>(
num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
grad_offset, grad_mask);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
}
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.h
================================================
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer ********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.h
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1811.11168
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
*/
/***************** Adapted by Charles Shang *********************/
#ifndef DCN_V2_IM2COL_CUDA
#define DCN_V2_IM2COL_CUDA
#ifdef __cplusplus
extern "C"
{
#endif
void modulated_deformable_im2col_cuda(cudaStream_t stream,
const float *data_im, const float *data_offset, const float *data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float *data_col);
void modulated_deformable_col2im_cuda(cudaStream_t stream,
const float *data_col, const float *data_offset, const float *data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float *grad_im);
void modulated_deformable_col2im_coord_cuda(cudaStream_t stream,
const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
float *grad_offset, float *grad_mask);
#ifdef __cplusplus
}
#endif
#endif
================================================
FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu
================================================
/*!
* Copyright (c) 2017 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file deformable_psroi_pooling.cu
* \brief
* \author Yi Li, Guodong Zhang, Jifeng Dai
*/
/***************** Adapted by Charles Shang *********************/
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template
__device__ T bilinear_interp_cuda(
const T *data,
const T x,
const T y,
const int width,
const int height)
{
int x1 = floor(x);
int x2 = ceil(x);
int y1 = floor(y);
int y2 = ceil(y);
T dist_x = static_cast(x - x1);
T dist_y = static_cast(y - y1);
T value11 = data[y1 * width + x1];
T value12 = data[y2 * width + x1];
T value21 = data[y1 * width + x2];
T value22 = data[y2 * width + x2];
T value = (1 - dist_x) * (1 - dist_y) * value11 +
(1 - dist_x) * dist_y * value12 +
dist_x * (1 - dist_y) * value21 +
dist_x * dist_y * value22;
return value;
}
template
__global__ void DeformablePSROIPoolForwardKernelCuda(
const int count,
const T *bottom_data,
const T spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const T *bottom_rois, const T *bottom_trans,
const int no_trans,
const T trans_std,
const int sample_per_part,
const int output_dim,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class,
T *top_data,
T *top_count)
{
CUDA_KERNEL_LOOP(index, count)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const T *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
T roi_height = max(roi_end_h - roi_start_h, 0.1);
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast(pooled_height);
T bin_size_w = roi_width / static_cast(pooled_width);
T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
int part_h = floor(static_cast(ph) / pooled_height * part_size);
int part_w = floor(static_cast(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
T wstart = static_cast(pw) * bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
T hstart = static_cast(ph) * bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
T sum = 0;
int count = 0;
int gw = floor(static_cast(pw) * group_size / pooled_width);
int gh = floor(static_cast(ph) * group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
T w = wstart + iw * sub_bin_size_w;
T h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = min(max(w, 0.), width - 1.);
h = min(max(h, 0.), height - 1.);
int c = (ctop * group_size + gh) * group_size + gw;
T val = bilinear_interp_cuda(offset_bottom_data + c * height * width, w, h, width, height);
sum += val;
count++;
}
}
top_data[index] = count == 0 ? static_cast(0) : sum / count;
top_count[index] = count;
}
}
template
__global__ void DeformablePSROIPoolBackwardAccKernelCuda(
const int count,
const T *top_diff,
const T *top_count,
const int num_rois,
const T spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int output_dim,
T *bottom_data_diff, T *bottom_trans_diff,
const T *bottom_data,
const T *bottom_rois,
const T *bottom_trans,
const int no_trans,
const T trans_std,
const int sample_per_part,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class)
{
CUDA_KERNEL_LOOP(index, count)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const T *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
T roi_height = max(roi_end_h - roi_start_h, 0.1);
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast(pooled_height);
T bin_size_w = roi_width / static_cast(pooled_width);
T sub_bin_size_h = bin_size_h / static_cast(sample_per_part);
T sub_bin_size_w = bin_size_w / static_cast(sample_per_part);
int part_h = floor(static_cast(ph) / pooled_height * part_size);
int part_w = floor(static_cast(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
T wstart = static_cast(pw) * bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
T hstart = static_cast(ph) * bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
if (top_count[index] <= 0)
{
continue;
}
T diff_val = top_diff[index] / top_count[index];
const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
int gw = floor(static_cast(pw) * group_size / pooled_width);
int gh = floor(static_cast(ph) * group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
T w = wstart + iw * sub_bin_size_w;
T h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = min(max(w, 0.), width - 1.);
h = min(max(h, 0.), height - 1.);
int c = (ctop * group_size + gh) * group_size + gw;
// backward on feature
int x0 = floor(w);
int x1 = ceil(w);
int y0 = floor(h);
int y1 = ceil(h);
T dist_x = w - x0, dist_y = h - y0;
T q00 = (1 - dist_x) * (1 - dist_y);
T q01 = (1 - dist_x) * dist_y;
T q10 = dist_x * (1 - dist_y);
T q11 = dist_x * dist_y;
int bottom_index_base = c * height * width;
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
if (no_trans)
{
continue;
}
T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
diff_x *= roi_width;
T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
diff_y *= roi_height;
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
}
}
}
}
std::tuple
dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
auto pooled_height = pooled_size;
auto pooled_width = pooled_size;
auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
long out_size = num_bbox * output_dim * pooled_height * pooled_width;
auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (out.numel() == 0)
{
THCudaCheck(cudaGetLastError());
return std::make_tuple(out, top_count);
}
dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
dim3 block(512);
AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] {
DeformablePSROIPoolForwardKernelCuda<<>>(
out_size,
input.contiguous().data_ptr(),
spatial_scale,
channels,
height, width,
pooled_height,
pooled_width,
bbox.contiguous().data_ptr(),
trans.contiguous().data_ptr(),
no_trans,
trans_std,
sample_per_part,
output_dim,
group_size,
part_size,
num_classes,
channels_each_class,
out.data_ptr(),
top_count.data_ptr());
});
THCudaCheck(cudaGetLastError());
return std::make_tuple(out, top_count);
}
std::tuple
dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const at::Tensor &top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor");
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor");
AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
auto pooled_height = pooled_size;
auto pooled_width = pooled_size;
long out_size = num_bbox * output_dim * pooled_height * pooled_width;
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());
auto trans_grad = at::zeros_like(trans);
if (input_grad.numel() == 0)
{
THCudaCheck(cudaGetLastError());
return std::make_tuple(input_grad, trans_grad);
}
dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
dim3 block(512);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] {
DeformablePSROIPoolBackwardAccKernelCuda<<>>(
out_size,
out_grad.contiguous().data_ptr(),
top_count.contiguous().data_ptr(),
num_bbox,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
output_dim,
input_grad.contiguous().data_ptr(),
trans_grad.contiguous().data_ptr(),
input.contiguous().data_ptr(),
bbox.contiguous().data_ptr(),
trans.contiguous().data_ptr(),
no_trans,
trans_std,
sample_per_part,
group_size,
part_size,
num_classes,
channels_each_class);
});
THCudaCheck(cudaGetLastError());
return std::make_tuple(input_grad, trans_grad);
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cuda/vision.h
================================================
#pragma once
#include
#include
at::Tensor
dcn_v2_cuda_forward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group);
std::vector
dcn_v2_cuda_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group);
std::tuple
dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
std::tuple
dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad,
const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const at::Tensor &top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
================================================
FILE: code/real/bsrt/model/DCNv2/src/dcn_v2.h
================================================
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
at::Tensor
dcn_v2_forward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group)
{
if (input.type().is_cuda())
{
#ifdef WITH_CUDA
return dcn_v2_cuda_forward(input, weight, bias, offset, mask,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
deformable_group);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
else{
return dcn_v2_cpu_forward(input, weight, bias, offset, mask,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
deformable_group);
}
}
std::vector
dcn_v2_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
if (input.type().is_cuda())
{
#ifdef WITH_CUDA
return dcn_v2_cuda_backward(input,
weight,
bias,
offset,
mask,
grad_output,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
deformable_group);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
else{
return dcn_v2_cpu_backward(input,
weight,
bias,
offset,
mask,
grad_output,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
deformable_group);
}
}
std::tuple
dcn_v2_psroi_pooling_forward(const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
if (input.type().is_cuda())
{
#ifdef WITH_CUDA
return dcn_v2_psroi_pooling_cuda_forward(input,
bbox,
trans,
no_trans,
spatial_scale,
output_dim,
group_size,
pooled_size,
part_size,
sample_per_part,
trans_std);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
else{
return dcn_v2_psroi_pooling_cpu_forward(input,
bbox,
trans,
no_trans,
spatial_scale,
output_dim,
group_size,
pooled_size,
part_size,
sample_per_part,
trans_std);
}
}
std::tuple
dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad,
const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const at::Tensor &top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
if (input.type().is_cuda())
{
#ifdef WITH_CUDA
return dcn_v2_psroi_pooling_cuda_backward(out_grad,
input,
bbox,
trans,
top_count,
no_trans,
spatial_scale,
output_dim,
group_size,
pooled_size,
part_size,
sample_per_part,
trans_std);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
else{
return dcn_v2_psroi_pooling_cpu_backward(out_grad,
input,
bbox,
trans,
top_count,
no_trans,
spatial_scale,
output_dim,
group_size,
pooled_size,
part_size,
sample_per_part,
trans_std);
}
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/vision.cpp
================================================
#include "dcn_v2.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward");
m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward");
m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward");
m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward");
}
================================================
FILE: code/real/bsrt/model/DCNv2/test.py
================================================
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
from torch.autograd import gradcheck
from dcn_v2 import dcn_v2_conv, DCNv2, DCN
from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling
deformable_groups = 1
N, inC, inH, inW = 2, 2, 4, 4
outC = 2
kH, kW = 3, 3
def conv_identify(weight, bias):
weight.data.zero_()
bias.data.zero_()
o, i, h, w = weight.shape
y = h//2
x = w//2
for p in range(i):
for q in range(o):
if p == q:
weight.data[q, p, y, x] = 1.0
def check_zero_offset():
conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW,
kernel_size=(kH, kW),
stride=(1, 1),
padding=(1, 1),
bias=True).cuda()
conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW,
kernel_size=(kH, kW),
stride=(1, 1),
padding=(1, 1),
bias=True).cuda()
dcn_v2 = DCNv2(inC, outC, (kH, kW),
stride=1, padding=1, dilation=1,
deformable_groups=deformable_groups).cuda()
conv_offset.weight.data.zero_()
conv_offset.bias.data.zero_()
conv_mask.weight.data.zero_()
conv_mask.bias.data.zero_()
conv_identify(dcn_v2.weight, dcn_v2.bias)
input = torch.randn(N, inC, inH, inW).cuda()
offset = conv_offset(input)
mask = conv_mask(input)
mask = torch.sigmoid(mask)
output = dcn_v2(input, offset, mask)
output *= 2
d = (input - output).abs().max()
if d < 1e-10:
print('Zero offset passed')
else:
print('Zero offset failed')
print(input)
print(output)
def check_gradient_dconv():
input = torch.rand(N, inC, inH, inW).cuda() * 0.01
input.requires_grad = True
offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2
# offset.data.zero_()
# offset.data -= 0.5
offset.requires_grad = True
mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda()
# mask.data.zero_()
mask.requires_grad = True
mask = torch.sigmoid(mask)
weight = torch.randn(outC, inC, kH, kW).cuda()
weight.requires_grad = True
bias = torch.rand(outC).cuda()
bias.requires_grad = True
stride = 1
padding = 1
dilation = 1
print('check_gradient_dconv: ',
gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias,
stride, padding, dilation, deformable_groups),
eps=1e-3, atol=1e-4, rtol=1e-2))
def check_pooling_zero_offset():
input = torch.randn(2, 16, 64, 64).cuda().zero_()
input[0, :, 16:26, 16:26] = 1.
input[1, :, 10:20, 20:30] = 2.
rois = torch.tensor([
[0, 65, 65, 103, 103],
[1, 81, 41, 119, 79],
]).cuda().float()
pooling = DCNv2Pooling(spatial_scale=1.0 / 4,
pooled_size=7,
output_dim=16,
no_trans=True,
group_size=1,
trans_std=0.0).cuda()
out = pooling(input, rois, input.new())
s = ', '.join(['%f' % out[i, :, :, :].mean().item()
for i in range(rois.shape[0])])
print(s)
dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,
pooled_size=7,
output_dim=16,
no_trans=False,
group_size=1,
trans_std=0.0).cuda()
offset = torch.randn(20, 2, 7, 7).cuda().zero_()
dout = dpooling(input, rois, offset)
s = ', '.join(['%f' % dout[i, :, :, :].mean().item()
for i in range(rois.shape[0])])
print(s)
def check_gradient_dpooling():
input = torch.randn(2, 3, 5, 5).cuda() * 0.01
N = 4
batch_inds = torch.randint(2, (N, 1)).cuda().float()
x = torch.rand((N, 1)).cuda().float() * 15
y = torch.rand((N, 1)).cuda().float() * 15
w = torch.rand((N, 1)).cuda().float() * 10
h = torch.rand((N, 1)).cuda().float() * 10
rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
offset = torch.randn(N, 2, 3, 3).cuda()
input.requires_grad = True
offset.requires_grad = True
spatial_scale = 1.0 / 4
pooled_size = 3
output_dim = 3
no_trans = 0
group_size = 1
trans_std = 0.0
sample_per_part = 4
part_size = pooled_size
print('check_gradient_dpooling:',
gradcheck(dcn_v2_pooling, (input, rois, offset,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size,
part_size,
sample_per_part,
trans_std),
eps=1e-4))
def example_dconv():
input = torch.randn(2, 64, 128, 128).cuda()
# wrap all things (offset and mask) in DCN
dcn = DCN(64, 64, kernel_size=(3, 3), stride=1,
padding=1, deformable_groups=2).cuda()
# print(dcn.weight.shape, input.shape)
output = dcn(input)
targert = output.new(*output.size())
targert.data.uniform_(-0.01, 0.01)
error = (targert - output).mean()
error.backward()
print(output.shape)
def example_dpooling():
input = torch.randn(2, 32, 64, 64).cuda()
batch_inds = torch.randint(2, (20, 1)).cuda().float()
x = torch.randint(256, (20, 1)).cuda().float()
y = torch.randint(256, (20, 1)).cuda().float()
w = torch.randint(64, (20, 1)).cuda().float()
h = torch.randint(64, (20, 1)).cuda().float()
rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
offset = torch.randn(20, 2, 7, 7).cuda()
input.requires_grad = True
offset.requires_grad = True
# normal roi_align
pooling = DCNv2Pooling(spatial_scale=1.0 / 4,
pooled_size=7,
output_dim=32,
no_trans=True,
group_size=1,
trans_std=0.1).cuda()
# deformable pooling
dpooling = DCNv2Pooling(spatial_scale=1.0 / 4,
pooled_size=7,
output_dim=32,
no_trans=False,
group_size=1,
trans_std=0.1).cuda()
out = pooling(input, rois, offset)
dout = dpooling(input, rois, offset)
print(out.shape)
print(dout.shape)
target_out = out.new(*out.size())
target_out.data.uniform_(-0.01, 0.01)
target_dout = dout.new(*dout.size())
target_dout.data.uniform_(-0.01, 0.01)
e = (target_out - out).mean()
e.backward()
e = (target_dout - dout).mean()
e.backward()
def example_mdpooling():
input = torch.randn(2, 32, 64, 64).cuda()
input.requires_grad = True
batch_inds = torch.randint(2, (20, 1)).cuda().float()
x = torch.randint(256, (20, 1)).cuda().float()
y = torch.randint(256, (20, 1)).cuda().float()
w = torch.randint(64, (20, 1)).cuda().float()
h = torch.randint(64, (20, 1)).cuda().float()
rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
# mdformable pooling (V2)
dpooling = DCNPooling(spatial_scale=1.0 / 4,
pooled_size=7,
output_dim=32,
no_trans=False,
group_size=1,
trans_std=0.1,
deform_fc_dim=1024).cuda()
dout = dpooling(input, rois)
target = dout.new(*dout.size())
target.data.uniform_(-0.1, 0.1)
error = (target - dout).mean()
error.backward()
print(dout.shape)
if __name__ == '__main__':
example_dconv()
example_dpooling()
example_mdpooling()
check_pooling_zero_offset()
# zero offset check
if inC == outC:
check_zero_offset()
check_gradient_dpooling()
check_gradient_dconv()
# """
# ****** Note: backward is not reentrant error may not be a serious problem,
# ****** since the max error is less than 1e-7,
# ****** Still looking for what trigger this problem
# """
================================================
FILE: code/real/bsrt/model/__init__.py
================================================
import os
from importlib import import_module
import torch
import torch.nn as nn
import torch.nn.parallel as P
import torch.utils.model_zoo
import time
class Model(nn.Module):
def __init__(self, args, ckp):
super(Model, self).__init__()
self.args = args
if args.local_rank == 0:
print("Making model: ", args.model)
print("Patch size: ", args.patch_size)
self.scale = args.scale
self.idx_scale = 0
self.input_large = (args.model == 'VDSR')
self.self_ensemble = args.self_ensemble
self.chop = args.chop
self.precision = args.precision
self.cpu = args.cpu
self.device = torch.device('cpu' if args.cpu else 'cuda:%d' % args.local_rank)
self.n_GPUs = args.n_GPUs
self.save_models = args.save_models
module = import_module('model.' + args.model.lower())
self.model = module.make_model(args).to(self.device)
if args.precision == 'half':
self.model.half()
self.load(
ckp.get_path('model'),
pre_train=args.pre_train,
resume=args.resume,
cpu=args.cpu
)
# time.sleep(3)
if args.n_GPUs > 1:
self.model = nn.parallel.DistributedDataParallel(self.model,
device_ids=[args.local_rank],
find_unused_parameters=True
)
print(self.model, file=ckp.log_file)
def forward(self, x, idx_scale):
self.idx_scale = idx_scale
if hasattr(self.model, 'set_scale'):
self.model.set_scale(idx_scale)
if self.training:
# if self.n_GPUs > 1:
return self.model(x)
else:
if self.chop:
forward_function = self.forward_chop
else:
forward_function = self.model.forward
if self.self_ensemble:
return self.forward_x8(x, forward_function=forward_function)
else:
# return self.model(x)
return forward_function(x)
def save(self, apath, epoch, is_best=False):
save_dirs = [os.path.join(apath, 'model_latest.pt')]
if is_best:
save_dirs.append(os.path.join(apath, 'model_best.pt'))
if self.save_models:
save_dirs.append(
os.path.join(apath, 'model_{}.pt'.format(epoch))
)
if self.n_GPUs > 1:
model = self.model.module
else:
model = self.model
for s in save_dirs:
torch.save(self.model.state_dict(), s)
def load(self, apath, pre_train='', resume=-1, cpu=False):
load_from = None
kwargs = {}
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
if resume == -1:
load_from = torch.load(
os.path.join(apath, 'model_latest.pt'),
**kwargs
)
elif resume == 0:
if pre_train == 'download':
print('Download the model')
dir_model = os.path.join('..', 'models')
os.makedirs(dir_model, exist_ok=True)
load_from = torch.utils.model_zoo.load_url(
self.model.url,
model_dir=dir_model,
**kwargs
)
elif pre_train:
if self.args.local_rank == 0:
print('Load the model from {}'.format(pre_train))
map_location = {'cuda:%d' % 0: 'cuda:%d' % self.args.local_rank}
load_from = torch.load(pre_train, map_location=map_location)
# print(load_from.keys())
else:
load_from = torch.load(
os.path.join(apath, 'model_{}.pt'.format(resume)),
**kwargs
)
if load_from:
self.model.load_state_dict(load_from, strict=True)
del load_from
if self.args.finetune:
if self.args.local_rank == 0:
print('finetune')
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.HRconv.parameters():
param.requires_grad = True
for param in self.model.conv_last.parameters():
param.requires_grad = True
if self.args.finetune_prelayer:
if self.args.local_rank == 0:
print('finetune_prelayer')
if self.args.swinfeature:
if self.args.model == 'MBSRT':
for param in self.model.pre_layer1.parameters():
param.requires_grad = True
for param in self.model.pre_layer2.parameters():
param.requires_grad = True
else:
for param in self.model.pre_layers.parameters():
param.requires_grad = True
else:
for param in self.model.feature_extraction.parameters():
param.requires_grad = True
for param in self.model.conv_after_pre_layer.parameters():
param.requires_grad = True
if self.args.finetune_align:
if self.args.local_rank == 0:
print('finetune_align')
for param in self.model.align.parameters():
param.requires_grad = True
if self.args.finetune_spynet:
if self.args.local_rank == 0:
print('finetune_spynet')
for param in self.model.spynet.parameters():
param.requires_grad = True
if self.args.finetune_swin:
if self.args.local_rank == 0:
print('finetune_swin')
for param in self.model.layers.parameters():
param.requires_grad = True
for param in self.model.conv_after_body.parameters():
param.requires_grad = True
if self.args.finetune_upconv:
if self.args.local_rank == 0:
print('finetune_upconv')
for param in self.model.upconv1.parameters():
param.requires_grad = True
for param in self.model.upconv2.parameters():
param.requires_grad = True
for param in self.model.skipup1.parameters():
param.requires_grad = True
for param in self.model.skipup2.parameters():
param.requires_grad = True
if self.args.finetune_conv:
if self.args.local_rank == 0:
print('finetune_conv')
# for param in self.model.conv_first.parameters():
# param.requires_grad = True
# for param in self.model.conv_flow.parameters():
# param.requires_grad = True
# for param in self.model.fea_L2_conv1.parameters():
# param.requires_grad = True
# for param in self.model.fea_L3_conv1.parameters():
# param.requires_grad = True
# for param in self.model.toplayer.parameters():
# param.requires_grad = True
# for param in self.model.smooth1.parameters():
# param.requires_grad = True
# for param in self.model.smooth2.parameters():
# param.requires_grad = True
# for param in self.model.latlayer1.parameters():
# param.requires_grad = True
# for param in self.model.latlayer2.parameters():
# param.requires_grad = True
# for param in self.model.fusion.parameters():
# param.requires_grad = True
# for param in self.model.conv_after_pre_layer.parameters():
# param.requires_grad = True
for param in self.model.conv_after_body.parameters():
param.requires_grad = True
def forward_chop(self, *args, shave=10, min_size=160000):
scale = 1 if self.input_large else self.scale[self.idx_scale]
n_GPUs = min(self.n_GPUs, 4)
# height, width
h, w = args[0].size()[-2:]
top = slice(0, h//2 + shave)
bottom = slice(h - h//2 - shave, h)
left = slice(0, w//2 + shave)
right = slice(w - w//2 - shave, w)
x_chops = [torch.cat([
a[..., top, left],
a[..., top, right],
a[..., bottom, left],
a[..., bottom, right]
]) for a in args]
y_chops = []
if h * w < 4 * min_size:
for i in range(0, 4, n_GPUs):
x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]
y = P.data_parallel(self.model, *x, range(n_GPUs))
if not isinstance(y, list): y = [y]
if not y_chops:
y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]
else:
for y_chop, _y in zip(y_chops, y):
y_chop.extend(_y.chunk(n_GPUs, dim=0))
else:
for p in zip(*x_chops):
y = self.forward_chop(*p, shave=shave, min_size=min_size)
if not isinstance(y, list): y = [y]
if not y_chops:
y_chops = [[_y] for _y in y]
else:
for y_chop, _y in zip(y_chops, y): y_chop.append(_y)
h *= scale
w *= scale
top = slice(0, h//2)
bottom = slice(h - h//2, h)
bottom_r = slice(h//2 - h, None)
left = slice(0, w//2)
right = slice(w - w//2, w)
right_r = slice(w//2 - w, None)
# batch size, number of color channels
b, c = y_chops[0][0].size()[:-2]
y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]
for y_chop, _y in zip(y_chops, y):
_y[..., top, left] = y_chop[0][..., top, left]
_y[..., top, right] = y_chop[1][..., top, right_r]
_y[..., bottom, left] = y_chop[2][..., bottom_r, left]
_y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]
if len(y) == 1: y = y[0]
return y
def forward_x8(self, *args, forward_function=None):
def _transform(v, op):
if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).to(self.device)
if self.precision == 'half': ret = ret.half()
return ret
list_x = []
for a in args:
x = [a]
for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])
list_x.append(x)
list_y = []
for x in zip(*list_x):
y = forward_function(*x)
if not isinstance(y, list): y = [y]
if not list_y:
list_y = [[_y] for _y in y]
else:
for _list_y, _y in zip(list_y, y): _list_y.append(_y)
for _list_y in list_y:
for i in range(len(_list_y)):
if i > 3:
_list_y[i] = _transform(_list_y[i], 't')
if i % 4 > 1:
_list_y[i] = _transform(_list_y[i], 'h')
if (i % 4) % 2 == 1:
_list_y[i] = _transform(_list_y[i], 'v')
y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]
if len(y) == 1: y = y[0]
return y
================================================
FILE: code/real/bsrt/model/arch_util.py
================================================
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from model import common
from model.utils.psconv import PSGConv2d as PSConv2d, PyConv2d
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
###########################
def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0):
return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True)
class ESA(nn.Module):
def __init__(self, n_feats, conv=conv_layer):
super(ESA, self).__init__()
f = n_feats // 4
self.conv1 = conv(n_feats, f, kernel_size=1)
self.conv_f = conv(f, f, kernel_size=1)
self.conv_max = conv(f, f, kernel_size=3, padding=1)
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
self.conv3 = conv(f, f, kernel_size=3, padding=1)
self.conv3_ = conv(f, f, kernel_size=3, padding=1)
self.conv4 = conv(f, n_feats, kernel_size=1)
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
c1_ = (self.conv1(x))
c1 = self.conv2(c1_)
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
v_range = self.relu(self.conv_max(v_max))
c3 = self.relu(self.conv3(v_range))
c3 = self.conv3_(c3)
c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False)
cf = self.conv_f(c1_)
c4 = self.conv4(c3+cf)
m = self.sigmoid(c4)
return x * m
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x):
x = self.dwconv(x)
return x
##########################
class SELayer(nn.Module):
'''
SE-block
'''
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
# nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ResidualBlock_noBN(nn.Module):
'''Residual block w/o BN
---Conv-ReLU-Conv-+-
|________________|
'''
def __init__(self, nf=64):
super(ResidualBlock_noBN, self).__init__()
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
# initialization
initialize_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = F.relu(self.conv1(x), inplace=True)
out = self.conv2(out)
return identity + out
class ResidualBlock_SE(nn.Module):
'''Residual block w/o BN
---Conv-ReLU-Conv-+-
|________________|
'''
def __init__(self, nf=64, reduction=16):
super(ResidualBlock_SE, self).__init__()
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv3 = nn.Conv2d(3 * nf, nf, 1, padding=0, dilation=1, bias=True)
self.se = SELayer(nf, reduction)
# initialization
initialize_weights([self.conv1, self.conv2, self.conv3], 0.1)
def forward(self, x):
identity = x
basic_out = F.relu(self.conv1(x), inplace=True)
basic_out = self.conv2(basic_out)
se_out = self.se(basic_out)
out = torch.cat((identity, basic_out, se_out), 1)
out = self.conv3(out)
return out
class _PositionAttentionModule(nn.Module):
""" Position attention module"""
def __init__(self, in_channels, **kwargs):
super(_PositionAttentionModule, self).__init__()
self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
self.alpha = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, _, height, width = x.size()
feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
feat_c = self.conv_c(x).view(batch_size, -1, height * width)
attention_s = self.softmax(torch.bmm(feat_b, feat_c))
feat_d = self.conv_d(x).view(batch_size, -1, height * width)
feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
out = self.alpha * feat_e + x
return out
## Spatial Attention (CA) Layer
class SALayer(nn.Module):
def __init__(self, wn=None):
super(SALayer,self).__init__()
self.body = nn.Sequential(
wn(nn.Conv2d(2, 1, 7, 1, 3, bias=False)),
nn.Sigmoid()
)
def forward(self, x):
avg_f = torch.mean(x, dim=1, keepdim=True)
max_f = torch.max(x, dim=1, keepdim=True)[0]
y = torch.cat([avg_f, max_f], dim=1)
return self.body(y).expand_as(x) * x
## Channel Attention (CA) Layer
class CALayerV2(nn.Module):
def __init__(self, n_feat, reduction=16, wn=None):
super(CALayerV2, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
wn(nn.Conv2d(n_feat, n_feat//reduction, 1, padding=0, bias=False)),
nn.ReLU(inplace=True),
wn(nn.Conv2d(n_feat//reduction, n_feat, 1, padding=0, bias=False)),
# nn.Sigmoid()
)
def forward(self, x):
y1 = self.avg_pool(x)
y2 = self.max_pool(x)
y1 = self.conv_du(y1)
y2 = self.conv_du(y2)
return x * torch.sigmoid(y1+y2)
class DALayer(nn.Module):
def __init__(self, channel, reduction, wn):
super(DALayer, self).__init__()
# global average pooling: feature --> point
self.ca = CALayer(channel, reduction, wn)
self.sa = SALayer(wn)
self.conv = wn(nn.Conv2d(channel*2, channel, 1))
def forward(self, x):
ca = self.ca(x)
sa = self.sa(x)
res = self.conv(torch.cat([ca, sa], dim=1))
return res + x
## Channel Attention (CA) Layer
class CALayer(nn.Module):
def __init__(self, channel, reduction, wn):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
wn(nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True)),
nn.ReLU(inplace=True),
wn(nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True)),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
## Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
def __init__(
self, conv, n_feat, kernel_size, reduction, wn,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False):
super(RCAB, self).__init__()
expand = 6
linear = 0.75
modules_body = []
# for i in range(2):
modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))
modules_body.append(act)
modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))
modules_body.append(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias))
if da:
modules_body.append(DALayer(n_feat, reduction, wn))
else:
modules_body.append(CALayer(n_feat, reduction, wn))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
#res = self.body(x).mul(self.res_scale)
res += x
return res
## Residual Group (RG)
class ResidualGroup(nn.Module):
def __init__(self, n_feat, n_resblocks, da=False):
super(ResidualGroup, self).__init__()
kernel_size = 3
res_scale = 1
reduction = 16
conv = common.default_conv
wn = lambda x: torch.nn.utils.weight_norm(x)
modules_body = []
modules_body = [
RCAB(
conv, n_feat, kernel_size, reduction, wn=wn, bias=True,
bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \
for _ in range(n_resblocks)]
modules_body.append(wn(conv(n_feat, n_feat, kernel_size)))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
################################################################
################################################################
################################################################
def make_layer_idx(block, n_layers):
layers = []
for i in range(n_layers):
layers.append(block(idx=i))
return nn.Sequential(*layers)
## Residual Channel Attention Block (RCAB)
class LRSCRCAB(nn.Module):
def __init__(
self, conv, n_feat, kernel_size, reduction, wn,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0):
super(LRSCRCAB, self).__init__()
expand = 6
linear = 0.75
modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []
# for i in range(2):
modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))
modules_body.append(act)
modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))
modules_body.append(wn(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias)))
if da:
modules_body.append(DALayer(n_feat, reduction, wn))
else:
modules_body.append(CALayer(n_feat, reduction, wn))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
res = torch.cat([res, x], dim=1)
return res
## Residual Channel Attention Block (RCAB)
class LRSCPYRCAB(nn.Module):
def __init__(
self, conv, n_feat, kernel_size, reduction, wn,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0):
super(LRSCPYRCAB, self).__init__()
expand = 6
linear = 0.75
modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []
# for i in range(2):
modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))
modules_body.append(act)
modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))
modules_body.append(
PyConv2d(in_channels=int(n_feat*linear),
out_channels=[n_feat//4, n_feat//4, n_feat//2],
pyconv_kernels=[3, 5, 7],
pyconv_groups=[1, 4, 8]))
if da:
modules_body.append(DALayer(n_feat, reduction, wn))
else:
modules_body.append(CALayer(n_feat, reduction, wn))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
res = torch.cat([res, x], dim=1)
return res
## Long-Range Skip-connect Residual Group (RG)
class LRSCResidualGroup(nn.Module):
def __init__(self, n_feat, n_resblocks, da=False, idx=0):
super(LRSCResidualGroup, self).__init__()
kernel_size = 3
res_scale = 1
reduction = 16
conv = common.default_conv
wn = lambda x: torch.nn.utils.weight_norm(x)
modules_head = [wn(conv(n_feat*(idx+1), n_feat, 1, bias=True))] if idx > 0 else []
modules_body = [
LRSCRCAB(
conv, n_feat, kernel_size, reduction, wn=wn, bias=True,
bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \
for i in range(n_resblocks)]
modules_body.append(wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size)))
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.head(x)
res = self.body(res)
res = torch.cat([res, x], dim=1)
return res
## Long-Range Skip-connect Residual Group (RG)
class LRSCPSResidualGroup(nn.Module):
def __init__(self, n_feat, n_resblocks, da=False, idx=0):
super(LRSCPSResidualGroup, self).__init__()
kernel_size = 3
res_scale = 1
reduction = 16
conv = PSConv2d
wn = lambda x: torch.nn.utils.weight_norm(x)
modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []
modules_body = [
LRSCRCAB(
conv, n_feat, kernel_size, reduction, wn=wn, bias=True,
bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \
for i in range(n_resblocks)]
modules_tail = [wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size))]
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
def forward(self, x):
res = self.head(x)
res = self.body(res)
res = self.tail(res)
res = torch.cat([res, x], dim=1)
return res
## Long-Range Skip-connect Residual Group (RG)
class LRSCPyResidualGroup(nn.Module):
def __init__(self, n_feat, n_resblocks, da=False, idx=0):
super(LRSCPyResidualGroup, self).__init__()
kernel_size = 3
res_scale = 1
reduction = 16
conv = PyConv2d
wn = lambda x: torch.nn.utils.weight_norm(x)
modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else []
modules_body = [
LRSCPYRCAB(
conv, n_feat, kernel_size, reduction, wn=wn, bias=True,
bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \
for i in range(n_resblocks)]
modules_tail = [wn(nn.Conv2d(n_feat*(n_resblocks+1), n_feat, 1))]
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
def forward(self, x):
res = self.head(x)
res = self.body(res)
res = self.tail(res)
res = torch.cat([res, x], dim=1)
return res
class LRSCWideActResBlock(nn.Module):
def __init__(self, nf=64, idx=0):
super(LRSCWideActResBlock, self).__init__()
self.res_scale = 1
expand = 6
linear = 0.8
kernel_size = 3
wn = lambda x: torch.nn.utils.weight_norm(x)
act=nn.ReLU(True)
head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else []
body = []
body.append(
wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))
body.append(act)
body.append(
wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))
body.append(
wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2)))
self.head = nn.Sequential(*head)
self.body = nn.Sequential(*body)
def forward(self, x):
res = self.head(x)
res = self.body(res)
res = torch.cat([res, x], dim=1)
return res
class LRSCPyWideActResBlock(nn.Module):
def __init__(self, nf=64, idx=0):
super(LRSCPyWideActResBlock, self).__init__()
self.res_scale = 1
expand = 6
linear = 0.75
kernel_size = 3
wn = lambda x: torch.nn.utils.weight_norm(x)
act=nn.ReLU(True)
head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else []
body = []
body.append(
wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))
body.append(act)
body.append(
wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))
body.append(
PyConv2d(in_channels=int(nf*linear),
out_channels=[nf//4, nf//4, nf//2],
pyconv_kernels=[3, 5, 7],
pyconv_groups=[1, 4, 8]))
self.head = nn.Sequential(*head)
self.body = nn.Sequential(*body)
def forward(self, x):
res = self.head(x)
res = self.body(res)
res = torch.cat([res, x], dim=1)
return res
## Long-Range Skip-connect Residual Group (RG)
class LRSCPyWideActResGroup(nn.Module):
def __init__(self, nf, n_resblocks, idx=0):
super(LRSCPyWideActResGroup, self).__init__()
kernel_size = 3
conv = PyConv2d
wn = lambda x: torch.nn.utils.weight_norm(x)
modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else []
modules_body = [
LRSCPyWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)]
modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))]
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
def forward(self, x):
res = self.head(x)
res = self.body(res)
res = self.tail(res)
res = torch.cat([res, x], dim=1)
return res
## Long-Range Skip-connect Residual Group (RG)
class LRSCWideActResGroup(nn.Module):
def __init__(self, nf, n_resblocks, idx=0):
super(LRSCWideActResGroup, self).__init__()
kernel_size = 3
conv = PyConv2d
wn = lambda x: torch.nn.utils.weight_norm(x)
modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else []
modules_body = [
LRSCWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)]
modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))]
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
def forward(self, x):
res = self.head(x)
res = self.body(res)
res = self.tail(res)
res = torch.cat([res, x], dim=1)
return res
################################################################
################################################################
################################################################
## Residual Channel Attention Block (RCAB)
class PYRCAB(nn.Module):
def __init__(
self, conv, n_feat, kernel_size, reduction, wn,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False):
super(PYRCAB, self).__init__()
expand = 6
linear = 0.75
modules_body = []
# for i in range(2):
modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias)))
modules_body.append(act)
modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias)))
# modules_body.append(conv(, n_feat, kernel_size, bias=bias))
modules_body.append(PyConv2d(in_channels=int(n_feat*linear),
out_channels=[n_feat//4, n_feat//4, n_feat//2],
pyconv_kernels=[3, 5, 7],
pyconv_groups=[1, 4, 8], bias=bias))
if da:
modules_body.append(DALayer(n_feat, reduction, wn))
else:
modules_body.append(CALayer(n_feat, reduction, wn))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
res += x
return res
## Residual Group (RG)
class PyResidualGroup(nn.Module):
def __init__(self, n_feat, n_resblocks, da=False):
super(PyResidualGroup, self).__init__()
kernel_size = 3
res_scale = 1
reduction = 16
conv = PyConv2d
wn = lambda x: torch.nn.utils.weight_norm(x)
modules_body = []
modules_body = [
PYRCAB(
conv, n_feat, kernel_size, reduction, wn=wn, bias=True,
bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \
for _ in range(n_resblocks)]
modules_body.append(
PyConv2d(in_channels=n_feat,
out_channels=[n_feat//4, n_feat//4, n_feat//2],
pyconv_kernels=[3, 5, 7],
pyconv_groups=[1, 4, 8]))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
class WideActResBlock(nn.Module):
def __init__(self, nf=64):
super(WideActResBlock, self).__init__()
self.res_scale = 1
body = []
expand = 6
linear = 0.8
kernel_size = 3
wn = lambda x: torch.nn.utils.weight_norm(x)
act=nn.ReLU(True)
body.append(
wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))
body.append(act)
body.append(
wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))
body.append(
wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2)))
self.body = nn.Sequential(*body)
def forward(self, x):
res = self.body(x) * self.res_scale
res += x
return res
class PSWideActResBlock(nn.Module):
def __init__(self, nf=64):
super(PSWideActResBlock, self).__init__()
self.res_scale = 1
body = []
expand = 6
linear = 0.75
kernel_size = 3
wn = lambda x: torch.nn.utils.weight_norm(x)
act=nn.ReLU(True)
body.append(
wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))
body.append(act)
body.append(
wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))
body.append(
wn(PSConv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2)))
self.body = nn.Sequential(*body)
def forward(self, x):
res = self.body(x) * self.res_scale
res += x
return res
class PyWideActResBlock(nn.Module):
def __init__(self, nf=64):
super(PyWideActResBlock, self).__init__()
self.res_scale = 1
body = []
expand = 6
linear = 0.75
kernel_size = 3
wn = lambda x: torch.nn.utils.weight_norm(x)
act=nn.ReLU(True)
expand_nf = nf*expand
linear_nf = int(nf * linear)
body.append(
wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2)))
body.append(act)
body.append(
wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2)))
body.append(
PyConv2d(in_channels=linear_nf,
out_channels=[nf//4, nf//4, nf//2],
pyconv_kernels=[3, 5, 7],
pyconv_groups=[1, 4, 8]))
self.body = nn.Sequential(*body)
def forward(self, x):
res = self.body(x) * self.res_scale
res += x
return res
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True, use_pad_mask=False):
"""Warp an image or feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'.
padding_mode (str): 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Before pytorch 1.3, the default value is
align_corners=True. After pytorch 1.3, the default value is
align_corners=False. Here, we use the True as default.
use_pad_mask (bool): only used for PWCNet, x is first padded with ones along the channel dimension.
The mask is generated according to the grid_sample results of the padded dimension.
Returns:
Tensor: Warped image or feature map.
"""
# assert x.size()[-2:] == flow.size()[1:3] # temporaily turned off for image-wise shift
n, _, h, w = x.size()
x = x.float()
# create mesh grid
# grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) # an illegal memory access on TITAN RTX + PyTorch1.9.1
grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), torch.arange(0, w, dtype=x.dtype, device=x.device))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
grid = grid.type_as(x)
vgrid = grid + flow
# if use_pad_mask: # for PWCNet
# x = F.pad(x, (0,0,0,0,0,1), mode='constant', value=1)
# scale grid to [-1,1]
if interp_mode == 'nearest4': # todo: bug, no gradient for flow model in this case!!! but the result is good
vgrid_x_floor = 2.0 * torch.floor(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0
vgrid_x_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0
vgrid_y_floor = 2.0 * torch.floor(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0
vgrid_y_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0
output00 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)
output01 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)
output10 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)
output11 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners)
return torch.cat([output00, output01, output10, output11], 1)
else:
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
# if use_pad_mask: # for PWCNet
# output = _flow_warp_masking(output)
# TODO, what if align_corners=False
return output
# def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
# """Warp an image or feature map with optical flow
# Args:
# x (Tensor): size (N, C, H, W)
# flow (Tensor): size (N, H, W, 2), normal value
# interp_mode (str): 'nearest' or 'bilinear'
# padding_mode (str): 'zeros' or 'border' or 'reflection'
# Returns:
# Tensor: warped image or feature map
# """
# assert x.size()[-2:] == flow.size()[1:3]
# B, C, H, W = x.size()
# # mesh grid
# grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
# grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
# grid.requires_grad = False
# grid = grid.type_as(x)
# vgrid = grid + flow
# # scale grid to [-1,1]
# vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
# vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
# vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
# output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
# return output
================================================
FILE: code/real/bsrt/model/bsrt.py
================================================
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import model.arch_util as arch_util
from torch.cuda.amp import autocast
import model.swin_util as swu
import time
import os
import math
from utils.debayer import Debayer3x3
import torchvision.utils as tvutils
from datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch
try:
from model.non_local.non_local_cross_dot_product import NONLocalBlock2D as NonLocalCross
from model.non_local.non_local_dot_product import NONLocalBlock2D as NonLocal
except ImportError:
raise ImportError('Failed to import Non_Local module.')
try:
from model.DCNv2.dcn_v2 import DCN_sep as DCN, FlowGuidedDCN, InsideFlowGuidedDCN
except ImportError:
raise ImportError('Failed to import DCNv2 module.')
def make_model(args, parent=False):
nframes = args.burst_size
img_size = args.patch_size * 2
patch_size = 1
in_chans = args.burst_channel
out_chans = args.n_colors
if args.model_level == "S":
depths = [6]*1 + [6] * 4
num_heads = [6]*1 + [6] * 4
embed_dim = 60
elif args.model_level == "L":
depths = [6]*1 + [8] * 6
num_heads = [6]*1 + [6] * 6
embed_dim = 180
window_size = 8
mlp_ratio = 2
upscale = args.scale[0]
non_local = args.non_local
use_checkpoint=args.use_checkpoint
if args.local_rank <= 0:
print("depths: ", depths)
return BSRT(args=args,nframes=nframes,
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
out_chans=out_chans,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
upscale=upscale,
non_local=non_local,
use_checkpoint=use_checkpoint)
class BasicModule(nn.Module):
"""Basic Module for SpyNet.
"""
def __init__(self):
super(BasicModule, self).__init__()
self.basic_module = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
def forward(self, tensor_input):
return self.basic_module(tensor_input)
class SpyNet(nn.Module):
"""SpyNet architecture.
Args:
load_path (str): path for pretrained SpyNet. Default: None.
return_levels (list[int]): return flows of different levels. Default: [5].
"""
def __init__(self, load_path=None, return_levels=[5]):
super(SpyNet, self).__init__()
self.return_levels = return_levels
self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
if load_path:
if not os.path.exists(load_path):
import requests
url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/spynet_sintel_final-3d2a1287.pth'
r = requests.get(url, allow_redirects=True)
print(f'downloading SpyNet pretrained model from {url}')
os.makedirs(os.path.dirname(load_path), exist_ok=True)
open(load_path, 'wb').write(r.content)
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def preprocess(self, tensor_input):
tensor_output = (tensor_input - self.mean) / self.std
return tensor_output
def process(self, ref, supp, w, h, w_floor, h_floor):
flow_list = []
ref = [self.preprocess(ref)]
supp = [self.preprocess(supp)]
# ref = [ref]
# supp = [supp]
for level in range(5):
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
flow = ref[0].new_zeros(
[ref[0].size(0), 2,
int(math.floor(ref[0].size(2) / 2.0)),
int(math.floor(ref[0].size(3) / 2.0))])
for level in range(len(ref)):
upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
if upsampled_flow.size(2) != ref[level].size(2):
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
if upsampled_flow.size(3) != ref[level].size(3):
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
flow = self.basic_module[level](torch.cat([
ref[level],
arch_util.flow_warp(
supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
upsampled_flow
], 1)) + upsampled_flow
if level in self.return_levels:
scale = 2**(5-level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8)
flow_out = F.interpolate(input=flow, size=(h//scale, w//scale), mode='bilinear', align_corners=False)
flow_out[:, 0, :, :] *= float(w//scale) / float(w_floor//scale)
flow_out[:, 1, :, :] *= float(h//scale) / float(h_floor//scale)
if torch.abs(flow_out).mean() > 200:
print(f"level {level}, flow > 200: {torch.abs(flow_out).mean():.4f}")
# return None
flow_out.clamp(-250, 250)
flow_list.insert(0, flow_out)
return flow_list
def forward(self, ref, supp):
assert ref.size() == supp.size()
h, w = ref.size(2), ref.size(3)
w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
flow_list = self.process(ref, supp, w, h, w_floor, h_floor)
return flow_list[0] if len(flow_list) == 1 else flow_list
class FlowGuidedPCDAlign(nn.Module):
''' Alignment module using Pyramid, Cascading and Deformable convolution
with 3 pyramid levels. [From EDVR]
'''
def __init__(self, nf=64, groups=8):
super(FlowGuidedPCDAlign, self).__init__()
# L3: level 3, 1/4 spatial size
self.L3_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L3_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)
# L2: level 2, 1/2 spatial size
self.L2_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L2_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)
self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
# L1: level 1, original spatial size
self.L1_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L1_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)
self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
# Cascading DCN
self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l):
'''align other neighboring frames to the reference frame in the feature level
nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
'''
# L3
L3_offset = torch.cat([nbr_fea_warped_l[2], ref_fea_l[2], flows_l[2]], dim=1)
L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset, flows_l[2]))
# L2
L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)
L2_offset = torch.cat([nbr_fea_warped_l[1], ref_fea_l[1], flows_l[1]], dim=1)
L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset*2], dim=1)))
L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset, flows_l[1])
L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)
L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
# L1
L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)
L1_offset = torch.cat([nbr_fea_warped_l[0], ref_fea_l[0], flows_l[0]], dim=1)
L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset, flows_l[0])
L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)
L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
# Cascading
offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)
offset = self.lrelu(self.cas_offset_conv1(offset))
offset = self.lrelu(self.cas_offset_conv2(offset))
L1_fea = self.cas_dcnpack(L1_fea, offset)
return L1_fea
class CrossNonLocal_Fusion(nn.Module):
''' Cross Non Local fusion module
'''
def __init__(self, nf=64, out_feat=96, nframes=5, center=2):
super(CrossNonLocal_Fusion, self).__init__()
self.center = center
self.non_local_T = nn.ModuleList()
self.non_local_F = nn.ModuleList()
for i in range(nframes):
self.non_local_T.append(NonLocalCross(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False))
self.non_local_F.append(NonLocal(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False))
# fusion conv: using 1x1 to save parameters and computation
self.fea_fusion = nn.Conv2d(nframes * nf*2, out_feat, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, aligned_fea):
B, N, C, H, W = aligned_fea.size() # N video frames
ref = aligned_fea[:, self.center, :, :, :].clone()
cor_l = []
non_l = []
for i in range(N):
nbr = aligned_fea[:, i, :, :, :]
non_l.append(self.non_local_F[i](nbr))
cor_l.append(self.non_local_T[i](nbr, ref))
aligned_fea_T = torch.cat(cor_l, dim=1)
aligned_fea_F = torch.cat(non_l, dim=1)
aligned_fea = torch.cat([aligned_fea_T, aligned_fea_F], dim=1)
#### fusion
fea = self.fea_fusion(aligned_fea)
return fea
class BSRT(nn.Module):
def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chans=3, out_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=4, non_local=False,
**kwargs):
super(BSRT, self).__init__()
num_in_ch = in_chans
num_out_ch = out_chans
num_feat = 64
groups = 8
# embed_dim = num_feat
back_RBs = 5
n_resblocks = 6
self.args = args
self.center = 0
self.upscale = upscale
self.window_size = window_size
self.non_local = non_local
self.nframes = nframes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
spynet_path='/home/luoziwei/.pretrained_models/spynet_sintel_final-3d2a1287.pth'
self.spynet = SpyNet(spynet_path, [3, 4, 5])
self.conv_flow = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)
self.flow_ps = nn.PixelShuffle(2)
# split image into non-overlapping patches
self.patch_embed = swu.PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = swu.PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch*(1+2*0), embed_dim, 3, 1, 1, bias=True)
# # stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
if args.swinfeature:
if self.args.local_rank <= 0:
print("using swinfeature")
self.pre_layers = nn.ModuleList()
for i_layer in range(depths[0]):
layer = swu.SwinTransformerBlock(dim=embed_dim,
input_resolution=(patches_resolution[0]//2,
patches_resolution[1]//2),
num_heads=num_heads[0], window_size=window_size,
shift_size=0 if (i_layer % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[i_layer],
norm_layer=norm_layer)
self.pre_layers.append(layer)
self.pre_norm = norm_layer(embed_dim)
else:
WARB = functools.partial(arch_util.WideActResBlock, nf=embed_dim)
self.feature_extraction = arch_util.make_layer(WARB, 5)
self.conv_after_pre_layer = nn.Conv2d(embed_dim, num_feat*4, 3, 1, 1, bias=True)
self.mid_ps = nn.PixelShuffle(2)
self.fea_L2_conv1 = nn.Conv2d(num_feat, num_feat*2, 3, 2, 1, bias=True)
self.fea_L3_conv1 = nn.Conv2d(num_feat*2, num_feat*4, 3, 2, 1, bias=True)
#####################################################################################################
################################### 2, Feature Enhanced PCD Align ###################################
# Top layers
self.toplayer = nn.Conv2d(num_feat*4, num_feat, kernel_size=1, stride=1, padding=0)
# Smooth layers
self.smooth1 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1)
self.smooth2 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1)
# Lateral layers
self.latlayer1 = nn.Conv2d(num_feat*2, num_feat, kernel_size=1, stride=1, padding=0)
self.latlayer2 = nn.Conv2d(num_feat*1, num_feat, kernel_size=1, stride=1, padding=0)
# self.align = PCD_Align(nf=num_feat, groups=groups)
self.align = FlowGuidedPCDAlign(nf=num_feat, groups=groups)
#####################################################################################################
################################### 3, Multi-frame Feature Fusion ##################################
if self.non_local:
if self.args.local_rank <= 0:
print("using non_local")
self.fusion = CrossNonLocal_Fusion(nf=num_feat, out_feat=embed_dim, nframes=nframes, center=self.center)
else:
self.fusion = nn.Conv2d(nframes * num_feat, embed_dim, 1, 1, bias=True)
#####################################################################################################
################################### 4, deep feature extraction ######################################
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
swu.trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(1, self.num_layers):
layer = swu.RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
#####################################################################################################
################################ 5, high quality image reconstruction ################################
self.upconv1 = nn.Conv2d(embed_dim, num_feat * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(64, args.n_colors, 3, 1, 1, bias=True)
#### skip #############
self.skip_pixel_shuffle = nn.PixelShuffle(2)
self.skipup1 = nn.Conv2d(num_in_ch//4, num_feat * 4, 3, 1, 1, bias=True)
self.skipup2 = nn.Conv2d(num_feat, args.n_colors * 4, 3, 1, 1, bias=True)
#### activation function
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
self.lrelu2 = nn.LeakyReLU(0.1, inplace=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
swu.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def _upsample_add(self, x, y):
return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + y
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def pre_forward_features(self, x):
if self.args.swinfeature:
x_size = (x.shape[-2], x.shape[-1])
x = self.patch_embed(x, use_norm=True)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for idx, layer in enumerate(self.pre_layers):
x = layer(x, x_size)
x = self.pre_norm(x)
x = self.patch_unembed(x, x_size)
else:
x = self.feature_extraction(x)
return x
def forward_features(self, x):
x_size = (x.shape[-2], x.shape[-1])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for idx, layer in enumerate(self.layers):
x = layer(x, x_size)
if torch.any(torch.isinf(x)) or torch.any(torch.isnan(x)):
print('layer: ', idx)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
@autocast()
def forward(self, x, print_time=False):
B, N, C, H, W = x.size() # N video frames
x_center = x[:, self.center, :, :, :].contiguous()
#### skip module ########
skip1 = self.lrelu2(self.skip_pixel_shuffle(self.skipup1(self.skip_pixel_shuffle(x_center))))
skip2 = self.skip_pixel_shuffle(self.skipup2(skip1))
x_ = self.conv_flow(self.flow_ps(x.view(B*N, C, H, W))).view(B, N, -1, H*2, W*2)
# calculate flows
ref_flows = self.get_ref_flows(x_)
#### extract LR features
x = self.lrelu(self.conv_first(x.view(B*N, -1, H, W)))
L1_fea = self.mid_ps(self.conv_after_pre_layer(self.pre_forward_features(x)))
_, _, H, W = L1_fea.size()
L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
# FPN enhance features
L3_fea = self.lrelu(self.toplayer(L3_fea))
L2_fea = self.smooth1(self._upsample_add(L3_fea, self.latlayer1(L2_fea)))
L1_fea = self.smooth2(self._upsample_add(L2_fea, self.latlayer2(L1_fea)))
L1_fea = L1_fea.view(B, N, -1, H, W).contiguous()
L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2 ).contiguous()
L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4).contiguous()
#### PCD align
# ref feature list
ref_fea_l = [
L1_fea[:, self.center, :, :, :].clone(),
L2_fea[:, self.center, :, :, :].clone(),
L3_fea[:, self.center, :, :, :].clone()
]
aligned_fea = []
for i in range(N):
nbr_fea_l = [
L1_fea[:, i, :, :, :].clone(),
L2_fea[:, i, :, :, :].clone(),
L3_fea[:, i, :, :, :].clone()
]
flows_l = [
ref_flows[0][:, i, :, :, :].clone(),
ref_flows[1][:, i, :, :, :].clone(),
ref_flows[2][:, i, :, :, :].clone()
]
# print(nbr_fea_l[0].shape, flows_l[0].shape)
nbr_warped_l = [
arch_util.flow_warp(nbr_fea_l[0], flows_l[0].permute(0, 2, 3, 1), 'bilinear'),
arch_util.flow_warp(nbr_fea_l[1], flows_l[1].permute(0, 2, 3, 1), 'bilinear'),
arch_util.flow_warp(nbr_fea_l[2], flows_l[2].permute(0, 2, 3, 1), 'bilinear')
]
aligned_fea.append(self.align(nbr_fea_l, nbr_warped_l, ref_fea_l, flows_l))
aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W] --> [B, T, C, H, W]
if not self.non_local:
aligned_fea = aligned_fea.view(B, -1, H, W)
x = self.lrelu(self.fusion(aligned_fea))
x = self.lrelu(self.conv_after_body(self.forward_features(x))) + x
x = self.lrelu(self.pixel_shuffle(self.upconv1(x)))
x = skip1 + x
x = self.lrelu(self.pixel_shuffle(self.upconv2(x)))
x = self.lrelu(self.HRconv(x))
x = self.conv_last(x)
x = skip2 + x
return x
def get_ref_flows(self, x):
'''Get flow between frames ref and other'''
b, n, c, h, w = x.size()
x_nbr = x.reshape(-1, c, h, w)
x_ref = x[:, self.center:self.center+1, :, :, :].repeat(1, n, 1, 1, 1).reshape(-1, c, h, w)
# backward
flows = self.spynet(x_ref, x_nbr)
flows_list = [flow.view(b, n, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in
zip(flows, range(3))]
return flows_list
================================================
FILE: code/real/bsrt/model/checkpoint.py
================================================
import torch
import warnings
def detach_variable(inputs):
if isinstance(inputs, tuple):
out = []
for inp in inputs:
x = inp.detach()
x.requires_grad = inp.requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
def check_backward_validity(inputs):
if not any(inp.requires_grad for inp in inputs):
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
for i in range(len(ctx.input_tensors)):
temp = ctx.input_tensors[i]
ctx.input_tensors[i] = temp.detach()
ctx.input_tensors[i].requires_grad = temp.requires_grad
with torch.enable_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
return (None, None) + input_grads
================================================
FILE: code/real/bsrt/model/common.py
================================================
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), bias=bias)
class MeanShift(nn.Conv2d):
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
for p in self.parameters():
p.requires_grad = False
class BasicBlock(nn.Sequential):
def __init__(
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
bn=True, act=nn.ReLU(True)):
m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
if bn:
m.append(nn.BatchNorm2d(out_channels))
if act is not None:
m.append(act)
super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module):
def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if i == 0:
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
class UpOnly(nn.Sequential):
def __init__(self, scale):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.PixelShuffle(3))
else:
raise NotImplementedError
super(UpOnly, self).__init__(*m)
def lanczos_kernel(dx, a=3, N=None, dtype=None, device=None):
'''
Generates 1D Lanczos kernels for translation and interpolation.
Args:
dx : float, tensor (batch_size, 1), the translation in pixels to shift an image.
a : int, number of lobes in the kernel support.
If N is None, then the width is the kernel support (length of all lobes),
S = 2(a + ceil(dx)) + 1.
N : int, width of the kernel.
If smaller than S then N is set to S.
Returns:
k: tensor (?, ?), lanczos kernel
'''
if not torch.is_tensor(dx):
dx = torch.tensor(dx, dtype=dtype, device=device)
if device is None:
device = dx.device
if dtype is None:
dtype = dx.dtype
D = dx.abs().ceil().int()
S = 2 * (a + D) + 1 # width of kernel support
S_max = S.max() if hasattr(S, 'shape') else S
if (N is None) or (N < S_max):
N = S
Z = (N - S) // 2 # width of zeros beyond kernel support
start = (-(a + D + Z)).min()
end = (a + D + Z + 1).max()
x = torch.arange(start, end, dtype=dtype, device=device).view(1, -1) - dx
px = (np.pi * x) + 1e-3
sin_px = torch.sin(px)
sin_pxa = torch.sin(px / a)
k = a * sin_px * sin_pxa / px ** 2 # sinc(x) masked by sinc(x/a)
return k
def lanczos_shift(img, shift, p=5, a=3):
'''
Shifts an image by convolving it with a Lanczos kernel.
Lanczos interpolation is an approximation to ideal sinc interpolation,
by windowing a sinc kernel with another sinc function extending up to a
few nunber of its lobes (typically a=3).
Args:
img : tensor (batch_size, channels, height, width), the images to be shifted
shift : tensor (batch_size, 2) of translation parameters (dy, dx)
p : int, padding width prior to convolution (default=3)
a : int, number of lobes in the Lanczos interpolation kernel (default=3)
Returns:
I_s: tensor (batch_size, channels, height, width), shifted images
'''
img = img.transpose(0, 1)
dtype = img.dtype
if len(img.shape) == 2:
img = img[None, None].repeat(1, shift.shape[0], 1, 1) # batch of one image
elif len(img.shape) == 3: # one image per shift
assert img.shape[0] == shift.shape[0]
img = img[None,]
# Apply padding
padder = torch.nn.ReflectionPad2d(p) # reflect pre-padding
I_padded = padder(img)
# Create 1D shifting kernels
y_shift = shift[:, [0]]
x_shift = shift[:, [1]]
k_y = (lanczos_kernel(y_shift, a=a, N=None, dtype=dtype)
.flip(1) # flip axis of convolution
)[:, None, :, None] # expand dims to get shape (batch, channels, y_kernel, 1)
k_x = (lanczos_kernel(x_shift, a=a, N=None, dtype=dtype)
.flip(1)
)[:, None, None, :] # shape (batch, channels, 1, x_kernel)
# Apply kernels
# print(I_padded.shape, k_y.shape)
I_s = torch.conv1d(I_padded,
groups=k_y.shape[0],
weight=k_y,
padding=[k_y.shape[2] // 2, 0]) # same padding
I_s = torch.conv1d(I_s,
groups=k_x.shape[0],
weight=k_x,
padding=[0, k_x.shape[3] // 2])
I_s = I_s[..., p:-p, p:-p] # remove padding
# print(I_s.shape)
return I_s.transpose(0, 1) # , k.squeeze()
================================================
FILE: code/real/bsrt/model/non_local/network.py
================================================
from torch import nn
# from lib.non_local_concatenation import NONLocalBlock2D
# from lib.non_local_gaussian import NONLocalBlock2D
from lib.non_local_embedded_gaussian import NONLocalBlock2D
# from lib.non_local_dot_product import NONLocalBlock2D
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.conv_1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.nl_1 = NONLocalBlock2D(in_channels=32)
self.conv_2 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.nl_2 = NONLocalBlock2D(in_channels=64)
self.conv_3 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc = nn.Sequential(
nn.Linear(in_features=128*3*3, out_features=256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(in_features=256, out_features=10)
)
def forward(self, x):
batch_size = x.size(0)
feature_1 = self.conv_1(x)
nl_feature_1 = self.nl_1(feature_1)
feature_2 = self.conv_2(nl_feature_1)
nl_feature_2 = self.nl_2(feature_2)
output = self.conv_3(nl_feature_2).view(batch_size, -1)
output = self.fc(output)
return output
def forward_with_nl_map(self, x):
batch_size = x.size(0)
feature_1 = self.conv_1(x)
nl_feature_1, nl_map_1 = self.nl_1(feature_1, return_nl_map=True)
feature_2 = self.conv_2(nl_feature_1)
nl_feature_2, nl_map_2 = self.nl_2(feature_2, return_nl_map=True)
output = self.conv_3(nl_feature_2).view(batch_size, -1)
output = self.fc(output)
return output, [nl_map_1, nl_map_2]
if __name__ == '__main__':
import torch
img = torch.randn(3, 1, 28, 28)
net = Network()
out = net(img)
print(out.size())
================================================
FILE: code/real/bsrt/model/non_local/non_local_concatenation.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.concat_project = nn.Sequential(
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
nn.ReLU()
)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x, return_nl_map=False):
'''
:param x: (b, c, t, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# (b, c, N, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
# (b, c, 1, N)
phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
h = theta_x.size(2)
w = phi_x.size(3)
theta_x = theta_x.repeat(1, 1, 1, w)
phi_x = phi_x.repeat(1, 1, h, 1)
concat_feature = torch.cat([theta_x, phi_x], dim=1)
f = self.concat_project(concat_feature)
b, _, h, w = f.size()
f = f.view(b, h, w)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
if return_nl_map:
return z, f_div_C
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, sub_sample=sub_sample,
bn_layer=bn_layer)
if __name__ == '__main__':
import torch
for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
img = torch.zeros(2, 3, 20)
net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.zeros(2, 3, 20, 20)
net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.randn(2, 3, 8, 20, 20)
net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
================================================
FILE: code/real/bsrt/model/non_local/non_local_cross_dot_product.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(4))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x, ref, return_nl_map=False):
"""
:param x: (b, c, t, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_ref = self.theta(ref).view(batch_size, self.inter_channels, -1)
theta_ref = theta_ref.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_ref, phi_x)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
if return_nl_map:
return z, f_div_C
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, sub_sample=sub_sample,
bn_layer=bn_layer)
if __name__ == '__main__':
import torch
for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
img = torch.zeros(2, 3, 20)
net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.zeros(2, 3, 20, 20)
net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.randn(2, 3, 8, 20, 20)
net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
================================================
FILE: code/real/bsrt/model/non_local/non_local_dot_product.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x, return_nl_map=False):
"""
:param x: (b, c, t, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
if return_nl_map:
return z, f_div_C
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, sub_sample=sub_sample,
bn_layer=bn_layer)
if __name__ == '__main__':
import torch
for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
img = torch.zeros(2, 3, 20)
net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.zeros(2, 3, 20, 20)
net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.randn(2, 3, 8, 20, 20)
net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
================================================
FILE: code/real/bsrt/model/non_local/non_local_embedded_gaussian.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
"""
:param in_channels:
:param inter_channels:
:param dimension:
:param sub_sample:
:param bn_layer:
"""
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x, return_nl_map=False):
"""
:param x: (b, c, t, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
if return_nl_map:
return z, f_div_C
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, sub_sample=sub_sample,
bn_layer=bn_layer,)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, sub_sample=sub_sample,
bn_layer=bn_layer,)
if __name__ == '__main__':
import torch
for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
img = torch.zeros(2, 3, 20)
net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.zeros(2, 3, 20, 20)
net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.randn(2, 3, 8, 20, 20)
net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
================================================
FILE: code/real/bsrt/model/non_local/non_local_gaussian.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = max_pool_layer
def forward(self, x, return_nl_map=False):
"""
:param x: (b, c, t, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = x.view(batch_size, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
else:
phi_x = x.view(batch_size, self.in_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
# if self.store_last_batch_nl_map:
# self.nl_map = f_div_C
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
if return_nl_map:
return z, f_div_C
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, sub_sample=sub_sample,
bn_layer=bn_layer)
if __name__ == '__main__':
import torch
for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
img = torch.zeros(2, 3, 20)
net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.zeros(2, 3, 20, 20)
net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
img = torch.randn(2, 3, 8, 20, 20)
net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
out = net(img)
print(out.size())
================================================
FILE: code/real/bsrt/model/swin_util.py
================================================
# -----------------------------------------------------------------------------------
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
# -----------------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch.utils.checkpoint as checkpoint
from model.checkpoint import CheckpointFunction as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import time
from functools import reduce, lru_cache
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Mlp_GEGLU(nn.Module):
""" Multilayer perceptron with gated linear unit (GEGLU). Ref. "GLU Variants Improve Transformer".
Args:
x: (B, D, H, W, C)
Returns:
x: (B, D, H, W, C)
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc11 = nn.Linear(in_features, hidden_features)
self.fc12 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.act(self.fc11(x)) * self.fc12(x)
x = self.drop(x)
x = self.fc2(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
@lru_cache()
def calculate_mask(x_size, window_size, shift_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.use_checkpoint = use_checkpoint
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
# if self.shift_size > 0:
# attn_mask = self.calculate_mask(self.input_resolution)
# else:
# attn_mask = None
# self.register_buffer("attn_mask", attn_mask)
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
# if self.input_resolution != x_size:
# self.input_resolution = x_size
# if self.attn_mask is not None:
# self.attn_mask = self.calculate_mask(x_size).to(x.device)
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
# if self.input_resolution == x_size:
# if self.use_checkpoint:
# attn_windows = checkpoint.apply(self.attn, x_windows, self.attn_mask) # nW*B, window_size*window_size, C
# else:
# attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# else:
# if self.use_checkpoint:
# attn_windows = checkpoint.apply(self.attn, x_windows, self.calculate_mask(x_size).to(x.device))
# else:
# attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
attn_mask = calculate_mask(x_size, self.window_size, self.shift_size).to(x.device)
attn_windows = self.attn(x_windows, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = False
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer, use_checkpoint=use_checkpoint)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
# x = checkpoint.checkpoint(blk, x, x_size)
x = checkpoint.apply(blk, 2, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
super(RSTB, self).__init__()
# print(f'dim: {dim}, input_resolution: {input_resolution}, depth: {depth}, num_heads: {num_heads}, window_size: {window_size}, img_size: {img_size}. patch_size: {patch_size}')
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
x = self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
return x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x, use_norm=True):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if use_norm and self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class SwinIR(nn.Module):
r""" SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# print('patches_resolution: ', patches_resolution)
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
assert self.upscale == 4, 'only support x4 now.'
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x[:, :, :H*self.upscale, :W*self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
if __name__ == '__main__':
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
model = SwinIR(upscale=2, img_size=(height, width),
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
print(model)
print(height, width, model.flops() / 1e9)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
================================================
FILE: code/real/bsrt/model/utils/interp_methods.py
================================================
from math import pi
try:
import torch
except ImportError:
torch = None
try:
import numpy
except ImportError:
numpy = None
if numpy is None and torch is None:
raise ImportError("Must have either Numpy or PyTorch but both not found")
def set_framework_dependencies(x):
if type(x) is numpy.ndarray:
to_dtype = lambda a: a
fw = numpy
else:
to_dtype = lambda a: a.to(x.dtype)
fw = torch
eps = fw.finfo(fw.float32).eps
return fw, to_dtype, eps
def support_sz(sz):
def wrapper(f):
f.support_sz = sz
return f
return wrapper
@support_sz(4)
def cubic(x):
fw, to_dtype, eps = set_framework_dependencies(x)
absx = fw.abs(x)
absx2 = absx ** 2
absx3 = absx ** 3
return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
(-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
to_dtype((1. < absx) & (absx <= 2.)))
@support_sz(4)
def lanczos2(x):
fw, to_dtype, eps = set_framework_dependencies(x)
return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))
@support_sz(6)
def lanczos3(x):
fw, to_dtype, eps = set_framework_dependencies(x)
return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))
@support_sz(2)
def linear(x):
fw, to_dtype, eps = set_framework_dependencies(x)
return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *
to_dtype((0 <= x) & (x <= 1)))
@support_sz(1)
def box(x):
fw, to_dtype, eps = set_framework_dependencies(x)
return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
================================================
FILE: code/real/bsrt/model/utils/psconv.py
================================================
import torch
import torch.nn as nn
class PyConv2d(nn.Module):
"""PyConv2d with padding (general case). Applies a 2D PyConv over an input signal composed of several input planes.
Args:
in_channels (int): Number of channels in the input image
out_channels (list): Number of channels for each pyramid level produced by the convolution
pyconv_kernels (list): Spatial size of the kernel for each pyramid level
pyconv_groups (list): Number of blocked connections from input channels to output channels for each pyramid level
stride (int or tuple, optional): Stride of the convolution. Default: 1
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``
Example::
>>> # PyConv with two pyramid levels, kernels: 3x3, 5x5
>>> m = PyConv2d(in_channels=64, out_channels=[32, 32], pyconv_kernels=[3, 5], pyconv_groups=[1, 4])
>>> input = torch.randn(4, 64, 56, 56)
>>> output = m(input)
>>> # PyConv with three pyramid levels, kernels: 3x3, 5x5, 7x7
>>> m = PyConv2d(in_channels=64, out_channels=[16, 16, 32], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])
>>> input = torch.randn(4, 64, 56, 56)
>>> output = m(input)
"""
def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_groups, stride=1, dilation=1, bias=False):
super(PyConv2d, self).__init__()
assert len(out_channels) == len(pyconv_kernels) == len(pyconv_groups)
self.pyconv_levels = [None] * len(pyconv_kernels)
for i in range(len(pyconv_kernels)):
self.pyconv_levels[i] = nn.Conv2d(in_channels, out_channels[i], kernel_size=pyconv_kernels[i],
stride=stride, padding=pyconv_kernels[i] // 2, groups=pyconv_groups[i],
dilation=dilation, bias=bias)
self.pyconv_levels = nn.ModuleList(self.pyconv_levels)
def forward(self, x):
out = []
for level in self.pyconv_levels:
out.append(level(x))
return torch.cat(out, 1)
################################################################
class PSConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, parts=4, bias=False):
super(PSConv2d, self).__init__()
self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, dilation, dilation, groups=parts, bias=bias)
self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * dilation, 2 * dilation, groups=parts, bias=bias)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
def backward_hook(grad):
out = grad.clone()
out[self.mask] = 0
return out
self.mask = torch.zeros(self.conv.weight.shape).byte().cuda()
_in_channels = in_channels // parts
_out_channels = out_channels // parts
for i in range(parts):
self.mask[i * _out_channels: (i + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1
self.mask[(i + parts//2)%parts * _out_channels: ((i + parts//2)%parts + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1
self.conv.weight.data[self.mask] = 0
self.conv.weight.register_hook(backward_hook)
self.weight = self.conv.weight
self.bias = self.conv.bias
def forward(self, x):
x1, x2 = x.chunk(2, dim=1)
x_shift = self.gwconv_shift(torch.cat((x2, x1), dim=1))
return self.gwconv(x) + self.conv(x) + x_shift
# PSConv-based Group Convolution
class PSGConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, parts=4, bias=False):
super(PSGConv2d, self).__init__()
self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups * parts, bias=bias)
self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * padding, 2 * dilation, groups=groups * parts, bias=bias)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias)
def backward_hook(grad):
out = grad.clone()
out[self.mask] = 0
return out
self.mask = torch.zeros(self.conv.weight.shape).bool().cuda()
_in_channels = in_channels // (groups * parts)
_out_channels = out_channels // (groups * parts)
for i in range(parts):
for j in range(groups):
self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1
self.mask[((i + parts // 2) % parts + j * groups) * _out_channels: ((i + parts // 2) % parts + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1
self.conv.weight.data[self.mask] = 0
self.conv.weight.register_hook(backward_hook)
self.groups = groups
self.weight = self.conv.weight
self.bias = self.conv.bias
def forward(self, x):
x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1))
x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1)
x_shift = self.gwconv_shift(x_merge)
gx = self.gwconv(x)
cx = self.conv(x)
# print(x.shape, gx.shape, cx.shape, x_merge.shape, x_shift.shape)
return gx + cx + x_shift
================================================
FILE: code/real/bsrt/model/utils/resize_right.py
================================================
import warnings
from math import ceil
import model.utils.interp_methods as interp_methods
class NoneClass:
pass
try:
import torch
from torch import nn
nnModuleWrapped = nn.Module
except ImportError:
warnings.warn('No PyTorch found, will work only with Numpy')
torch = None
nnModuleWrapped = NoneClass
try:
import numpy
except ImportError:
warnings.warn('No Numpy found, will work only with PyTorch')
numpy = None
if numpy is None and torch is None:
raise ImportError("Must have either Numpy or PyTorch but both not found")
def resize(input, scale_factors=None, out_shape=None,
interp_method=interp_methods.cubic, support_sz=None,
antialiasing=True):
# get properties of the input tensor
in_shape, n_dims = input.shape, input.ndim
# fw stands for framework that can be either numpy or torch,
# determined by the input type
fw = numpy if type(input) is numpy.ndarray else torch
eps = fw.finfo(fw.float32).eps
# set missing scale factors or output shapem one according to another,
# scream if both missing
scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,
scale_factors, fw)
# sort indices of dimensions according to scale of each dimension.
# since we are going dim by dim this is efficient
sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])
for dim in sorted(range(n_dims),
key=lambda ind: scale_factors[ind])
if scale_factors[dim] != 1.]
# unless support size is specified by the user, it is an attribute
# of the interpolation method
if support_sz is None:
support_sz = interp_method.support_sz
# when using pytorch, we need to know what is the input tensor device
if fw is torch:
device = input.device
# output begins identical to input and changes with each iteration
output = input
# iterate over dims
for dim, scale_factor in sorted_filtered_dims_and_scales:
# get 1d set of weights and fields of view for each output location
# along this dim
field_of_view, weights = prepare_weights_and_field_of_view_1d(
dim, scale_factor, in_shape[dim], out_shape[dim], interp_method,
support_sz, antialiasing, fw, eps, device)
# multiply the weights by the values in the field of view and
# aggreagate
output = apply_weights(output, field_of_view, weights, dim, n_dims,
fw)
return output
class ResizeLayer(nnModuleWrapped):
def __init__(self, in_shape, scale_factors=None, out_shape=None,
interp_method=interp_methods.cubic, support_sz=None,
antialiasing=True):
super(ResizeLayer, self).__init__()
# fw stands for framework, that can be either numpy or torch. since
# this is a torch layer, only one option in this case.
fw = torch
eps = fw.finfo(fw.float32).eps
# set missing scale factors or output shapem one according to another,
# scream if both missing
scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape,
scale_factors, fw)
# unless support size is specified by the user, it is an attribute
# of the interpolation method
if support_sz is None:
support_sz = interp_method.support_sz
self.n_dims = len(in_shape)
# sort indices of dimensions according to scale of each dimension.
# since we are going dim by dim this is efficient
self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim])
for dim in
sorted(range(self.n_dims),
key=lambda ind:
scale_factors[ind])
if scale_factors[dim] != 1.]
# iterate over dims
field_of_view_list = []
weights_list = []
for dim, scale_factor in self.sorted_filtered_dims_and_scales:
# get 1d set of weights and fields of view for each output
# location along this dim
field_of_view, weights = prepare_weights_and_field_of_view_1d(
dim, scale_factor, in_shape[dim], out_shape[dim],
interp_method, support_sz, antialiasing, fw, eps, input.device)
# keep weights and fields of views for all dims
weights_list.append(nn.Parameter(weights, requires_grad=False))
field_of_view_list.append(nn.Parameter(field_of_view,
requires_grad=False))
self.field_of_view = nn.ParameterList(field_of_view_list)
self.weights = nn.ParameterList(weights_list)
self.in_shape = in_shape
def forward(self, input):
# output begins identical to input and changes with each iteration
output = input
for (dim, scale_factor), field_of_view, weights in zip(
self.sorted_filtered_dims_and_scales,
self.field_of_view,
self.weights):
# multiply the weights by the values in the field of view and
# aggreagate
output = apply_weights(output, field_of_view, weights, dim,
self.n_dims, torch)
return output
def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,
interp_method, support_sz,
antialiasing, fw, eps, device=None):
# If antialiasing is taking place, we modify the window size and the
# interpolation method (see inside function)
interp_method, cur_support_sz = apply_antialiasing_if_needed(
interp_method,
support_sz,
scale_factor,
antialiasing)
# STEP 1- PROJECTED GRID: The non-integer locations of the projection of
# output pixel locations to the input tensor
projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device)
# STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels
# that influence it
field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz,
fw, eps)
# STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the
# field of view for each output pixel
weights = get_weights(interp_method, projected_grid, field_of_view)
return field_of_view, weights
def apply_weights(input, field_of_view, weights, dim, n_dims, fw):
# STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying
# its set of weights with the pixel values in its field of view.
# We now multiply the fields of view with their matching weights.
# We do this by tensor multiplication and broadcasting.
# this step is separated to a different function, so that it can be
# repeated with the same calculated weights and fields.
# for this operations we assume the resized dim is the first one.
# so we transpose and will transpose back after multiplying
tmp_input = fw_swapaxes(input, dim, 0, fw)
# field_of_view is a tensor of order 2: for each output (1d location
# along cur dim)- a list of 1d neighbors locations.
# note that this whole operations is applied to each dim separately,
# this is why it is all in 1d.
# neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:
# for each output pixel (this time indicated in all dims), these are the
# values of the neighbors in the 1d field of view. note that we only
# consider neighbors along the current dim, but such set exists for every
# multi-dim location, hence the final tensor order is image_dims+1.
neighbors = tmp_input[field_of_view]
# weights is an order 2 tensor: for each output location along 1d- a list
# of weighs matching the field of view. we augment it with ones, for
# broadcasting, so that when multiplies some tensor the weights affect
# only its first dim.
tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1)))
# now we simply multiply the weights with the neighbors, and then sum
# along the field of view, to get a single value per out pixel
tmp_output = (neighbors * tmp_weights).sum(1)
# we transpose back the resized dim to its original position
return fw_swapaxes(tmp_output, 0, dim, fw)
def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):
# eventually we must have both scale-factors and out-sizes for all in/out
# dims. however, we support many possible partial arguments
if scale_factors is None and out_shape is None:
raise ValueError("either scale_factors or out_shape should be "
"provided")
if out_shape is not None:
# if out_shape has less dims than in_shape, we defaultly resize the
# first dims for numpy and last dims for torch
out_shape = (list(out_shape) + list(in_shape[:-len(out_shape)])
if fw is numpy
else list(in_shape[:-len(out_shape)]) + list(out_shape))
if scale_factors is None:
# if no scale given, we calculate it as the out to in ratio
# (not recomended)
scale_factors = [out_sz / in_sz for out_sz, in_sz
in zip(out_shape, in_shape)]
if scale_factors is not None:
# by default, if a single number is given as scale, we assume resizing
# two dims (most common are images with 2 spatial dims)
scale_factors = (scale_factors
if isinstance(scale_factors, (list, tuple))
else [scale_factors, scale_factors])
# if less scale_factors than in_shape dims, we defaultly resize the
# first dims for numpy and last dims for torch
scale_factors = (list(scale_factors) + [1] *
(len(in_shape) - len(scale_factors)) if fw is numpy
else [1] * (len(in_shape) - len(scale_factors)) +
list(scale_factors))
if out_shape is None:
# when no out_shape given, it is calculated by multiplying the
# scale by the in_shape (not recomended)
out_shape = [ceil(scale_factor * in_sz)
for scale_factor, in_sz in
zip(scale_factors, in_shape)]
# next line intentionally after out_shape determined for stability
scale_factors = [float(sf) for sf in scale_factors]
return scale_factors, out_shape
def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):
# we start by having the ouput coordinates which are just integer locations
out_coordinates = fw.arange(out_sz)
# if using torch we need to match the grid tensor device to the input device
out_coordinates = fw_set_device(out_coordinates, device, fw)
# This is projecting the ouput pixel locations in 1d to the input tensor,
# as non-integer locations.
# the following fomrula is derived in the paper
# "From Discrete to Continuous Convolutions" by Shocher et al.
return (out_coordinates / scale_factor +
(in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor))
def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):
# for each output pixel, map which input pixels influence it, in 1d.
# we start by calculating the leftmost neighbor, using half of the window
# size (eps is for when boundary is exact int)
left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)
# then we simply take all the pixel centers in the field by counting
# window size pixels from the left boundary
ordinal_numbers = fw.arange(ceil(cur_support_sz - eps))
# in case using torch we need to match the device
ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw)
field_of_view = left_boundaries[:, None] + ordinal_numbers
# next we do a trick instead of padding, we map the field of view so that
# it would be like mirror padding, without actually padding
# (which would require enlarging the input tensor)
mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw)
field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])]
field_of_view = fw_set_device(field_of_view,projected_grid.device, fw)
return field_of_view
def get_weights(interp_method, projected_grid, field_of_view):
# the set of weights per each output pixels is the result of the chosen
# interpolation method applied to the distances between projected grid
# locations and the pixel-centers in the field of view (distances are
# directed, can be positive or negative)
weights = interp_method(projected_grid[:, None] - field_of_view)
# we now carefully normalize the weights to sum to 1 per each output pixel
sum_weights = weights.sum(1, keepdims=True)
sum_weights[sum_weights == 0] = 1
return weights / sum_weights
def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
antialiasing):
# antialiasing is "stretching" the field of view according to the scale
# factor (only for downscaling). this is low-pass filtering. this
# requires modifying both the interpolation (stretching the 1d
# function and multiplying by the scale-factor) and the window size.
if scale_factor >= 1.0 or not antialiasing:
return interp_method, support_sz
cur_interp_method = (lambda arg: scale_factor *
interp_method(scale_factor * arg))
cur_support_sz = support_sz / scale_factor
return cur_interp_method, cur_support_sz
def fw_ceil(x, fw):
if fw is numpy:
return fw.int_(fw.ceil(x))
else:
return x.ceil().long()
def fw_cat(x, fw):
if fw is numpy:
return fw.concatenate(x)
else:
return fw.cat(x)
def fw_swapaxes(x, ax_1, ax_2, fw):
if fw is numpy:
return fw.swapaxes(x, ax_1, ax_2)
else:
return x.transpose(ax_1, ax_2)
def fw_set_device(x, device, fw):
if fw is numpy:
return x
else:
return x.to(device)
================================================
FILE: code/real/bsrt/option.py
================================================
import argparse
parser = argparse.ArgumentParser(description='EDSR and MDSR')
parser.add_argument('--n_resblocks', type=int, default=16,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64,
help='number of feature maps')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--burst_size', type=int, default=14,
help='burst size, max 14')
parser.add_argument('--burst_channel', type=int, default=4,
help='RAW channel, default:4')
parser.add_argument('--swinfeature', action='store_true',
help='use swin transformer to extract features')
parser.add_argument('--model_level', type=str, default='S',
help='S: small, L: large')
################## fine-tune ##################
parser.add_argument('--finetune', action='store_true',
help='finetune model')
parser.add_argument('--finetune_align', action='store_true',
help='finetune alignment module')
parser.add_argument('--finetune_swin', action='store_true',
help='finetune swin trans module')
parser.add_argument('--finetune_conv', action='store_true',
help='finetune rest convs')
parser.add_argument('--finetune_prelayer', action='store_true',
help='finetune finetune pre feature extract layer')
parser.add_argument('--finetune_upconv', action='store_true',
help='finetune finetune up conv layer')
parser.add_argument('--finetune_spynet', action='store_true',
help='finetune finetune up conv layer')
# Hardware specifications
parser.add_argument('--n_threads', type=int, default=6,
help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',
help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=1,
help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--local_rank', type=int, default=-1,
help='proc index')
parser.add_argument('--fp16', action='store_true',
help='use fp16 only')
parser.add_argument('--use_checkpoint', action='store_true',
help='use use_checkpoint in swin transformer')
# Data specifications
parser.add_argument('--root', type=str, default='/data/dataset/ntire21/burstsr/real',
help='dataset directory')
parser.add_argument('--val_root', type=str, default='../test_set',
help='dataset directory')
parser.add_argument('--mode', type=str, default='train',
help='demo image directory')
parser.add_argument('--scale', type=str, default='4',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=256,
help='output patch size')
parser.add_argument('--rgb_range', type=int, default=1,
help='maximum value of RGB')
parser.add_argument('--chop', action='store_true',
help='enable memory-efficient forward')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='LRSC_EDVR',
help='model name')
parser.add_argument('--act', type=str, default='relu',
help='activation function')
parser.add_argument('--pre_train', type=str, default='',
help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--res_scale', type=float, default=1,
help='residual scaling')
parser.add_argument('--shift_mean', default=True,
help='subtract pixel mean from the input')
parser.add_argument('--dilation', action='store_true',
help='use dilated convolution')
parser.add_argument('--precision', type=str, default='single',
choices=('single', 'half'),
help='FP precision for test (single | half)')
# Option for Residual channel attention network (RCAN)
parser.add_argument('--n_resgroups', type=int, default=20,
help='number of residual groups')
parser.add_argument('--reduction', type=int, default=16,
help='number of feature maps reduction')
parser.add_argument('--DA', action='store_true',
help='use Dual Attention')
parser.add_argument('--CA', action='store_true',
help='use Channel Attention')
parser.add_argument('--non_local', action='store_true',
help='use Dual Attention')
# Training specifications
parser.add_argument('--reset', action='store_true',
help='reset the training')
parser.add_argument('--test_every', type=int, default=1000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=100,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=8,
help='input batch size for training')
parser.add_argument('--split_batch', type=int, default=1,
help='split the batch into smaller chunks')
parser.add_argument('--self_ensemble', action='store_true',
help='use self-ensemble method for test')
parser.add_argument('--test_only', action='store_true',
help='set this option to test the model')
parser.add_argument('--gan_k', type=int, default=1,
help='k value for adversarial loss')
# Optimization specifications
parser.add_argument('--decay', type=str, default='40-80',
help='learning rate decay type')
parser.add_argument('--gamma', type=float, default=0.5,
help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='ADAM',
choices=('SGD', 'ADAM', 'RMSprop'),
help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum')
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay')
parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
# Loss specifications
parser.add_argument('--loss', type=str, default='1*L1',
help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e8',
help='skipping batch that has large error')
# Log specifications
parser.add_argument('--save', type=str, default='test',
help='file name to save')
parser.add_argument('--load', type=str, default='',
help='file name to load')
parser.add_argument('--resume', type=int, default=0,
help='resume from specific checkpoint')
parser.add_argument('--save_models', action='store_true',
help='save all intermediate models')
parser.add_argument('--print_every', type=int, default=10,
help='how many batches to wait before logging training status')
parser.add_argument('--save_results', action='store_true',
help='save output results')
parser.add_argument('--save_gt', action='store_true',
help='save low-resolution and high-resolution images together')
args = parser.parse_args()
args.scale = list(map(lambda x: int(x), args.scale.split('+')))
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
================================================
FILE: code/real/bsrt/pwcnet/LICENSE
================================================
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc.
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
{one line to give the program's name and a brief idea of what it does.}
Copyright (C) {year} {name of author}
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
{project} Copyright (C) {year} {fullname}
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
.
================================================
FILE: code/real/bsrt/pwcnet/README.md
================================================
# pytorch-pwc
This is a personal reimplementation of PWC-Net [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the | | |