Showing preview only (1,142K chars total). Download the full file or copy to clipboard to get everything.
Repository: Algolzw/BSRT
Branch: main
Commit: da8d5154478a
Files: 184
Total size: 1.1 MB
Directory structure:
gitextract_kbz2i0m2/
├── .gitignore
├── LICENSE
├── README.md
├── code/
│ ├── real/
│ │ └── bsrt/
│ │ ├── README.md
│ │ ├── data_processing/
│ │ │ ├── __init__.py
│ │ │ ├── camera_pipeline.py
│ │ │ └── synthetic_burst_generation.py
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── burstsr_dataset.py
│ │ │ ├── burstsr_test_dataset.py
│ │ │ ├── data_sampler.py
│ │ │ ├── realworld_burst_test_set.py
│ │ │ ├── synthetic_burst_test_set.py
│ │ │ ├── synthetic_burst_train_set.py
│ │ │ ├── synthetic_burst_val_set.py
│ │ │ └── zurich_raw2rgb_dataset.py
│ │ ├── demo.sh
│ │ ├── loss/
│ │ │ ├── Charbonnier.py
│ │ │ ├── __init__.py
│ │ │ ├── adversarial.py
│ │ │ ├── discriminator.py
│ │ │ ├── filter.py
│ │ │ ├── hist_entropy.py
│ │ │ ├── mssim.py
│ │ │ └── vgg.py
│ │ ├── main.py
│ │ ├── model/
│ │ │ ├── DCNv2/
│ │ │ │ ├── LICENSE
│ │ │ │ ├── README.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dcn_v2.py
│ │ │ │ ├── files.txt
│ │ │ │ ├── make.sh
│ │ │ │ ├── setup.py
│ │ │ │ ├── src/
│ │ │ │ │ ├── cpu/
│ │ │ │ │ │ ├── dcn_v2_cpu.cpp
│ │ │ │ │ │ ├── dcn_v2_im2col_cpu.cpp
│ │ │ │ │ │ ├── dcn_v2_im2col_cpu.h
│ │ │ │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp
│ │ │ │ │ │ └── vision.h
│ │ │ │ │ ├── cuda/
│ │ │ │ │ │ ├── dcn_v2_cuda.cu
│ │ │ │ │ │ ├── dcn_v2_im2col_cuda.cu
│ │ │ │ │ │ ├── dcn_v2_im2col_cuda.h
│ │ │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu
│ │ │ │ │ │ └── vision.h
│ │ │ │ │ ├── dcn_v2.h
│ │ │ │ │ └── vision.cpp
│ │ │ │ └── test.py
│ │ │ ├── __init__.py
│ │ │ ├── arch_util.py
│ │ │ ├── bsrt.py
│ │ │ ├── checkpoint.py
│ │ │ ├── common.py
│ │ │ ├── non_local/
│ │ │ │ ├── network.py
│ │ │ │ ├── non_local_concatenation.py
│ │ │ │ ├── non_local_cross_dot_product.py
│ │ │ │ ├── non_local_dot_product.py
│ │ │ │ ├── non_local_embedded_gaussian.py
│ │ │ │ └── non_local_gaussian.py
│ │ │ ├── swin_util.py
│ │ │ └── utils/
│ │ │ ├── interp_methods.py
│ │ │ ├── psconv.py
│ │ │ └── resize_right.py
│ │ ├── option.py
│ │ ├── pwcnet/
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── comparison/
│ │ │ │ └── comparison.py
│ │ │ ├── correlation/
│ │ │ │ ├── README.md
│ │ │ │ └── correlation.py
│ │ │ ├── download.bash
│ │ │ ├── images/
│ │ │ │ └── README.md
│ │ │ ├── out.flo
│ │ │ ├── pwcnet.py
│ │ │ ├── requirements.txt
│ │ │ └── run.py
│ │ ├── requirements.txt
│ │ ├── scripts/
│ │ │ ├── __init__.py
│ │ │ ├── cal_mean_std.py
│ │ │ ├── demo.sh
│ │ │ ├── download_burstsr_dataset.py
│ │ │ ├── evaluate.sh
│ │ │ ├── evaluate_burstsr_val.py
│ │ │ ├── save_results_synburst_val.py
│ │ │ ├── test_burstsr_dataset.py
│ │ │ └── test_synthetic_bursts.py
│ │ ├── test.py
│ │ ├── test_real.py
│ │ ├── trainer.py
│ │ ├── utility.py
│ │ ├── utils/
│ │ │ ├── __init__.py
│ │ │ ├── data_format_utils.py
│ │ │ ├── debayer.py
│ │ │ ├── interp_methods.py
│ │ │ ├── metrics.py
│ │ │ ├── postprocessing_functions.py
│ │ │ ├── resize_right.py
│ │ │ ├── spatial_color_alignment.py
│ │ │ ├── stn.py
│ │ │ └── warp.py
│ │ └── validate.py
│ └── synthetic/
│ └── bsrt/
│ ├── README.md
│ ├── data_processing/
│ │ ├── __init__.py
│ │ ├── camera_pipeline.py
│ │ └── synthetic_burst_generation.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── burstsr_dataset.py
│ │ ├── burstsr_test_dataset.py
│ │ ├── data_sampler.py
│ │ ├── realworld_burst_test_set.py
│ │ ├── synthetic_burst_test_set.py
│ │ ├── synthetic_burst_train_set.py
│ │ ├── synthetic_burst_val_set.py
│ │ └── zurich_raw2rgb_dataset.py
│ ├── demo.sh
│ ├── loss/
│ │ ├── Charbonnier.py
│ │ ├── __init__.py
│ │ ├── adversarial.py
│ │ ├── discriminator.py
│ │ ├── filter.py
│ │ ├── hist_entropy.py
│ │ ├── mssim.py
│ │ └── vgg.py
│ ├── main.py
│ ├── model/
│ │ ├── DCNv2/
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── dcn_v2.py
│ │ │ ├── files.txt
│ │ │ ├── make.sh
│ │ │ ├── setup.py
│ │ │ ├── src/
│ │ │ │ ├── cpu/
│ │ │ │ │ ├── dcn_v2_cpu.cpp
│ │ │ │ │ ├── dcn_v2_im2col_cpu.cpp
│ │ │ │ │ ├── dcn_v2_im2col_cpu.h
│ │ │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp
│ │ │ │ │ └── vision.h
│ │ │ │ ├── cuda/
│ │ │ │ │ ├── dcn_v2_cuda.cu
│ │ │ │ │ ├── dcn_v2_im2col_cuda.cu
│ │ │ │ │ ├── dcn_v2_im2col_cuda.h
│ │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu
│ │ │ │ │ └── vision.h
│ │ │ │ ├── dcn_v2.h
│ │ │ │ └── vision.cpp
│ │ │ └── test.py
│ │ ├── __init__.py
│ │ ├── arch_util.py
│ │ ├── bsrt.py
│ │ ├── checkpoint.py
│ │ ├── common.py
│ │ ├── ebsr.py
│ │ ├── non_local/
│ │ │ ├── network.py
│ │ │ ├── non_local_concatenation.py
│ │ │ ├── non_local_cross_dot_product.py
│ │ │ ├── non_local_dot_product.py
│ │ │ ├── non_local_embedded_gaussian.py
│ │ │ └── non_local_gaussian.py
│ │ ├── swin_util.py
│ │ └── utils/
│ │ ├── interp_methods.py
│ │ ├── psconv.py
│ │ └── resize_right.py
│ ├── option.py
│ ├── requirements.txt
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── cal_mean_std.py
│ │ ├── demo.sh
│ │ ├── download_burstsr_dataset.py
│ │ ├── evaluate.sh
│ │ ├── evaluate_burstsr_val.py
│ │ ├── save_results_synburst_val.py
│ │ ├── test_burstsr_dataset.py
│ │ └── test_synthetic_bursts.py
│ ├── test.py
│ ├── test_synburst.py
│ ├── trainer.py
│ ├── utility.py
│ └── utils/
│ ├── __init__.py
│ ├── data_format_utils.py
│ ├── debayer.py
│ ├── interp_methods.py
│ ├── metrics.py
│ ├── postprocessing_functions.py
│ ├── resize_right.py
│ ├── spatial_color_alignment.py
│ ├── stn.py
│ └── warp.py
└── requirements.txt
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Megvii Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
## BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (CVPRW 2022)
[](https://paperswithcode.com/sota/burst-image-super-resolution-on-burstsr?p=bsrt-improving-burst-super-resolution-with) [](https://paperswithcode.com/sota/burst-image-super-resolution-on?p=bsrt-improving-burst-super-resolution-with)
#### [BSRT](https://arxiv.org/abs/2204.08332), the winner of the NTIRE 2022 Burst Super-Resolution Challenge Real-World Track.
You can also find our winner method in NTIRE 2021 Burst Super-Resolution Challenge [here](https://github.com/Algolzw/EBSR).
> This work addresses the Burst Super-Resolution (BurstSR) task using a new architecture, which requires restoring a high-quality image from a sequence of noisy, misaligned, and low-resolution RAW bursts. To overcome the challenges in BurstSR, we propose a **B**urst **S**uper-**R**esolution **T**ransformer (**BSRT**), which can significantly improve the capability of extracting inter-frame information and reconstruction. To achieve this goal, we propose a Pyramid Flow-Guided Deformable Convolution Network (Pyramid FG-DCN) and incorporate Swin Transformer Blocks and Groups as our main backbone. More specifically, we combine optical flows and deformable convolutions, hence our BSRT can handle misalignment and aggregate the potential texture information in multi-frames more efficiently. In addition, our Transformer-based structure can capture long-range dependency to further improve the performance. The evaluation on both synthetic and real-world tracks demonstrates that our approach achieves a new state-of-the-art in BurstSR task. Further, our BSRT wins the championship in the NTIRE2022 Burst Super-Resolution Challenge.
#### Comparison with State-of-the-art Burst Super-Resolution Methods

## Overview Architecture

## Dependencies
- OS: Ubuntu 18.04
- Python: Python 3.7
- nvidia :
- cuda: 10.1
- cudnn: 7.6.1
- Other reference requirements
## Quick Start
1.Create a conda virtual environment and activate it
```python3
conda create -n pytorch_1.6 python=3.7
source activate pytorch_1.6
```
2.Install PyTorch and torchvision following the official instructions
```python3
conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch
```
3.Install build requirements
```python3
pip3 install -r requirements.txt
```
4.Install DCN
```python3
cd DCNv2
python3 setup.py build develop # build
python3 test.py # run examples and check
```
## Training
We provide all pretrained model weights [here](https://drive.google.com/file/d/1Bv1ZwoE3s8trhG--wjB0Yt6WJIQPpvsn/view?usp=sharing).
#### For Synthetic data
```python3
cd code/synthetic/bsrt
# Modify the root path of training dataset and model etc.
# The number of GPUs should be more than 1
python main.py --n_GPUs 8 --print_every 40 --lr 0.0001 --decay 150-300 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 32 --burst_size 14 --patch_size 256
```
#### For Real-World data
```python3
cd code/real/bsrt
# Modify the root path of training dataset and model etc.
# The number of GPUs should be more than 1
python main.py --n_GPUs 8 --print_every 20 --lr 0.00005 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth
```
The pretrained PWC-Net model can be downloaded [here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing).
## Test
#### For Synthetic data
```python3
# Modify the path of test dataset and the path of the trained model
python test_synburst.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth --root /data/dataset/ntire21/burstsr/synthetic
```
#### For Real-World data
```python3
# Modify the path of test dataset and the path of the trained model
python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real
```
## Results
### Comparison on Synthetic dataset

### Comparison on Real-World dataset

## Citations
If our code helps your research or work, please consider citing our paper.
The following is a BibTeX reference.
```
@inproceedings{luo2022bsrt,
title={BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment},
author={Luo, Ziwei and Li, Youwei and Cheng, Shen and Yu, Lei and Wu, Qi and Wen, Zhihong and Fan, Haoqiang and Sun, Jian and Liu, Shuaicheng},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={998--1008},
year={2022}
}
```
## Contact
email: [ziwei.ro@gmail.com]
================================================
FILE: code/real/bsrt/README.md
================================================
# BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (Real-World)
## Dependencies
- OS: Ubuntu 18.04
- Python: Python 3.7
- nvidia :
- cuda: 10.1
- cudnn: 7.6.1
- Other reference requirements
## Quick Start
1.Create a conda virtual environment and activate it
```python3
conda create -n pytorch_1.6 python=3.7
source activate pytorch_1.6
```
2.Install PyTorch and torchvision following the official instructions
```python3
conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch
```
3.Install build requirements
```python3
pip3 install -r requirements.txt
```
4.Install DCN
```python3
cd DCNv2
python3 setup.py build develop # build
python3 test.py # run examples and check
```
## Training
The pretrained PWC-Net model can be downloaded [here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing).
```python3
# Modify the root path of training dataset and model etc.
# The number of GPUs should be more than 1
python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth
```
## Test
```python3
# Modify the path of test dataset and the path of the trained model
python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real
```
================================================
FILE: code/real/bsrt/data_processing/__init__.py
================================================
================================================
FILE: code/real/bsrt/data_processing/camera_pipeline.py
================================================
import torch
import random
import math
import cv2 as cv
import numpy as np
import utils.data_format_utils as df_utils
""" Based on http://timothybrooks.com/tech/unprocessing
Functions for forward and inverse camera pipeline. All functions input a torch float tensor of shape (c, h, w).
Additionally, some also support batch operations, i.e. inputs of shape (b, c, h, w)
"""
def random_ccm():
"""Generates random RGB -> Camera color correction matrices."""
# Takes a random convex combination of XYZ -> Camera CCMs.
xyz2cams = [[[1.0234, -0.2969, -0.2266],
[-0.5625, 1.6328, -0.0469],
[-0.0703, 0.2188, 0.6406]],
[[0.4913, -0.0541, -0.0202],
[-0.613, 1.3513, 0.2906],
[-0.1564, 0.2151, 0.7183]],
[[0.838, -0.263, -0.0639],
[-0.2887, 1.0725, 0.2496],
[-0.0627, 0.1427, 0.5438]],
[[0.6596, -0.2079, -0.0562],
[-0.4782, 1.3016, 0.1933],
[-0.097, 0.1581, 0.5181]]]
num_ccms = len(xyz2cams)
xyz2cams = torch.tensor(xyz2cams)
weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(0.0, 1.0)
weights_sum = weights.sum()
xyz2cam = (xyz2cams * weights).sum(dim=0) / weights_sum
# Multiplies with RGB -> XYZ to get RGB -> Camera CCM.
rgb2xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]])
rgb2cam = torch.mm(xyz2cam, rgb2xyz)
# Normalizes each row.
rgb2cam = rgb2cam / rgb2cam.sum(dim=-1, keepdims=True)
return rgb2cam
def random_gains():
"""Generates random gains for brightening and white balance."""
# RGB gain represents brightening.
rgb_gain = 1.0 / random.gauss(mu=0.8, sigma=0.1)
# Red and blue gains represent white balance.
red_gain = random.uniform(1.9, 2.4)
blue_gain = random.uniform(1.5, 1.9)
return rgb_gain, red_gain, blue_gain
def apply_smoothstep(image):
"""Apply global tone mapping curve."""
image_out = 3 * image**2 - 2 * image**3
return image_out
def invert_smoothstep(image):
"""Approximately inverts a global tone mapping curve."""
image = image.clamp(0.0, 1.0)
return 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)
def gamma_expansion(image):
"""Converts from gamma to linear space."""
# Clamps to prevent numerical instability of gradients near zero.
return image.clamp(1e-8) ** 2.2
def gamma_compression(image):
"""Converts from linear to gammaspace."""
# Clamps to prevent numerical instability of gradients near zero.
return image.clamp(1e-8) ** (1.0 / 2.2)
def apply_ccm(image, ccm):
"""Applies a color correction matrix."""
assert image.dim() == 3 and image.shape[0] == 3
shape = image.shape
image = image.view(3, -1)
ccm = ccm.to(image.device).type_as(image)
image = torch.mm(ccm, image)
return image.view(shape)
def apply_gains(image, rgb_gain, red_gain, blue_gain):
"""Inverts gains while safely handling saturated pixels."""
assert image.dim() == 3 and image.shape[0] in [3, 4]
if image.shape[0] == 3:
gains = torch.tensor([red_gain, 1.0, blue_gain]) * rgb_gain
else:
gains = torch.tensor([red_gain, 1.0, 1.0, blue_gain]) * rgb_gain
gains = gains.view(-1, 1, 1)
gains = gains.to(image.device).type_as(image)
return (image * gains).clamp(0.0, 1.0)
def safe_invert_gains(image, rgb_gain, red_gain, blue_gain):
"""Inverts gains while safely handling saturated pixels."""
assert image.dim() == 3 and image.shape[0] == 3
gains = torch.tensor([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain
gains = gains.view(-1, 1, 1)
# Prevents dimming of saturated pixels by smoothly masking gains near white.
gray = image.mean(dim=0, keepdims=True)
inflection = 0.9
mask = ((gray - inflection).clamp(0.0) / (1.0 - inflection)) ** 2.0
safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)
return image * safe_gains
def mosaic(image, mode='rggb'):
"""Extracts RGGB Bayer planes from an RGB image."""
shape = image.shape
if image.dim() == 3:
image = image.unsqueeze(0)
if mode == 'rggb':
red = image[:, 0, 0::2, 0::2]
green_red = image[:, 1, 0::2, 1::2]
green_blue = image[:, 1, 1::2, 0::2]
blue = image[:, 2, 1::2, 1::2]
image = torch.stack((red, green_red, green_blue, blue), dim=1)
elif mode == 'grbg':
green_red = image[:, 1, 0::2, 0::2]
red = image[:, 0, 0::2, 1::2]
blue = image[:, 2, 0::2, 1::2]
green_blue = image[:, 1, 1::2, 1::2]
image = torch.stack((green_red, red, blue, green_blue), dim=1)
if len(shape) == 3:
return image.view((4, shape[-2] // 2, shape[-1] // 2))
else:
return image.view((-1, 4, shape[-2] // 2, shape[-1] // 2))
def demosaic(image):
assert isinstance(image, torch.Tensor)
image = image.clamp(0.0, 1.0) * 255
if image.dim() == 4:
num_images = image.dim()
batch_input = True
else:
num_images = 1
batch_input = False
image = image.unsqueeze(0)
# Generate single channel input for opencv
im_sc = torch.zeros((num_images, image.shape[-2] * 2, image.shape[-1] * 2, 1))
im_sc[:, ::2, ::2, 0] = image[:, 0, :, :]
im_sc[:, ::2, 1::2, 0] = image[:, 1, :, :]
im_sc[:, 1::2, ::2, 0] = image[:, 2, :, :]
im_sc[:, 1::2, 1::2, 0] = image[:, 3, :, :]
im_sc = im_sc.numpy().astype(np.uint8)
out = []
for im in im_sc:
# cv.imwrite('frames/tmp.png', im)
im_dem_np = cv.cvtColor(im, cv.COLOR_BAYER_BG2RGB)#_VNG)
# Convert to torch image
im_t = df_utils.npimage_to_torch(im_dem_np, input_bgr=False)
out.append(im_t)
if batch_input:
return torch.stack(out, dim=0)
else:
return out[0]
def random_noise_levels():
"""Generates random noise levels from a log-log linear distribution."""
log_min_shot_noise = math.log(0.0001)
log_max_shot_noise = math.log(0.012)
log_shot_noise = random.uniform(log_min_shot_noise, log_max_shot_noise)
shot_noise = math.exp(log_shot_noise)
line = lambda x: 2.18 * x + 1.20
log_read_noise = line(log_shot_noise) + random.gauss(mu=0.0, sigma=0.26)
read_noise = math.exp(log_read_noise)
return shot_noise, read_noise
def add_noise(image, shot_noise=0.01, read_noise=0.0005):
"""Adds random shot (proportional to image) and read (independent) noise."""
variance = image * shot_noise + read_noise
noise = torch.FloatTensor(image.shape).normal_().to(image.device)*variance.sqrt()
return image + noise
def process_linear_image_rgb(image, meta_info, return_np=False):
image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
image = apply_ccm(image, meta_info['cam2rgb'])
if meta_info['gamma']:
image = gamma_compression(image)
if meta_info['smoothstep']:
image = apply_smoothstep(image)
image = image.clamp(0.0, 1.0)
if return_np:
image = df_utils.torch_to_npimage(image)
return image
def process_linear_image_raw(image, meta_info):
image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
image = demosaic(image)
image = apply_ccm(image, meta_info['cam2rgb'])
if meta_info['gamma']:
image = gamma_compression(image)
if meta_info['smoothstep']:
image = apply_smoothstep(image)
return image.clamp(0.0, 1.0)
================================================
FILE: code/real/bsrt/data_processing/synthetic_burst_generation.py
================================================
import torch
import random
import cv2
import numpy as np
import torch.nn.functional as F
from data_processing.camera_pipeline import *
from utils.data_format_utils import torch_to_numpy, numpy_to_torch
def random_crop(frames, crop_sz):
""" Extract a random crop of size crop_sz from the input frames. If the crop_sz is larger than the input image size,
then the largest possible crop of same aspect ratio as crop_sz will be extracted from frames, and upsampled to
crop_sz.
"""
if not isinstance(crop_sz, (tuple, list)):
crop_sz = (crop_sz, crop_sz)
crop_sz = torch.tensor(crop_sz).float()
shape = frames.shape
# Select scale_factor. Ensure the crop fits inside the image
max_scale_factor = torch.tensor(shape[-2:]).float() / crop_sz
max_scale_factor = max_scale_factor.min().item()
if max_scale_factor < 1.0:
scale_factor = max_scale_factor
else:
scale_factor = 1.0
# Extract the crop
orig_crop_sz = (crop_sz * scale_factor).floor()
assert orig_crop_sz[-2] <= shape[-2] and orig_crop_sz[-1] <= shape[-1], 'Bug in crop size estimation!'
r1 = random.randint(0, shape[-2] - orig_crop_sz[-2])
c1 = random.randint(0, shape[-1] - orig_crop_sz[-1])
r2 = r1 + orig_crop_sz[0].int().item()
c2 = c1 + orig_crop_sz[1].int().item()
frames_crop = frames[:, r1:r2, c1:c2]
# Resize to crop_sz
if scale_factor < 1.0:
frames_crop = F.interpolate(frames_crop.unsqueeze(0), size=crop_sz.int().tolist(), mode='bilinear', align_corners=False).squeeze(0)
return frames_crop
def rgb2rawburst(image, burst_size, downsample_factor=1, burst_transformation_params=None,
image_processing_params=None, interpolation_type='bilinear'):
""" Generates a synthetic LR RAW burst from the input image. The input sRGB image is first converted to linear
sensor space using an inverse camera pipeline. A LR burst is then generated by applying random
transformations defined by burst_transformation_params to the input image, and downsampling it by the
downsample_factor. The generated burst is then mosaicekd and corrputed by random noise.
"""
if image_processing_params is None:
image_processing_params = {}
_defaults = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 'gamma': True, 'add_noise': True}
for k, v in _defaults.items():
if k not in image_processing_params:
image_processing_params[k] = v
# Sample camera pipeline params
if image_processing_params['random_ccm']:
rgb2cam = random_ccm()
else:
rgb2cam = torch.eye(3).float()
cam2rgb = rgb2cam.inverse()
# Sample gains
if image_processing_params['random_gains']:
rgb_gain, red_gain, blue_gain = random_gains()
else:
rgb_gain, red_gain, blue_gain = (1.0, 1.0, 1.0)
# Approximately inverts global tone mapping.
use_smoothstep = image_processing_params['smoothstep']
if use_smoothstep:
image = invert_smoothstep(image)
# Inverts gamma compression.
use_gamma = image_processing_params['gamma']
if use_gamma:
image = gamma_expansion(image)
# Inverts color correction.
image = apply_ccm(image, rgb2cam)
# Approximately inverts white balance and brightening.
image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain)
# Clip saturated pixels.
image = image.clamp(0.0, 1.0)
# Generate LR burst
image_burst_rgb, flow_vectors = single2lrburst(image, burst_size=burst_size,
downsample_factor=downsample_factor,
transformation_params=burst_transformation_params,
interpolation_type=interpolation_type)
# mosaic
image_burst = mosaic(image_burst_rgb.clone())
# Add noise
if image_processing_params['add_noise']:
shot_noise_level, read_noise_level = random_noise_levels()
image_burst = add_noise(image_burst, shot_noise_level, read_noise_level)
else:
shot_noise_level = 0
read_noise_level = 0
# Clip saturated pixels.
image_burst = image_burst.clamp(0.0, 1.0)
meta_info = {'rgb2cam': rgb2cam, 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain,
'blue_gain': blue_gain, 'smoothstep': use_smoothstep, 'gamma': use_gamma,
'shot_noise_level': shot_noise_level, 'read_noise_level': read_noise_level}
return image_burst, image, image_burst_rgb, flow_vectors, meta_info
def get_tmat(image_shape, translation, theta, shear_values, scale_factors):
""" Generates a transformation matrix corresponding to the input transformation parameters """
im_h, im_w = image_shape
t_mat = np.identity(3)
t_mat[0, 2] = translation[0]
t_mat[1, 2] = translation[1]
t_rot = cv2.getRotationMatrix2D((im_w * 0.5, im_h * 0.5), theta, 1.0)
t_rot = np.concatenate((t_rot, np.array([0.0, 0.0, 1.0]).reshape(1, 3)))
t_shear = np.array([[1.0, shear_values[0], -shear_values[0] * 0.5 * im_w],
[shear_values[1], 1.0, -shear_values[1] * 0.5 * im_h],
[0.0, 0.0, 1.0]])
t_scale = np.array([[scale_factors[0], 0.0, 0.0],
[0.0, scale_factors[1], 0.0],
[0.0, 0.0, 1.0]])
t_mat = t_scale @ t_rot @ t_shear @ t_mat
t_mat = t_mat[:2, :]
return t_mat
def single2lrburst(image, burst_size, downsample_factor=1, transformation_params=None,
interpolation_type='bilinear'):
""" Generates a burst of size burst_size from the input image by applying random transformations defined by
transformation_params, and downsampling the resulting burst by downsample_factor.
"""
if interpolation_type == 'bilinear':
interpolation = cv2.INTER_LINEAR
elif interpolation_type == 'lanczos':
interpolation = cv2.INTER_LANCZOS4
else:
raise ValueError
normalize = False
if isinstance(image, torch.Tensor):
if image.max() < 2.0:
image = image * 255.0
normalize = True
image = torch_to_numpy(image).astype(np.uint8)
burst = []
sample_pos_inv_all = []
rvs, cvs = torch.meshgrid([torch.arange(0, image.shape[0]),
torch.arange(0, image.shape[1])])
sample_grid = torch.stack((cvs, rvs, torch.ones_like(cvs)), dim=-1).float()
for i in range(burst_size):
if i == 0:
# For base image, do not apply any random transformations. We only translate the image to center the
# sampling grid
shift = (downsample_factor / 2.0) - 0.5
translation = (shift, shift)
theta = 0.0
shear_factor = (0.0, 0.0)
scale_factor = (1.0, 1.0)
else:
# Sample random image transformation parameters
max_translation = transformation_params.get('max_translation', 0.0)
if max_translation <= 0.01:
shift = (downsample_factor / 2.0) - 0.5
translation = (shift, shift)
else:
translation = (random.uniform(-max_translation, max_translation),
random.uniform(-max_translation, max_translation))
max_rotation = transformation_params.get('max_rotation', 0.0)
theta = random.uniform(-max_rotation, max_rotation)
max_shear = transformation_params.get('max_shear', 0.0)
shear_x = random.uniform(-max_shear, max_shear)
shear_y = random.uniform(-max_shear, max_shear)
shear_factor = (shear_x, shear_y)
max_ar_factor = transformation_params.get('max_ar_factor', 0.0)
ar_factor = np.exp(random.uniform(-max_ar_factor, max_ar_factor))
max_scale = transformation_params.get('max_scale', 0.0)
scale_factor = np.exp(random.uniform(-max_scale, max_scale))
scale_factor = (scale_factor, scale_factor * ar_factor)
output_sz = (image.shape[1], image.shape[0])
# Generate a affine transformation matrix corresponding to the sampled parameters
t_mat = get_tmat((image.shape[0], image.shape[1]), translation, theta, shear_factor, scale_factor)
t_mat_tensor = torch.from_numpy(t_mat)
# Apply the sampled affine transformation
image_t = cv2.warpAffine(image, t_mat, output_sz, flags=interpolation,
borderMode=cv2.BORDER_CONSTANT)
t_mat_tensor_3x3 = torch.cat((t_mat_tensor.float(), torch.tensor([0.0, 0.0, 1.0]).view(1, 3)), dim=0)
t_mat_tensor_inverse = t_mat_tensor_3x3.inverse()[:2, :].contiguous()
sample_pos_inv = torch.mm(sample_grid.view(-1, 3), t_mat_tensor_inverse.t().float()).view(
*sample_grid.shape[:2], -1)
if transformation_params.get('border_crop') is not None:
border_crop = transformation_params.get('border_crop')
image_t = image_t[border_crop:-border_crop, border_crop:-border_crop, :]
sample_pos_inv = sample_pos_inv[border_crop:-border_crop, border_crop:-border_crop, :]
# Downsample the image
image_t = cv2.resize(image_t, None, fx=1.0 / downsample_factor, fy=1.0 / downsample_factor,
interpolation=interpolation)
sample_pos_inv = cv2.resize(sample_pos_inv.numpy(), None, fx=1.0 / downsample_factor,
fy=1.0 / downsample_factor,
interpolation=interpolation)
sample_pos_inv = torch.from_numpy(sample_pos_inv).permute(2, 0, 1).contiguous()
if normalize:
image_t = numpy_to_torch(image_t).float() / 255.0
else:
image_t = numpy_to_torch(image_t).float()
burst.append(image_t)
sample_pos_inv_all.append(sample_pos_inv / downsample_factor)
burst_images = torch.stack(burst)
sample_pos_inv_all = torch.stack(sample_pos_inv_all)
# Compute the flow vectors to go from the i'th burst image to the base image
flow_vectors = sample_pos_inv_all - sample_pos_inv_all[:, :1, ...]
return burst_images, flow_vectors
================================================
FILE: code/real/bsrt/datasets/__init__.py
================================================
================================================
FILE: code/real/bsrt/datasets/burstsr_dataset.py
================================================
import os
import torch
import cv2
import numpy as np
import pickle as pkl
import torch.nn.functional as F
import random
import time
class SamsungRAWImage:
@staticmethod
def load(path):
im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED)
im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16)
im_raw = torch.from_numpy(im_raw)
meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), "rb", -1))
return SamsungRAWImage(im_raw, meta_data['black_level'], meta_data['cam_wb'],
meta_data['daylight_wb'], meta_data['color_matrix'], meta_data['exif_data'],
meta_data.get('crop_info', None), meta_data.get('im_preview', None))
def __init__(self, im_raw, black_level, cam_wb, daylight_wb, color_matrix, exif_data, crop_info=None,
im_preview=None):
self.im_raw = im_raw
self.black_level = black_level
self.cam_wb = cam_wb
self.daylight_wb = daylight_wb
self.color_matrix = color_matrix
self.exif_data = exif_data
self.crop_info = crop_info
self.im_preview = im_preview
self.norm_factor = 1023.0
def get_all_meta_data(self):
return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb,
'color_matrix': self.color_matrix.tolist()}
def get_exposure_time(self):
return self.exif_data['Image ExposureTime'].values[0].decimal()
def get_noise_profile(self):
noise = self.exif_data['Image Tag 0xC761'].values
noise = [n[0] for n in noise]
noise = np.array(noise).reshape(3, 2)
return noise
def get_f_number(self):
return self.exif_data['Image FNumber'].values[0].decimal()
def get_iso(self):
return self.exif_data['Image ISOSpeedRatings'].values[0]
def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False):
im_raw = self.im_raw.float()
if substract_black_level:
im_raw = im_raw - torch.tensor(self.black_level).view(4, 1, 1)
if white_balance:
im_raw = im_raw * torch.tensor(self.cam_wb).view(4, 1, 1)
if normalize:
im_raw = im_raw / self.norm_factor
return im_raw
def shape(self):
shape = (4, self.im_raw.shape[1], self.im_raw.shape[2])
return shape
def crop_image(self, r1, r2, c1, c2):
self.im_raw = self.im_raw[:, r1:r2, c1:c2]
def get_crop(self, r1, r2, c1, c2):
im_raw = self.im_raw[:, r1:r2, c1:c2]
if self.im_preview is not None:
im_preview = self.im_preview[2*r1:2*r2, 2*c1:2*c2]
else:
im_preview = None
return SamsungRAWImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.color_matrix,
self.exif_data, im_preview=im_preview)
def postprocess(self, return_np=True, norm_factor=None):
# Convert to rgb
# im = torch.from_numpy(self.im_raw.astype(np.float32))
im = self.im_raw
im = (im - torch.tensor(self.black_level).view(4, 1, 1)) * torch.tensor(self.cam_wb).view(4, 1, 1)
if norm_factor is None:
im = im / im.max()
else:
im = im / norm_factor
im = torch.stack((im[0], (im[1] + im[2])/2, im[3]), dim=0)
# im = torch.stack((im[0], im[1], im[3]), dim=0)
im_out = im.clamp(0.0, 1.0)
if return_np:
im_out = im_out.permute(1, 2, 0).numpy() * 255.0
im_out = im_out.astype(np.uint8)
return im_out
class CanonImage:
@staticmethod
def load(path, split='train'):
im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED)
im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16)
im_raw = torch.from_numpy(im_raw)
meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), "rb", -1))
return CanonImage(im_raw.float(), meta_data['black_level'], meta_data['cam_wb'],
meta_data['daylight_wb'], meta_data['rgb_xyz_matrix'], meta_data.get('exif_data', None),
meta_data.get('crop_info', None))
def __init__(self, im_raw, black_level, cam_wb, daylight_wb, rgb_xyz_matrix, exif_data, crop_info=None):
super(CanonImage, self).__init__()
self.im_raw = im_raw
if len(black_level) == 4:
black_level = [black_level[0], black_level[1], black_level[3]]
self.black_level = black_level
if len(cam_wb) == 4:
cam_wb = [cam_wb[0], cam_wb[1], cam_wb[3]]
self.cam_wb = cam_wb
if len(daylight_wb) == 4:
daylight_wb = [daylight_wb[0], daylight_wb[1], daylight_wb[3]]
self.daylight_wb = daylight_wb
self.rgb_xyz_matrix = rgb_xyz_matrix
self.xyz_srgb_matrix = torch.tensor([3.2404542, -1.5371385, -0.4985314,
-0.9692660, 1.8760108, 0.0415560,
0.0556434, -0.2040259, 1.0572252]).view(3, 3)
self.exif_data = exif_data
self.crop_info = crop_info
self.norm_factor = 16383
def shape(self):
shape = (3, self.im_raw.shape[1], self.im_raw.shape[2])
return shape
def get_all_meta_data(self):
return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb,
'rgb_xyz_matrix': self.rgb_xyz_matrix.tolist(), 'crop_info': self.crop_info,
'norm_factor': self.norm_factor}
def get_exposure_time(self):
return self.exif_data['EXIF ExposureTime'].values[0].decimal()
def get_f_number(self):
return self.exif_data['EXIF FNumber'].values[0].decimal()
def get_iso(self):
return self.exif_data['EXIF ISOSpeedRatings'].values[0]
def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False):
im_raw = self.im_raw.float()
if substract_black_level:
im_raw = im_raw - torch.tensor(self.black_level).view(3, 1, 1)
if white_balance:
im_raw = im_raw * torch.tensor(self.cam_wb).view(3, 1, 1) / 1024.0
if normalize:
im_raw = im_raw / self.norm_factor
return im_raw
def set_image_data(self, im_data):
self.im_raw = im_data
def crop_image(self, r1, r2, c1, c2):
self.im_raw = self.im_raw[:, r1:r2, c1:c2]
def get_crop(self, r1, r2, c1, c2):
im_raw = self.im_raw[:, r1:r2, c1:c2]
return CanonImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.rgb_xyz_matrix,
self.exif_data, self.crop_info)
def set_crop_info(self, crop_info):
self.crop_info = crop_info
def resize(self, size=None, scale_factor=None):
self.im_raw = F.interpolate(self.im_raw.unsqueeze(0), size=size, scale_factor=scale_factor,
mode='bilinear').squeeze(0)
def postprocess(self, return_np=True):
# Convert to rgb
im = self.im_raw
im = (im - torch.tensor(self.black_level).view(3, 1, 1)).float() * torch.tensor(self.cam_wb).view(3, 1, 1)
im_out = im / im.max()
im_out = im_out.clamp(0.0, 1.0)
if return_np:
im_out = im_out.permute(1, 2, 0).numpy() * 255.0
im_out = im_out.astype(np.uint8)
return im_out
def load_txt(path):
with open(path, 'r') as fh:
out = [d.rstrip() for d in fh.readlines()]
return out
class BurstSRDataset(torch.utils.data.Dataset):
""" Real-world burst super-resolution dataset. """
def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='train'):
"""
args:
root : path of the root directory
burst_size : Burst size. Maximum allowed burst size is 14.
crop_sz: Size of the extracted crop. Maximum allowed crop size is 80
center_crop: Whether to extract a random crop, or a centered crop.
random_flip: Whether to apply random horizontal and vertical flip
split: Can be 'train' or 'val'
"""
assert burst_size <= 14, 'burst_sz must be less than or equal to 14'
assert crop_sz <= 80, 'crop_sz must be less than or equal to 80'
assert split in ['train', 'val']
root = root + '/' + split
super().__init__()
self.burst_size = burst_size
self.crop_sz = crop_sz
self.split = split
self.center_crop = center_crop
self.random_flip = random_flip
self.root = root
self.substract_black_level = True
self.white_balance = False
self.burst_list = self._get_burst_list()
def _get_burst_list(self):
burst_list = sorted(os.listdir('{}'.format(self.root)))
# print(burst_list)
return burst_list
def get_burst_info(self, burst_id):
burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]}
return burst_info
def _get_raw_image(self, burst_id, im_id):
raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id))
return raw_image
def _get_gt_image(self, burst_id):
canon_im = CanonImage.load('{}/{}/canon'.format(self.root, self.burst_list[burst_id]), split=self.split)
return canon_im
def get_burst(self, burst_id, im_ids, info=None):
frames = [self._get_raw_image(burst_id, i) for i in im_ids]
gt = self._get_gt_image(burst_id)
if info is None:
info = self.get_burst_info(burst_id)
return frames, gt, info
def _sample_images(self):
burst_size = 14
ids = random.sample(range(1, burst_size), k=self.burst_size - 1)
ids = [0, ] + ids
return ids
def __len__(self):
return len(self.burst_list)
def __getitem__(self, index):
# Sample the images in the burst, in case a burst_size < 14 is used.
im_ids = self._sample_images()
# Read the burst images along with HR ground truth
frames, gt, meta_info = self.get_burst(index, im_ids)
# Extract crop if needed
if frames[0].shape()[-1] != self.crop_sz:
if getattr(self, 'center_crop', False):
r1 = (frames[0].shape()[-2] - self.crop_sz) // 2
c1 = (frames[0].shape()[-1] - self.crop_sz) // 2
else:
r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz)
c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz)
r2 = r1 + self.crop_sz
c2 = c1 + self.crop_sz
scale_factor = gt.shape()[-1] // frames[0].shape()[-1]
frames = [im.get_crop(r1, r2, c1, c2) for im in frames]
gt = gt.get_crop(scale_factor * r1, scale_factor * r2, scale_factor * c1, scale_factor * c2)
# Load the RAW image data
burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level,
white_balance=self.white_balance) for im in frames]
# Convert to tensor
gt_image_data = gt.get_image_data(normalize=True, white_balance=self.white_balance,
substract_black_level=self.substract_black_level)
if self.random_flip:
burst_image_data = [flatten_raw_image(im) for im in burst_image_data]
pad = [0, 0, 0, 0]
if random.random() > 0.5:
burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data]
gt_image_data = gt_image_data.flip([2, ])[:, :, 2:-2].contiguous()
pad[1] = 1
if random.random() > 0.5:
burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data]
gt_image_data = gt_image_data.flip([1, ])[:, 2:-2, :].contiguous()
pad[3] = 1
burst_image_data = [pack_raw_image(im) for im in burst_image_data]
burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data]
gt_image_data = F.pad(gt_image_data.unsqueeze(0), [4 * p for p in pad], mode='replicate').squeeze(0)
burst_image_meta_info = frames[0].get_all_meta_data()
burst_image_meta_info['black_level_subtracted'] = self.substract_black_level
burst_image_meta_info['while_balance_applied'] = self.white_balance
burst_image_meta_info['norm_factor'] = frames[0].norm_factor
gt_image_meta_info = gt.get_all_meta_data()
burst = torch.stack(burst_image_data, dim=0)
burst_exposure = frames[0].get_exposure_time()
canon_exposure = gt.get_exposure_time()
burst_f_number = frames[0].get_f_number()
canon_f_number = gt.get_f_number()
burst_iso = frames[0].get_iso()
canon_iso = gt.get_iso()
# Normalize the GT image to account for differences in exposure, ISO etc
light_factor_burst = burst_exposure * burst_iso / (burst_f_number ** 2)
light_factor_canon = canon_exposure * canon_iso / (canon_f_number ** 2)
exp_scale_factor = (light_factor_burst / light_factor_canon)
gt_image_data = gt_image_data * exp_scale_factor
gt_image_meta_info['black_level_subtracted'] = self.substract_black_level
gt_image_meta_info['while_balance_applied'] = self.white_balance
gt_image_meta_info['norm_factor'] = gt.norm_factor / exp_scale_factor
burst_image_meta_info['exposure'] = burst_exposure
burst_image_meta_info['f_number'] = burst_f_number
burst_image_meta_info['iso'] = burst_iso
gt_image_meta_info['exposure'] = canon_exposure
gt_image_meta_info['f_number'] = canon_f_number
gt_image_meta_info['iso'] = canon_iso
burst = burst.float()
frame_gt = gt_image_data.float()
meta_info_burst = burst_image_meta_info
meta_info_gt = gt_image_meta_info
del meta_info_gt['crop_info']
for k, v in meta_info_gt.items():
if isinstance(v, (list, tuple)):
meta_info_gt[k] = torch.tensor(v)
for k, v in meta_info_burst.items():
if isinstance(v, (list, tuple)):
meta_info_burst[k] = torch.tensor(v)
meta_info_burst['burst_name'] = meta_info['burst_name']
return burst, frame_gt, meta_info_burst, meta_info_gt
def pack_raw_image(im_raw):
if isinstance(im_raw, np.ndarray):
im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2))
elif isinstance(im_raw, torch.Tensor):
im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype).to(im_raw.device)
else:
raise Exception
im_out[0, :, :] = im_raw[0::2, 0::2]
im_out[1, :, :] = im_raw[0::2, 1::2]
im_out[2, :, :] = im_raw[1::2, 0::2]
im_out[3, :, :] = im_raw[1::2, 1::2]
return im_out
def flatten_raw_image(im_raw_4ch):
if isinstance(im_raw_4ch, np.ndarray):
im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2))
elif isinstance(im_raw_4ch, torch.Tensor):
im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device)
else:
raise Exception
im_out[0::2, 0::2] = im_raw_4ch[0, :, :]
im_out[0::2, 1::2] = im_raw_4ch[1, :, :]
im_out[1::2, 0::2] = im_raw_4ch[2, :, :]
im_out[1::2, 1::2] = im_raw_4ch[3, :, :]
return im_out
def pack_raw_image_batch(im_raw):
im_out = torch.zeros((im_raw.shape[0], im_raw.shape[1], 4, im_raw.shape[3] // 2, im_raw.shape[4] // 2), dtype=im_raw.dtype).to(im_raw.device)
im_out[:, :, 0, :, :] = im_raw[:, :, 0, 0::2, 0::2]
im_out[:, :, 1, :, :] = im_raw[:, :, 0, 0::2, 1::2]
im_out[:, :, 2, :, :] = im_raw[:, :, 0, 1::2, 0::2]
im_out[:, :, 3, :, :] = im_raw[:, :, 0, 1::2, 1::2]
return im_out
def flatten_raw_image_batch(im_raw_4ch):
im_out = torch.zeros((im_raw_4ch.shape[0], im_raw_4ch.shape[1], 1, im_raw_4ch.shape[3] * 2, im_raw_4ch.shape[4] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device)
im_out[:, :, 0, 0::2, 0::2] = im_raw_4ch[:, :, 0, :, :]
im_out[:, :, 0, 0::2, 1::2] = im_raw_4ch[:, :, 1, :, :]
im_out[:, :, 0, 1::2, 0::2] = im_raw_4ch[:, :, 2, :, :]
im_out[:, :, 0, 1::2, 1::2] = im_raw_4ch[:, :, 3, :, :]
return im_out
================================================
FILE: code/real/bsrt/datasets/burstsr_test_dataset.py
================================================
import os
import torch
import torch.nn.functional as F
import random
from .burstsr_dataset import SamsungRAWImage, flatten_raw_image, pack_raw_image
class BurstSRDataset(torch.utils.data.Dataset):
""" Real-world burst super-resolution dataset. """
def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='test'):
"""
args:
root : path of the root directory
burst_size : Burst size. Maximum allowed burst size is 14.
crop_sz: Size of the extracted crop. Maximum allowed crop size is 80
center_crop: Whether to extract a random crop, or a centered crop.
random_flip: Whether to apply random horizontal and vertical flip
split: Can be 'train' or 'val'
"""
assert burst_size <= 14, 'burst_sz must be less than or equal to 14'
assert crop_sz <= 80, 'crop_sz must be less than or equal to 80'
assert split in ['test']
root = root + '/' + split
super().__init__()
self.burst_size = burst_size
self.crop_sz = crop_sz
self.split = split
self.center_crop = center_crop
self.random_flip = random_flip
self.root = root
self.substract_black_level = True
self.white_balance = False
self.burst_list = self._get_burst_list()
def _get_burst_list(self):
burst_list = sorted(os.listdir('{}'.format(self.root)))
return burst_list
def get_burst_info(self, burst_id):
burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]}
return burst_info
def _get_raw_image(self, burst_id, im_id):
raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id))
return raw_image
def get_burst(self, burst_id, im_ids, info=None):
frames = [self._get_raw_image(burst_id, i) for i in im_ids]
if info is None:
info = self.get_burst_info(burst_id)
return frames, info
def _sample_images(self):
burst_size = 14
ids = random.sample(range(1, burst_size), k=self.burst_size - 1)
ids = [0, ] + ids
return ids
def __len__(self):
return len(self.burst_list)
def __getitem__(self, index):
# Sample the images in the burst, in case a burst_size < 14 is used.
im_ids = self._sample_images()
# Read the burst images along with HR ground truth
frames, meta_info = self.get_burst(index, im_ids)
# Extract crop if needed
if frames[0].shape()[-1] != self.crop_sz:
if getattr(self, 'center_crop', False):
r1 = (frames[0].shape()[-2] - self.crop_sz) // 2
c1 = (frames[0].shape()[-1] - self.crop_sz) // 2
else:
r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz)
c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz)
r2 = r1 + self.crop_sz
c2 = c1 + self.crop_sz
frames = [im.get_crop(r1, r2, c1, c2) for im in frames]
# Load the RAW image data
burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level,
white_balance=self.white_balance) for im in frames]
if self.random_flip:
burst_image_data = [flatten_raw_image(im) for im in burst_image_data]
pad = [0, 0, 0, 0]
if random.random() > 0.5:
burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data]
pad[1] = 1
if random.random() > 0.5:
burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data]
pad[3] = 1
burst_image_data = [pack_raw_image(im) for im in burst_image_data]
burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data]
burst_image_meta_info = frames[0].get_all_meta_data()
burst_image_meta_info['black_level_subtracted'] = self.substract_black_level
burst_image_meta_info['while_balance_applied'] = self.white_balance
burst_image_meta_info['norm_factor'] = frames[0].norm_factor
burst = torch.stack(burst_image_data, dim=0)
burst_exposure = frames[0].get_exposure_time()
burst_f_number = frames[0].get_f_number()
burst_iso = frames[0].get_iso()
burst_image_meta_info['exposure'] = burst_exposure
burst_image_meta_info['f_number'] = burst_f_number
burst_image_meta_info['iso'] = burst_iso
burst = burst.float()
meta_info_burst = burst_image_meta_info
for k, v in meta_info_burst.items():
if isinstance(v, (list, tuple)):
meta_info_burst[k] = torch.tensor(v)
return burst, meta_info_burst
================================================
FILE: code/real/bsrt/datasets/data_sampler.py
================================================
"""
Modified from torch.utils.data.distributed.DistributedSampler
Support enlarging the dataset for *iter-oriented* training, for saving time when restart the
dataloader after each epoch
"""
import math
import torch
import torch.distributed as dist
from torch.utils.data.sampler import Sampler
class DistIterSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(
self.total_size, generator=g
).tolist() # Returns a random permutation of integers from 0 to n - 1
dsize = len(self.dataset)
indices = [v % dsize for v in indices]
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
================================================
FILE: code/real/bsrt/datasets/realworld_burst_test_set.py
================================================
import torch
import cv2
import numpy as np
import pickle as pkl
class RealWorldBurstTest(torch.utils.data.Dataset):
"""
"""
def __init__(self, root):
self.root = root
self.burst_list = list(range(20))
self.burst_size = 14
def __len__(self):
return len(self.burst_list)
def _read_burst_image(self, index, image_id):
im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)
im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
return im_t
def __getitem__(self, index):
"""
args:
index: Index of the burst
returns:
burst: LR RAW burst, a torch tensor of shape
The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
meta_info: Meta information about the burst
"""
burst_name = '{:04d}'.format(index)
burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
burst = torch.stack(burst, 0)
meta_info = {}
meta_info['burst_name'] = burst_name
return burst, meta_info
================================================
FILE: code/real/bsrt/datasets/synthetic_burst_test_set.py
================================================
import torch
import cv2
import numpy as np
import pickle as pkl
class SyntheticBurstTest(torch.utils.data.Dataset):
""" Synthetic burst test set. The test burst have been generated using the same synthetic pipeline as
employed in SyntheticBurst dataset.
"""
def __init__(self, root):
self.root = root
self.burst_list = list(range(92))
self.burst_size = 14
def __len__(self):
return len(self.burst_list)
def _read_burst_image(self, index, image_id):
im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)
im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
return im_t
def __getitem__(self, index):
""" Generates a synthetic burst
args:
index: Index of the burst
returns:
burst: LR RAW burst, a torch tensor of shape
The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
meta_info: Meta information about the burst
"""
burst_name = '{:04d}'.format(index)
burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
burst = torch.stack(burst, 0)
meta_info = {}
meta_info['burst_name'] = burst_name
return burst, meta_info
================================================
FILE: code/real/bsrt/datasets/synthetic_burst_train_set.py
================================================
import torch
import numpy as np
from PIL import Image
from data_processing.synthetic_burst_generation import rgb2rawburst, random_crop #syn_burst_utils
import torchvision.transforms as tfm
class SyntheticBurst(torch.utils.data.Dataset):
""" Synthetic burst dataset for joint denoising, demosaicking, and super-resolution. RAW Burst sequences are
synthetically generated on the fly as follows. First, a single image is loaded from the base_dataset. The sampled
image is converted to linear sensor space using the inverse camera pipeline employed in [1]. A burst
sequence is then generated by adding random translations and rotations to the converted image. The generated burst
is then converted is then mosaicked, and corrupted by random noise to obtain the RAW burst.
[1] Unprocessing Images for Learned Raw Denoising, Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen,
Jiawen and Sharlet, Dillon and Barron, Jonathan T, CVPR 2019
"""
def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=tfm.ToTensor()):
self.base_dataset = base_dataset
self.burst_size = burst_size
self.crop_sz = crop_sz
self.transform = transform
self.downsample_factor = 4
self.burst_transformation_params = {'max_translation': 24.0,
'max_rotation': 1.0,
'max_shear': 0.0,
'max_scale': 0.0,
'border_crop': 24}
self.image_processing_params = {'random_ccm': True, 'random_gains': True, 'smoothstep': True,
'gamma': True,
'add_noise': True}
self.interpolation_type = 'bilinear'
def __len__(self):
return len(self.base_dataset)
def __getitem__(self, index):
""" Generates a synthetic burst
args:
index: Index of the image in the base_dataset used to generate the burst
returns:
burst: Generated LR RAW burst, a torch tensor of shape
[burst_size, 4, self.crop_sz / (2*self.downsample_factor), self.crop_sz / (2*self.downsample_factor)]
The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
The extra factor 2 in the denominator (2*self.downsample_factor) corresponds to the mosaicking
operation.
frame_gt: The HR RGB ground truth in the linear sensor space, a torch tensor of shape
[3, self.crop_sz, self.crop_sz]
flow_vectors: The ground truth flow vectors between a burst image and the base image (i.e. the first image in the burst).
The flow_vectors can be used to warp the burst images to the base frame, using the 'warp'
function in utils.warp package.
flow_vectors is torch tensor of shape
[burst_size, 2, self.crop_sz / self.downsample_factor, self.crop_sz / self.downsample_factor].
Note that the flow_vectors are in the LR RGB space, before mosaicking. Hence it has twice
the number of rows and columns, compared to the output burst.
NOTE: The flow_vectors are only available during training for the purpose of using any
auxiliary losses if needed. The flow_vectors will NOT be provided for the bursts in the
test set
meta_info: A dictionary containing the parameters used to generate the synthetic burst.
"""
frame = self.base_dataset[index]
# Augmentation, e.g. convert to tensor
if self.transform is not None:
# frame = Image.fromarray(frame)
frame = self.transform(frame)
# Extract a random crop from the image
crop_sz = self.crop_sz + 2 * self.burst_transformation_params.get('border_crop', 0)
frame_crop = random_crop(frame, crop_sz)
# Generate RAW burst
burst, frame_gt, burst_rgb, flow_vectors, meta_info = rgb2rawburst(frame_crop,
self.burst_size,
self.downsample_factor,
burst_transformation_params=self.burst_transformation_params,
image_processing_params=self.image_processing_params,
interpolation_type=self.interpolation_type
)
if self.burst_transformation_params.get('border_crop') is not None:
border_crop = self.burst_transformation_params.get('border_crop')
frame_gt = frame_gt[:, border_crop:-border_crop, border_crop:-border_crop]
return burst, frame_gt, flow_vectors, meta_info
================================================
FILE: code/real/bsrt/datasets/synthetic_burst_val_set.py
================================================
import os
import torch
import cv2
import numpy as np
import pickle as pkl
class SyntheticBurstVal(torch.utils.data.Dataset):
""" Synthetic burst validation set introduced in [1]. The validation burst have been generated using a
synthetic data generation pipeline. The dataset can be downloaded from
https://data.vision.ee.ethz.ch/bhatg/SyntheticBurstVal.zip
[1] Deep Burst Super-Resolution. Goutam Bhat, Martin Danelljan, Luc Van Gool, and Radu Timofte. CVPR 2021
"""
def __init__(self, root=None, initialize=True):
"""
args:
root - Path to root dataset directory
initialize - boolean indicating whether to load the meta-data for the dataset
"""
self.root = os.path.join(root, 'val')
self.burst_list = list(range(300))
self.burst_size = 14
def initialize(self):
pass
def __len__(self):
return len(self.burst_list)
def _read_burst_image(self, index, image_id):
im = cv2.imread('{}/bursts/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)
im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
return im_t
def _read_gt_image(self, index):
gt = cv2.imread('{}/gt/{:04d}/im_rgb.png'.format(self.root, index), cv2.IMREAD_UNCHANGED)
gt_t = (torch.from_numpy(gt.astype(np.float32)) / 2 ** 14).permute(2, 0, 1).float()
return gt_t
def _read_meta_info(self, index):
with open('{}/gt/{:04d}/meta_info.pkl'.format(self.root, index), "rb") as input_file:
meta_info = pkl.load(input_file)
return meta_info
def __getitem__(self, index):
""" Generates a synthetic burst
args:
index: Index of the burst
returns:
burst: LR RAW burst, a torch tensor of shape
[14, 4, 48, 48]
The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
gt : Ground truth linear image
meta_info: Meta info about the burst which can be used to convert gt to sRGB space
"""
burst_name = '{:04d}'.format(index)
burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
burst = torch.stack(burst, 0)
gt = self._read_gt_image(index)
meta_info = self._read_meta_info(index)
meta_info['burst_name'] = burst_name
return burst, gt, meta_info
================================================
FILE: code/real/bsrt/datasets/zurich_raw2rgb_dataset.py
================================================
import torch
import os
import numpy as np
from cv2 import imread
class ZurichRAW2RGB(torch.utils.data.Dataset):
""" Canon RGB images from the "Zurich RAW to RGB mapping" dataset. You can download the full
dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the
Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip
"""
def __init__(self, root, split='train'):
super().__init__()
if split in ['train', 'test']:
self.img_pth = os.path.join(root, split, 'canon')
else:
raise Exception('Unknown split {}'.format(split))
self.image_list = self._get_image_list(split)
def _get_image_list(self, split):
if split == 'train':
image_list = ['{:d}.jpg'.format(i) for i in range(46839)]
elif split == 'test':
# image_list = ['{:d}.jpg'.format(int(i)) for i in np.linspace(1, 1200, 200)]
image_list = ['{:d}.jpg'.format(i) for i in range(1200)]
else:
raise Exception
return image_list
def _get_image(self, im_id):
path = os.path.join(self.img_pth, self.image_list[im_id])
img = imread(path)
return img
def get_image(self, im_id):
frame = self._get_image(im_id)
return frame
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
frame = self._get_image(index)
return frame
================================================
FILE: code/real/bsrt/demo.sh
================================================
#!/usr/bin/env bash
python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth
# python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_large --model BSRT --fp16 --model_level L --swinfeature --batch_size 8 --burst_size 14 --patch_size 48 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_large/bsrt_best_epoch.pth
# python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real
# python test_real.py --n_GPUs 1 --model BSRT --model_level L --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_large/bsrt_realworld.pth --root /data/dataset/ntire21/burstsr/real
================================================
FILE: code/real/bsrt/loss/Charbonnier.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class CharbonnierLoss(nn.Module):
"""L1 charbonnier loss."""
def __init__(self, epsilon=1e-3, reduce=True):
super(CharbonnierLoss, self).__init__()
self.eps = epsilon * epsilon
self.reduce = reduce
def forward(self, X, Y):
diff = torch.add(X, -Y)
error = torch.sqrt(diff * diff + self.eps)
if self.reduce:
loss = torch.mean(error)
else:
loss = error
return loss
================================================
FILE: code/real/bsrt/loss/__init__.py
================================================
import os
from importlib import import_module
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Loss(nn.modules.loss._Loss):
def __init__(self, args, ckp):
super(Loss, self).__init__()
if args.local_rank == 0:
print('Preparing loss function:')
self.n_GPUs = args.n_GPUs
self.loss = []
self.loss_module = nn.ModuleList()
for loss in args.loss.split('+'):
weight, loss_type = loss.split('*')
if loss_type == 'MSE':
loss_function = nn.MSELoss()
elif loss_type == 'L1':
loss_function = nn.L1Loss()
elif loss_type.find('VGG') >= 0:
module = import_module('loss.vgg')
loss_function = getattr(module, 'VGG')(
loss_type[3:],
rgb_range=args.rgb_range
)
elif loss_type.find('GAN') >= 0:
module = import_module('loss.adversarial')
loss_function = getattr(module, 'Adversarial')(
args,
loss_type
)
elif loss_type == 'FILTER':
module = import_module('loss.filter')
loss_function = getattr(module, 'Filter')(args)
elif loss_type == 'SSIM':
module = import_module('loss.mssim')
loss_function = getattr(module, 'SSIM')(args)
elif loss_type == 'MSSSIM':
module = import_module('loss.mssim')
loss_function = getattr(module, 'MSSSIM')(args)
self.loss.append({
'type': loss_type,
'weight': float(weight),
'function': loss_function}
)
if loss_type.find('GAN') >= 0:
self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
if len(self.loss) > 1:
self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
for l in self.loss:
if l['function'] is not None:
if args.local_rank == 0:
print('{:.3f} * {}'.format(l['weight'], l['type']))
self.loss_module.append(l['function'])
self.log = torch.Tensor()
device = torch.device('cpu' if args.cpu else 'cuda')
self.loss_module.to(device)
if args.precision == 'half': self.loss_module.half()
if not args.cpu and args.n_GPUs > 1:
self.loss_module = nn.DataParallel(
self.loss_module, range(args.n_GPUs)
)
if args.load != '': self.load(ckp.dir, cpu=args.cpu)
def forward(self, sr, hr):
losses = []
for i, l in enumerate(self.loss):
if l['function'] is not None:
loss = l['function'](sr, hr)
effective_loss = l['weight'] * loss
losses.append(effective_loss)
self.log[-1, i] += effective_loss.item()
elif l['type'] == 'DIS':
self.log[-1, i] += self.loss[i - 1]['function'].loss
loss_sum = sum(losses)
if len(self.loss) > 1:
self.log[-1, -1] += loss_sum.item()
return loss_sum
def step(self):
for l in self.get_loss_module():
if hasattr(l, 'scheduler'):
l.scheduler.step()
def start_log(self):
self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
def end_log(self, n_batches):
self.log[-1].div_(n_batches)
def display_loss(self, batch):
n_samples = batch + 1
log = []
for l, c in zip(self.loss, self.log[-1]):
log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
return ''.join(log)
def plot_loss(self, apath, epoch):
axis = np.linspace(1, epoch, epoch)
for i, l in enumerate(self.loss):
label = '{} Loss'.format(l['type'])
fig = plt.figure()
plt.title(label)
plt.plot(axis, self.log[:, i].numpy(), label=label)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.grid(True)
plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
plt.close(fig)
def get_loss_module(self):
if self.n_GPUs == 1:
return self.loss_module
else:
return self.loss_module.module
def save(self, apath):
torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
def load(self, apath, cpu=False):
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
self.load_state_dict(torch.load(
os.path.join(apath, 'loss.pt'),
**kwargs
))
self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
for l in self.get_loss_module():
if hasattr(l, 'scheduler'):
for _ in range(len(self.log)): l.scheduler.step()
================================================
FILE: code/real/bsrt/loss/adversarial.py
================================================
import utility
from types import SimpleNamespace
from model import common
from loss import discriminator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Adversarial(nn.Module):
def __init__(self, args, gan_type):
super(Adversarial, self).__init__()
self.gan_type = gan_type
self.gan_k = args.gan_k
self.dis = discriminator.Discriminator(args)
# if gan_type == 'WGAN_GP':
if True:
# see https://arxiv.org/pdf/1704.00028.pdf pp.4
optim_dict = {
'optimizer': 'ADAM',
'betas': (0.5, 0.9),
'epsilon': 1e-8,
'lr': 1e-5,
'weight_decay': args.weight_decay,
'decay': args.decay,
'gamma': args.gamma
}
optim_args = SimpleNamespace(**optim_dict)
else:
optim_args = args
self.optimizer = utility.make_optimizer(optim_args, self.dis)
def forward(self, fake, real):
# updating discriminator...
self.loss = 0
fake_detach = fake.detach() # do not backpropagate through G
for _ in range(self.gan_k):
self.optimizer.zero_grad()
# d: B x 1 tensor
d_fake = self.dis(fake_detach)
d_real = self.dis(real)
retain_graph = False
if self.gan_type in ['GAN', 'SNGAN']:
loss_d = self.bce(d_real, d_fake)
elif self.gan_type.find('WGAN') >= 0:
loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0:
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True
d_hat = self.dis(hat)
gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty
# from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake.mean(dim=0, keepdim=True)
better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
loss_d = self.bce(better_real, better_fake)
retain_graph = True
# Discriminator update
self.loss += loss_d.item()
loss_d.backward(retain_graph=retain_graph)
self.optimizer.step()
if self.gan_type == 'WGAN':
for p in self.dis.parameters():
p.data.clamp_(-1, 1)
self.loss /= self.gan_k
# updating generator...
d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
if self.gan_type in ['GAN', 'SNGAN']:
label_real = torch.ones_like(d_fake_bp)
loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
elif self.gan_type.find('WGAN') >= 0:
loss_g = -d_fake_bp.mean()
elif self.gan_type == 'RGAN':
better_real = d_real.detach() - d_fake_bp.mean(dim=0, keepdim=True)
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True).detach()
loss_g = self.bce(better_fake, better_real)
# Generator loss
return loss_g
def state_dict(self, *args, **kwargs):
state_discriminator = self.dis.state_dict(*args, **kwargs)
state_optimizer = self.optimizer.state_dict()
return dict(**state_discriminator, **state_optimizer)
def bce(self, real, fake):
label_real = torch.ones_like(real)
label_fake = torch.zeros_like(fake)
bce_real = F.binary_cross_entropy_with_logits(real, label_real)
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
bce_loss = bce_real + bce_fake
return bce_loss
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
================================================
FILE: code/real/bsrt/loss/discriminator.py
================================================
from model import common
import torch.nn as nn
class Discriminator(nn.Module):
'''
output is not normalized
'''
def __init__(self, args, gan_type='GAN'):
super(Discriminator, self).__init__()
in_channels = args.n_colors
out_channels = 32
depth = 6
def _block(_in_channels, _out_channels, stride=1):
Conv = nn.Conv2d(
_in_channels,
_out_channels,
3,
padding=1,
stride=stride,
bias=False
)
if gan_type == 'SNGAN':
return nn.Sequential(
spectral_norm(Conv),
nn.BatchNorm2d(_out_channels),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
else:
return nn.Sequential(
Conv,
nn.BatchNorm2d(_out_channels),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
m_features = [_block(in_channels, out_channels)]
for i in range(depth):
in_channels = out_channels
# if i % 2 == 1:
# stride = 1
# out_channels *= 2
# else:
out_channels *= 2
stride = 2
m_features.append(_block(in_channels, out_channels, stride=stride))
patch_size = args.patch_size // 2**(depth-1)
# print(out_channels, patch_size)
m_classifier = [
nn.Flatten(),
nn.Linear(out_channels*patch_size**2, 512),
nn.LeakyReLU(0.2, True),
nn.Linear(512, 1)
]
self.features = nn.Sequential(*m_features)
self.classifier = nn.Sequential(*m_classifier)
def forward(self, x):
features = self.features(x)
# print(features.shape)
output = self.classifier(features)
return output
================================================
FILE: code/real/bsrt/loss/filter.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class Filter(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
kernel = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
self.conv = nn.Conv2d(args.n_colors, args.n_colors, 3, 3)
with torch.no_grad():
self.conv.weight.copy_(kernel.float())
self.loss = nn.L1Loss()
def forward(self, x, y):
preds_x = self.conv(x)
preds_y = self.conv(y)
return self.loss(preds_x, preds_y)
================================================
FILE: code/real/bsrt/loss/hist_entropy.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class HistEntropy(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
def forward(self, x):
p = torch.softmax(x, dim=1)
logp = torch.log_softmax(x, dim=1)
entropy = (-p * logp).sum(dim=(2, 3)).mean()
return entropy
================================================
FILE: code/real/bsrt/loss/mssim.py
================================================
import torch
import torch.nn.functional as F
from math import exp
import numpy as np
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
if torch.max(img1) > 128:
max_val = 255
else:
max_val = 1
if torch.min(img1) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, channel, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window(real_size, channel=channel).to(img1.device)
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None):
device = img1.device
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
levels = weights.size()[0]
ssims = []
mcs = []
for _ in range(levels):
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
# Relu normalize (not compliant with original definition)
if normalize == "relu":
ssims.append(torch.relu(sim))
mcs.append(torch.relu(cs))
else:
ssims.append(sim)
mcs.append(cs)
img1 = F.avg_pool2d(img1, (2, 2))
img2 = F.avg_pool2d(img2, (2, 2))
ssims = torch.stack(ssims)
mcs = torch.stack(mcs)
# Simple normalize (not compliant with original definition)
# TODO: remove support for normalize == True (kept for backward support)
if normalize == "simple" or normalize == True:
ssims = (ssims + 1) / 2
mcs = (mcs + 1) / 2
pow1 = mcs ** weights
pow2 = ssims ** weights
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
output = torch.prod(pow1[:-1] * pow2[-1])
return output
# Classes to re-use window
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.val_range = val_range
# Assume 1 channel for SSIM
self.channel = 1
self.window = create_window(window_size)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
self.window = window
self.channel = channel
return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super(MSSSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel
def forward(self, img1, img2):
# TODO: store window between calls if possible
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
================================================
FILE: code/real/bsrt/loss/vgg.py
================================================
from model import common
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class VGG(nn.Module):
def __init__(self, conv_index, rgb_range=1):
super(VGG, self).__init__()
vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
if conv_index.find('22') >= 0:
self.vgg = nn.Sequential(*modules[:8])
elif conv_index.find('54') >= 0:
self.vgg = nn.Sequential(*modules[:35])
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
for p in self.parameters():
p.requires_grad = False
def forward(self, sr, hr):
def _forward(x):
# x = self.sub_mean(x)
x = self.vgg(x)
return x
sr = sr.repeat(1, 3, 1, 1)
hr = hr.repeat(1, 3, 1, 1)
vgg_sr = _forward(sr)
with torch.no_grad():
vgg_hr = _forward(hr.detach())
loss = F.mse_loss(vgg_sr, vgg_hr)
return loss
================================================
FILE: code/real/bsrt/main.py
================================================
import torch
import random
import numpy as np
from torch.utils.data import DataLoader
import os
import utility
import model
import loss
from option import args
from trainer import Trainer
from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.utils.data.distributed
def init_seeds(seed=0, cuda_deterministic=True):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
checkpoint = utility.checkpoint(args)
def main():
mp.spawn(main_worker, nprocs=args.n_GPUs, args=(args.n_GPUs, args))
def main_worker(local_rank, nprocs, args):
# print(local_rank)
if checkpoint.ok:
args.local_rank = local_rank
init_seeds(local_rank+1)
cudnn.benchmark = True
utility.setup(local_rank, nprocs)
torch.cuda.set_device(local_rank)
batch_size = int(args.batch_size / nprocs)
train_data = BurstSRDataset(root=args.root,
burst_size=args.burst_size,
crop_sz=args.patch_size, random_flip=True,
center_crop=True, split='train')
valid_data = BurstSRDataset(root=args.root,
burst_size=14,
crop_sz=80, split='val')
if local_rank <= 0:
print(f"train data: {len(train_data)}, test data: {len(valid_data)}")
if nprocs > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data, shuffle=False)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, num_workers=args.batch_size,
pin_memory=True, drop_last=True, sampler=train_sampler) # args.cpus
valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, num_workers=args.batch_size,
pin_memory=True, drop_last=True, sampler=valid_sampler) # args.cpus
else:
train_sampler = None
train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=8,
shuffle=True, pin_memory=True, drop_last=True) # args.cpus
valid_loader = DataLoader(dataset=valid_data, batch_size=args.batch_size, num_workers=4, shuffle=False,
pin_memory=True, drop_last=True) # args.cpus
_model = model.Model(args, checkpoint)
_loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, train_loader, train_sampler, valid_loader, _model, _loss, checkpoint)
while not t.terminate():
t.train()
del _model
del _loss
del train_loader
del valid_loader
# checkpoint.done()
if __name__ == '__main__':
# if not args.cpu: torch.cuda.set_device(0)
main()
================================================
FILE: code/real/bsrt/model/DCNv2/LICENSE
================================================
BSD 3-Clause License
Copyright (c) 2019, Charles Shang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: code/real/bsrt/model/DCNv2/README.md
================================================
## Deformable Convolutional Networks V2 with Pytorch 1.0
### Build
```bash
./make.sh # build
python test.py # run examples and gradient check
```
### An Example
- deformable conv
```python
from dcn_v2 import DCN
input = torch.randn(2, 64, 128, 128).cuda()
# wrap all things (offset and mask) in DCN
dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()
output = dcn(input)
print(output.shape)
```
- deformable roi pooling
```python
from dcn_v2 import DCNPooling
input = torch.randn(2, 32, 64, 64).cuda()
batch_inds = torch.randint(2, (20, 1)).cuda().float()
x = torch.randint(256, (20, 1)).cuda().float()
y = torch.randint(256, (20, 1)).cuda().float()
w = torch.randint(64, (20, 1)).cuda().float()
h = torch.randint(64, (20, 1)).cuda().float()
rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)
# mdformable pooling (V2)
# wrap all things (offset and mask) in DCNPooling
dpooling = DCNPooling(spatial_scale=1.0 / 4,
pooled_size=7,
output_dim=32,
no_trans=False,
group_size=1,
trans_std=0.1).cuda()
dout = dpooling(input, rois)
```
### Note
Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,
```bash
git checkout pytorch_0.4
```
### Known Issues:
- [x] Gradient check w.r.t offset (solved)
- [ ] Backward is not reentrant (minor)
This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).
<s>I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.
However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some
non-differential points? </s>
Update: all gradient check passes with double precision.
Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for
float `<1e-15` for double),
so it may not be a serious problem (?)
Please post an issue or PR if you have any comments.
================================================
FILE: code/real/bsrt/model/DCNv2/__init__.py
================================================
================================================
FILE: code/real/bsrt/model/DCNv2/dcn_v2.py
================================================
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import math
import torch
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from torch.cuda.amp import custom_fwd, custom_bwd
# from apex import amp
import _ext as _backend
class _DCNv2(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
# @amp.float_function
def forward(
ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups
):
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.kernel_size = _pair(weight.shape[2:4])
ctx.deformable_groups = deformable_groups
output = _backend.dcn_v2_forward(
input,
weight,
bias,
offset,
mask,
ctx.kernel_size[0],
ctx.kernel_size[1],
ctx.stride[0],
ctx.stride[1],
ctx.padding[0],
ctx.padding[1],
ctx.dilation[0],
ctx.dilation[1],
ctx.deformable_groups,
)
ctx.save_for_backward(input, offset, mask, weight, bias)
return output
@staticmethod
@once_differentiable
@custom_bwd
# @amp.float_function
def backward(ctx, grad_output):
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input, grad_offset, grad_mask, grad_weight, grad_bias = _backend.dcn_v2_backward(
input,
weight,
bias,
offset,
mask,
grad_output,
ctx.kernel_size[0],
ctx.kernel_size[1],
ctx.stride[0],
ctx.stride[1],
ctx.padding[0],
ctx.padding[1],
ctx.dilation[0],
ctx.dilation[1],
ctx.deformable_groups,
)
return grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None
@staticmethod
def symbolic(
g, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups
):
from torch.nn.modules.utils import _pair
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
# as of trt 7, the dcn operation will be translated again by modifying the onnx file
# so the exporting code is kept to resemble the forward()
return g.op(
"DCNv2_2",
input,
offset,
mask,
weight,
bias,
stride_i=stride,
padding_i=padding,
dilation_i=dilation,
deformable_groups_i=deformable_groups,
)
dcn_v2_conv = _DCNv2.apply
class DCNv2(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1,
):
super(DCNv2, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1.0 / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def forward(self, input, offset, mask):
assert (
2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1]
== offset.shape[1]
)
assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1]
return dcn_v2_conv(
input,
offset,
mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups,
)
class DCN(DCNv2):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1,
):
super(DCN, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups
)
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
channels_,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True,
)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return dcn_v2_conv(
input,
offset,
mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups,
)
class DCN_sep(DCNv2):
'''Use other features to generate offsets and masks'''
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1):
super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
dilation, deformable_groups)
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
channels_,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input, fea):
'''input: input features for deformable conv
fea: other features used for generating offsets and mask'''
out = self.conv_offset_mask(fea)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
# offset = torch.clamp(offset, -100, 100)
offset_mean = torch.mean(torch.abs(offset))
if offset_mean > 250:
print('Offset mean is {}, larger than 100.'.format(offset_mean))
# return None
# offset[offset>=150] = 1e-3
# offset = offset.clamp(-50, 50)
mask = torch.sigmoid(mask)
return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class FlowGuidedDCN(DCNv2):
'''Use other features to generate offsets and masks'''
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1):
super(FlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
dilation, deformable_groups)
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset_mask = nn.Conv2d(
in_channels, channels_, kernel_size, stride, padding, bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input, fea, flows):
'''input: input features for deformable conv: N, C, H, W.
fea: other features used for generating offsets and mask: N, C, H, W.
flows: N, 2, H, W.
'''
out = self.conv_offset_mask(fea)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude
offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1)
offset_mean = torch.mean(torch.abs(offset))
if offset_mean > 250:
print('FlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean))
# offset = offset.clamp(-50, 50)
# return None
mask = torch.sigmoid(mask)
return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class InsideFlowGuidedDCN(DCNv2):
'''Use other features to generate offsets and masks'''
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1):
super(InsideFlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
dilation, deformable_groups)
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset_mask = nn.Sequential(
nn.Conv2d(in_channels*2+2, out_channels, kernel_size, stride, padding, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(out_channels, channels_, kernel_size, stride, padding, bias=True)
)
self.reset_parameters()
self.init_offset()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1.0 / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def init_offset(self):
self.conv_offset_mask[-1].weight.data.zero_()
self.conv_offset_mask[-1].bias.data.zero_()
def forward(self, input, warped, ref, flows):
'''input: input features for deformable conv: N, C, H, W.
fea: other features used for generating offsets and mask: N, C, H, W.
flows: N, 2, H, W.
'''
out = self.conv_offset_mask(torch.cat([warped, ref, flows], dim=1))
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude
offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1)
offset_mean = torch.mean(torch.abs(offset))
if offset_mean > 250:
print('InsideFlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean))
print('flow mean is {}'.format(torch.abs(flows).mean()))
offset = offset.clamp(-50, 50)
# return None
mask = torch.sigmoid(mask)
return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class _DCNv2Pooling(Function):
@staticmethod
def forward(
ctx,
input,
rois,
offset,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=0.0,
):
ctx.spatial_scale = spatial_scale
ctx.no_trans = int(no_trans)
ctx.output_dim = output_dim
ctx.group_size = group_size
ctx.pooled_size = pooled_size
ctx.part_size = pooled_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
output, output_count = _backend.dcn_v2_psroi_pooling_forward(
input,
rois,
offset,
ctx.no_trans,
ctx.spatial_scale,
ctx.output_dim,
ctx.group_size,
ctx.pooled_size,
ctx.part_size,
ctx.sample_per_part,
ctx.trans_std,
)
ctx.save_for_backward(input, rois, offset, output_count)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, rois, offset, output_count = ctx.saved_tensors
grad_input, grad_offset = _backend.dcn_v2_psroi_pooling_backward(
grad_output,
input,
rois,
offset,
output_count,
ctx.no_trans,
ctx.spatial_scale,
ctx.output_dim,
ctx.group_size,
ctx.pooled_size,
ctx.part_size,
ctx.sample_per_part,
ctx.trans_std,
)
return grad_input, None, grad_offset, None, None, None, None, None, None, None, None
dcn_v2_pooling = _DCNv2Pooling.apply
class DCNv2Pooling(nn.Module):
def __init__(
self,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=0.0,
):
super(DCNv2Pooling, self).__init__()
self.spatial_scale = spatial_scale
self.pooled_size = pooled_size
self.output_dim = output_dim
self.no_trans = no_trans
self.group_size = group_size
self.part_size = pooled_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
def forward(self, input, rois, offset):
assert input.shape[1] == self.output_dim
if self.no_trans:
offset = input.new()
return dcn_v2_pooling(
input,
rois,
offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
self.no_trans,
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std,
)
class DCNPooling(DCNv2Pooling):
def __init__(
self,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=0.0,
deform_fc_dim=1024,
):
super(DCNPooling, self).__init__(
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size,
part_size,
sample_per_part,
trans_std,
)
self.deform_fc_dim = deform_fc_dim
if not no_trans:
self.offset_mask_fc = nn.Sequential(
nn.Linear(
self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim
),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3),
)
self.offset_mask_fc[4].weight.data.zero_()
self.offset_mask_fc[4].bias.data.zero_()
def forward(self, input, rois):
offset = input.new()
if not self.no_trans:
# do roi_align first
n = rois.shape[0]
roi = dcn_v2_pooling(
input,
rois,
offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
True, # no trans
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std,
)
# build mask and offset
offset_mask = self.offset_mask_fc(roi.view(n, -1))
offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size)
o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
# do pooling with offset and mask
return (
dcn_v2_pooling(
input,
rois,
offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
self.no_trans,
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std,
)
* mask
)
# only roi_align
return dcn_v2_pooling(
input,
rois,
offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
self.no_trans,
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std,
)
================================================
FILE: code/real/bsrt/model/DCNv2/files.txt
================================================
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu.so
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.py
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/PKG-INFO
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/SOURCES.txt
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/dependency_links.txt
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/native_libs.txt
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/not-zip-safe
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/top_level.txt
/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/__pycache__/_ext.cpython-37.pyc
================================================
FILE: code/real/bsrt/model/DCNv2/make.sh
================================================
#!/usr/bin/env bash
python setup.py build develop
================================================
FILE: code/real/bsrt/model/DCNv2/setup.py
================================================
#!/usr/bin/env python
import glob
import os
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
requirements = ["torch", "torchvision"]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
os.environ["CC"] = "g++"
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
if True:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
else:
# raise NotImplementedError('Cuda is not available')
pass
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
"_ext",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name="DCNv2",
version="0.1",
author="charlesshang",
url="https://github.com/charlesshang/DCNv2",
description="deformable convolutional networks",
packages=find_packages(exclude=("configs", "tests")),
# install_requires=requirements,
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp
================================================
#include <vector>
#include "cpu/dcn_v2_im2col_cpu.h"
#include <ATen/ATen.h>
//#include <ATen/cuda/CUDAContext.h>
#include <TH/TH.h>
//#include <THC/THCAtomics.cuh>
//#include <THC/THCDeviceUtils.cuh>
//extern THCState *state;
// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
// modified from the CUDA version for CPU use by Daniel K. Suhendro
at::Tensor
dcn_v2_cpu_forward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group)
{
// THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
/*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
// printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
// printf("Channels: %d %d\n", channels, channels_kernel);
// printf("Channels: %d %d\n", channels_out, channels_kernel);
AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
using scalar_t = float;
for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto output_n = output.select(0, b);
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
long m_ = channels_out;
long n_ = height_out * width_out;
long k_ = 1;
THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,
ones.contiguous().data<scalar_t>(), k_,
bias.contiguous().data<scalar_t>(), k_, 0.0f,
output_n.data<scalar_t>(), n_);
modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group,
columns.data<scalar_t>());
//(k * m) x (m * n)
// Y = WC
long m = channels_out;
long n = height_out * width_out;
long k = channels * kernel_h * kernel_w;
THFloatBlas_gemm('n', 'n', n, m, k, 1.0f,
columns.data<scalar_t>(), n,
weight.data<scalar_t>(), k, 1.0f,
output_n.data<scalar_t>(), n);
}
return output;
}
std::vector<at::Tensor> dcn_v2_cpu_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");
/*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
auto grad_input = at::zeros_like(input);
auto grad_weight = at::zeros_like(weight);
auto grad_bias = at::zeros_like(bias);
auto grad_offset = at::zeros_like(offset);
auto grad_mask = at::zeros_like(mask);
using scalar_t = float;
for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto grad_output_n = grad_output.select(0, b);
auto grad_input_n = grad_input.select(0, b);
auto grad_offset_n = grad_offset.select(0, b);
auto grad_mask_n = grad_mask.select(0, b);
long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;
THFloatBlas_gemm('n', 't', n, m, k, 1.0f,
grad_output_n.data<scalar_t>(), n,
weight.data<scalar_t>(), m, 0.0f,
columns.data<scalar_t>(), n);
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cpu(columns.data<scalar_t>(),
input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset_n.data<scalar_t>(),
grad_mask_n.data<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cpu(columns.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input_n.data<scalar_t>());
// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data<scalar_t>());
long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;
THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,
columns.data<scalar_t>(), k_,
grad_output_n.data<scalar_t>(), k_, 1.0f,
grad_weight.data<scalar_t>(), n_);
// gradient w.r.t. bias
// long m_ = channels_out;
// long k__ = height_out * width_out;
// THFloatBlas_gemv('t', k_, m_, 1.0f,
// grad_output_n.data<scalar_t>(), k_,
// ones.data<scalar_t>(), 1, 1.0f,
// grad_bias.data<scalar_t>(), 1);
}
return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
};
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp
================================================
#include "dcn_v2_im2col_cpu.h"
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
//#include <ATen/cuda/CUDAContext.h>
#include <TH/TH.h>
//#include <THC/THCAtomics.cuh>
//#include <THC/THCDeviceUtils.cuh>
// modified from the CUDA version for CPU use by Daniel K. Suhendro
/*#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}*/
float dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_width,
const int height, const int width, float h, float w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh, hw = 1 - lw;
float v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
float dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
float weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
float dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w,
const int height, const int width, const float *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
float weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
void modulated_deformable_im2col_cpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
float *data_col)
{
// launch channels * batch_size * height_col * width_col cores
for(int index=0; index<n; index++)
{
// NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)
// here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
// const int b_col = (index / width_col / height_col) % batch_size;
const int b_col = (index / width_col / height_col / num_channels) % batch_size;
// const int c_im = (index / width_col / height_col) / batch_size;
const int c_im = (index / width_col / height_col) % num_channels;
// const int c_col = c_im * kernel_h * kernel_w;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
// float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col;
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const float offset_h = data_offset_ptr[data_offset_h_ptr];
const float offset_w = data_offset_ptr[data_offset_w_ptr];
const float mask = data_mask_ptr[data_mask_hw_ptr];
float val = static_cast<float>(0);
const float h_im = h_in + i * dilation_h + offset_h;
const float w_im = w_in + j * dilation_w + offset_w;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
// data_col_ptr += batch_size * height_col * width_col;
data_col_ptr += height_col * width_col;
}
}
}
}
void modulated_deformable_col2im_cpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
float *grad_im)
{
for(int index = 0; index < n; index++)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const float offset_h = data_offset_ptr[data_offset_h_ptr];
const float offset_w = data_offset_ptr[data_offset_w_ptr];
const float mask = data_mask_ptr[data_mask_hw_ptr];
const float cur_inv_h_data = h_in + i * dilation_h + offset_h;
const float cur_inv_w_data = w_in + j * dilation_w + offset_w;
const float cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
float weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
//atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
*(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad;
}
}
}
}
}
void modulated_deformable_col2im_coord_cpu_kernel(const int n, const float *data_col, const float *data_im,
const float *data_offset, const float *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col,
float *grad_offset, float *grad_mask)
{
for(int index = 0; index < n; index++)
{
float val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const float offset_h = data_offset_ptr[data_offset_h_ptr];
const float offset_w = data_offset_ptr[data_offset_w_ptr];
const float mask = data_mask_ptr[data_mask_hw_ptr];
float inv_h = h_in + i * dilation_h + offset_h;
float inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
else
{
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
}
const float weight = dmcn_get_coordinate_weight_cpu(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
}
}
void modulated_deformable_im2col_cpu(const float* data_im, const float* data_offset, const float* data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float* data_col) {
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
modulated_deformable_im2col_cpu_kernel(
num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col);
/*cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}*/
}
void modulated_deformable_col2im_cpu(const float* data_col, const float* data_offset, const float* data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float* grad_im){
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
modulated_deformable_col2im_cpu_kernel(
num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im);
/*cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}*/
}
void modulated_deformable_col2im_coord_cpu(const float* data_col, const float* data_im, const float* data_offset, const float* data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
float* grad_offset, float* grad_mask) {
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
modulated_deformable_col2im_coord_cpu_kernel(
num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
grad_offset, grad_mask);
/*cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
}*/
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h
================================================
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer ********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.h
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1811.11168
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
*/
/***************** Adapted by Charles Shang *********************/
// modified from the CUDA version for CPU use by Daniel K. Suhendro
#ifndef DCN_V2_IM2COL_CPU
#define DCN_V2_IM2COL_CPU
#ifdef __cplusplus
extern "C"
{
#endif
void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float *data_col);
void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, float *grad_im);
void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
float *grad_offset, float *grad_mask);
#ifdef __cplusplus
}
#endif
#endif
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp
================================================
/*!
* Copyright (c) 2017 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file deformable_psroi_pooling.cu
* \brief
* \author Yi Li, Guodong Zhang, Jifeng Dai
*/
/***************** Adapted by Charles Shang *********************/
// modified from the CUDA version for CPU use by Daniel K. Suhendro
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
//#include <ATen/cuda/CUDAContext.h>
#include <TH/TH.h>
//#include <THC/THCAtomics.cuh>
//#include <THC/THCDeviceUtils.cuh>
/*#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}*/
template <typename T>
T bilinear_interp_cpu(
const T *data,
const T x,
const T y,
const int width,
const int height)
{
int x1 = floor(x);
int x2 = ceil(x);
int y1 = floor(y);
int y2 = ceil(y);
T dist_x = static_cast<T>(x - x1);
T dist_y = static_cast<T>(y - y1);
T value11 = data[y1 * width + x1];
T value12 = data[y2 * width + x1];
T value21 = data[y1 * width + x2];
T value22 = data[y2 * width + x2];
T value = (1 - dist_x) * (1 - dist_y) * value11 +
(1 - dist_x) * dist_y * value12 +
dist_x * (1 - dist_y) * value21 +
dist_x * dist_y * value22;
return value;
}
template <typename T>
void DeformablePSROIPoolForwardKernelCpu(
const int count,
const T *bottom_data,
const T spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const T *bottom_rois, const T *bottom_trans,
const int no_trans,
const T trans_std,
const int sample_per_part,
const int output_dim,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class,
T *top_data,
T *top_count)
{
for(int index = 0; index < count; index++)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const T *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0
T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);
T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);
int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);
int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
T sum = 0;
int count = 0;
int gw = floor(static_cast<T>(pw) * group_size / pooled_width);
int gh = floor(static_cast<T>(ph) * group_size / pooled_height);
gw = std::min(std::max(gw, 0), group_size - 1);
gh = std::min(std::max(gh, 0), group_size - 1);
const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
T w = wstart + iw * sub_bin_size_w;
T h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = std::min(std::max(w, T(0.)), width - T(1.));
h = std::min(std::max(h, T(0.)), height - T(1.));
int c = (ctop * group_size + gh) * group_size + gw;
T val = bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height);
sum += val;
count++;
}
}
top_data[index] = count == 0 ? static_cast<T>(0) : sum / count;
top_count[index] = count;
}
}
template <typename T>
void DeformablePSROIPoolBackwardAccKernelCpu(
const int count,
const T *top_diff,
const T *top_count,
const int num_rois,
const T spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int output_dim,
T *bottom_data_diff, T *bottom_trans_diff,
const T *bottom_data,
const T *bottom_rois,
const T *bottom_trans,
const int no_trans,
const T trans_std,
const int sample_per_part,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class)
{
for(int index = 0; index < count; index++)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const T *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
T roi_start_w = static_cast<T>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
T roi_start_h = static_cast<T>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
T roi_end_w = static_cast<T>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
T roi_end_h = static_cast<T>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0
T roi_height = std::max(roi_end_h - roi_start_h, T(0.1));
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
T sub_bin_size_h = bin_size_h / static_cast<T>(sample_per_part);
T sub_bin_size_w = bin_size_w / static_cast<T>(sample_per_part);
int part_h = floor(static_cast<T>(ph) / pooled_height * part_size);
int part_w = floor(static_cast<T>(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
T trans_x = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std;
T trans_y = no_trans ? static_cast<T>(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
if (top_count[index] <= 0)
{
continue;
}
T diff_val = top_diff[index] / top_count[index];
const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
int gw = floor(static_cast<T>(pw) * group_size / pooled_width);
int gh = floor(static_cast<T>(ph) * group_size / pooled_height);
gw = std::min(std::max(gw, 0), group_size - 1);
gh = std::min(std::max(gh, 0), group_size - 1);
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
T w = wstart + iw * sub_bin_size_w;
T h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = std::min(std::max(w, T(0.)), width - T(1.));
h = std::min(std::max(h, T(0.)), height - T(1.));
int c = (ctop * group_size + gh) * group_size + gw;
// backward on feature
int x0 = floor(w);
int x1 = ceil(w);
int y0 = floor(h);
int y1 = ceil(h);
T dist_x = w - x0, dist_y = h - y0;
T q00 = (1 - dist_x) * (1 - dist_y);
T q01 = (1 - dist_x) * dist_y;
T q10 = dist_x * (1 - dist_y);
T q11 = dist_x * dist_y;
int bottom_index_base = c * height * width;
/*atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);*/
*(offset_bottom_data_diff + bottom_index_base + y0 * width + x0) += q00 * diff_val;
*(offset_bottom_data_diff + bottom_index_base + y1 * width + x0) += q01 * diff_val;
*(offset_bottom_data_diff + bottom_index_base + y0 * width + x1) += q10 * diff_val;
*(offset_bottom_data_diff + bottom_index_base + y1 * width + x1) += q11 * diff_val;
if (no_trans)
{
continue;
}
T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
diff_x *= roi_width;
T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
diff_y *= roi_height;
/*atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);*/
*(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w) += diff_x;
*(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w) += diff_y;
}
}
}
}
std::tuple<at::Tensor, at::Tensor>
dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
/*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");*/
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
auto pooled_height = pooled_size;
auto pooled_width = pooled_size;
auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
long out_size = num_bbox * output_dim * pooled_height * pooled_width;
auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options());
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
//cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (out.numel() == 0)
{
//THCudaCheck(cudaGetLastError());
return std::make_tuple(out, top_count);
}
/*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
dim3 block(512);*/
AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cpu_forward", [&] {
DeformablePSROIPoolForwardKernelCpu<scalar_t>(
out_size,
input.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height, width,
pooled_height,
pooled_width,
bbox.contiguous().data<scalar_t>(),
trans.contiguous().data<scalar_t>(),
no_trans,
trans_std,
sample_per_part,
output_dim,
group_size,
part_size,
num_classes,
channels_each_class,
out.data<scalar_t>(),
top_count.data<scalar_t>());
});
//THCudaCheck(cudaGetLastError());
return std::make_tuple(out, top_count);
}
std::tuple<at::Tensor, at::Tensor>
dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const at::Tensor &top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
/*AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor");
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor");
AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");
AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");*/
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
AT_ASSERTM(channels == output_dim, "input channels and output channels must equal");
auto pooled_height = pooled_size;
auto pooled_width = pooled_size;
long out_size = num_bbox * output_dim * pooled_height * pooled_width;
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options());
auto trans_grad = at::zeros_like(trans);
if (input_grad.numel() == 0)
{
//THCudaCheck(cudaGetLastError());
return std::make_tuple(input_grad, trans_grad);
}
/*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L));
dim3 block(512);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/
AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cpu_backward", [&] {
DeformablePSROIPoolBackwardAccKernelCpu<scalar_t>(
out_size,
out_grad.contiguous().data<scalar_t>(),
top_count.contiguous().data<scalar_t>(),
num_bbox,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
output_dim,
input_grad.contiguous().data<scalar_t>(),
trans_grad.contiguous().data<scalar_t>(),
input.contiguous().data<scalar_t>(),
bbox.contiguous().data<scalar_t>(),
trans.contiguous().data<scalar_t>(),
no_trans,
trans_std,
sample_per_part,
group_size,
part_size,
num_classes,
channels_each_class);
});
//THCudaCheck(cudaGetLastError());
return std::make_tuple(input_grad, trans_grad);
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cpu/vision.h
================================================
#pragma once
#include <torch/extension.h>
at::Tensor
dcn_v2_cpu_forward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group);
std::vector<at::Tensor>
dcn_v2_cpu_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group);
std::tuple<at::Tensor, at::Tensor>
dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
std::tuple<at::Tensor, at::Tensor>
dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const at::Tensor &top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
================================================
FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_cuda.cu
================================================
#include <vector>
#include "cuda/dcn_v2_im2col_cuda.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/Dispatch.h>
#include <ATen/div_rtn.h>
#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
THCState *state = at::globalContext().lazyInitCUDA();
static cublasOperation_t _cublasOpFromChar(char op) {
switch (op) {
case 'n':
case 'N':
return CUBLAS_OP_N;
case 't':
case 'T':
return CUBLAS_OP_T;
case 'c':
case 'C':
return CUBLAS_OP_C;
}
AT_ERROR(
"_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
}
static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
// Note: leading dimensions generally are checked that they are > 0
// and at least as big the result requires (even if the value won't
// be used).
// Q: Why does Level3 check trans but this doesn't?
// A: In level 2, the sizes (m, n) specify the size of A
// (independent of trans value). In level 3. the sizes (m, n, k)
// specify the sizes of op(A), op(B) where op depend on trans
// values.
if (n <= 1)
*lda = std::max<int64_t>(m, 1);
}
// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
// [batch gemm]
// https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu
__global__ void createBatchGemmBuffer(const float **input_b, float **output_b,
float **columns_b, const float **ones_b,
const float **weight_b, const float **bias_b,
float *input, float *output,
float *columns, float *ones,
float *weight, float *bias,
const int input_stride, const int output_stride,
const int columns_stride, const int ones_stride,
const int num_batches)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_batches)
{
input_b[idx] = input + idx * input_stride;
output_b[idx] = output + idx * output_stride;
columns_b[idx] = columns + idx * columns_stride;
ones_b[idx] = ones + idx * ones_stride;
// share weights and bias within a Mini-Batch
weight_b[idx] = weight;
bias_b[idx] = bias;
}
}
at::Tensor
dcn_v2_cuda_forward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group)
{
using scalar_t = float;
// THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
// printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
// printf("Channels: %d %d\n", channels, channels_kernel);
// printf("Channels: %d %d\n", channels_out, channels_kernel);
AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
auto ones = at::ones({batch, height_out, width_out}, input.options());
auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
// prepare for batch-wise computing, which is significantly faster than instance-wise computing
// when batch size is large.
// launch batch threads
int matrices_size = batch * sizeof(float *);
auto input_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));
auto output_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto columns_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto ones_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));
auto weight_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));
auto bias_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));
const int block = 128;
const int grid = (batch + block - 1) / block;
createBatchGemmBuffer<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
input_b, output_b,
columns_b, ones_b,
weight_b, bias_b,
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
columns.data_ptr<scalar_t>(),
ones.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
bias.data_ptr<scalar_t>(),
channels * width * height,
channels_out * width_out * height_out,
channels * kernel_h * kernel_w * height_out * width_out,
height_out * width_out,
batch);
long m_ = channels_out;
long n_ = height_out * width_out;
long k_ = 1;
THCudaBlas_SgemmBatched(state,
't',
'n',
n_,
m_,
k_,
1.0f,
ones_b, k_,
bias_b, k_,
0.0f,
output_b, n_,
batch);
modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),
input.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group,
columns.data_ptr<scalar_t>());
long m = channels_out;
long n = height_out * width_out;
long k = channels * kernel_h * kernel_w;
THCudaBlas_SgemmBatched(state,
'n',
'n',
n,
m,
k,
1.0f,
(const float **)columns_b, n,
weight_b, k,
1.0f,
output_b, n,
batch);
THCudaFree(state, input_b);
THCudaFree(state, output_b);
THCudaFree(state, columns_b);
THCudaFree(state, ones_b);
THCudaFree(state, weight_b);
THCudaFree(state, bias_b);
return output;
}
__global__ void createBatchGemmBufferBackward(
float **grad_output_b,
float **columns_b,
float **ones_b,
float **weight_b,
float **grad_weight_b,
float **grad_bias_b,
float *grad_output,
float *columns,
float *ones,
float *weight,
float *grad_weight,
float *grad_bias,
const int grad_output_stride,
const int columns_stride,
const int ones_stride,
const int num_batches)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_batches)
{
grad_output_b[idx] = grad_output + idx * grad_output_stride;
columns_b[idx] = columns + idx * columns_stride;
ones_b[idx] = ones + idx * ones_stride;
// share weights and bias within a Mini-Batch
weight_b[idx] = weight;
grad_weight_b[idx] = grad_weight;
grad_bias_b[idx] = grad_bias;
}
}
std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);
AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
auto grad_input = at::zeros_like(input);
auto grad_weight = at::zeros_like(weight);
auto grad_bias = at::zeros_like(bias);
auto grad_offset = at::zeros_like(offset);
auto grad_mask = at::zeros_like(mask);
using scalar_t = float;
for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto grad_output_n = grad_output.select(0, b);
auto grad_input_n = grad_input.select(0, b);
auto grad_offset_n = grad_offset.select(0, b);
auto grad_mask_n = grad_mask.select(0, b);
long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;
THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,
grad_output_n.data_ptr<scalar_t>(), n,
weight.data_ptr<scalar_t>(), m, 0.0f,
columns.data_ptr<scalar_t>(), n);
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(),
columns.data_ptr<scalar_t>(),
input_n.data_ptr<scalar_t>(),
offset_n.data_ptr<scalar_t>(),
mask_n.data_ptr<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset_n.data_ptr<scalar_t>(),
grad_mask_n.data_ptr<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(),
columns.data_ptr<scalar_t>(),
offset_n.data_ptr<scalar_t>(),
mask_n.data_ptr<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input_n.data_ptr<scalar_t>());
// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),
input_n.data_ptr<scalar_t>(),
offset_n.data_ptr<scalar_t>(),
mask_n.data_ptr<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data_ptr<scalar_t>());
long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;
THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,
columns.data_ptr<scalar_t>(), k_,
grad_output_n.data_ptr<scalar_t>(), k_, 1.0f,
grad_weight.data_ptr<scalar_t>(), n_);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar('t');
_cublasAdjustLdLevel2(k_, m_, &k_);
scalar_t* grad_output_n_float = grad_output_n.data_ptr<scalar_t>();
scalar_t* one_float = ones.data_ptr<scalar_t>();
scalar_t alpha = 1.0;
scalar_t beta = 1.0;
cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr<scalar_t>(), 1);
}
return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
};
}
================================================
FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu
================================================
#include "dcn_v2_im2col_cuda.h"
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
__device__ float dmcn_im2col_bilinear_cuda(const float *bottom_data, const int data_width,
const int height, const int width, float h, float w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh, hw = 1 - lw;
float v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
__device__ float dmcn_get_gradient_weight_cuda(float argmax_h, float argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
float weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
__device__ float dmcn_get_coordinate_weight_cuda(float argmax_h, float argmax_w,
const int height, const int width, const float *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
float weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
const float *data_im, const float *data_offset, const float *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
float *data_col)
{
// launch channels * batch_size * height_col * width_col cores
CUDA_KERNEL_LOOP(index, n)
{
// NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)
// here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
// const int b_col = (index / width_col / height_col) % batch_size;
const int b_col = (index / width_col / height_col / num_channels) % batch_size;
// const int c_im = (index / width_col / height_col) / batch_size;
const int c_im = (index / width_col / height_col) % num_channels;
// const int c_col = c_im * kernel_h * kernel_w;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
// float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col;
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const float offset_h = data_offset_ptr[data_offset_h_ptr];
const float offset_w = data_offset_ptr[data_offset_w_ptr];
const float mask = data_mask_ptr[data_mask_hw_ptr];
float val = static_cast<float>(0);
const float h_im = h_in + i * dilation_h + offset_h;
const float w_im = w_in + j * dilation_w + offset_w;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
// data_col_ptr += batch_size * height_col * width_col;
data_col_ptr += height_col * width_col;
}
}
}
}
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
const float *data_col, const float *data_offset, const float *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
float *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col
gitextract_kbz2i0m2/ ├── .gitignore ├── LICENSE ├── README.md ├── code/ │ ├── real/ │ │ └── bsrt/ │ │ ├── README.md │ │ ├── data_processing/ │ │ │ ├── __init__.py │ │ │ ├── camera_pipeline.py │ │ │ └── synthetic_burst_generation.py │ │ ├── datasets/ │ │ │ ├── __init__.py │ │ │ ├── burstsr_dataset.py │ │ │ ├── burstsr_test_dataset.py │ │ │ ├── data_sampler.py │ │ │ ├── realworld_burst_test_set.py │ │ │ ├── synthetic_burst_test_set.py │ │ │ ├── synthetic_burst_train_set.py │ │ │ ├── synthetic_burst_val_set.py │ │ │ └── zurich_raw2rgb_dataset.py │ │ ├── demo.sh │ │ ├── loss/ │ │ │ ├── Charbonnier.py │ │ │ ├── __init__.py │ │ │ ├── adversarial.py │ │ │ ├── discriminator.py │ │ │ ├── filter.py │ │ │ ├── hist_entropy.py │ │ │ ├── mssim.py │ │ │ └── vgg.py │ │ ├── main.py │ │ ├── model/ │ │ │ ├── DCNv2/ │ │ │ │ ├── LICENSE │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── dcn_v2.py │ │ │ │ ├── files.txt │ │ │ │ ├── make.sh │ │ │ │ ├── setup.py │ │ │ │ ├── src/ │ │ │ │ │ ├── cpu/ │ │ │ │ │ │ ├── dcn_v2_cpu.cpp │ │ │ │ │ │ ├── dcn_v2_im2col_cpu.cpp │ │ │ │ │ │ ├── dcn_v2_im2col_cpu.h │ │ │ │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp │ │ │ │ │ │ └── vision.h │ │ │ │ │ ├── cuda/ │ │ │ │ │ │ ├── dcn_v2_cuda.cu │ │ │ │ │ │ ├── dcn_v2_im2col_cuda.cu │ │ │ │ │ │ ├── dcn_v2_im2col_cuda.h │ │ │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu │ │ │ │ │ │ └── vision.h │ │ │ │ │ ├── dcn_v2.h │ │ │ │ │ └── vision.cpp │ │ │ │ └── test.py │ │ │ ├── __init__.py │ │ │ ├── arch_util.py │ │ │ ├── bsrt.py │ │ │ ├── checkpoint.py │ │ │ ├── common.py │ │ │ ├── non_local/ │ │ │ │ ├── network.py │ │ │ │ ├── non_local_concatenation.py │ │ │ │ ├── non_local_cross_dot_product.py │ │ │ │ ├── non_local_dot_product.py │ │ │ │ ├── non_local_embedded_gaussian.py │ │ │ │ └── non_local_gaussian.py │ │ │ ├── swin_util.py │ │ │ └── utils/ │ │ │ ├── interp_methods.py │ │ │ ├── psconv.py │ │ │ └── resize_right.py │ │ ├── option.py │ │ ├── pwcnet/ │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── comparison/ │ │ │ │ └── comparison.py │ │ │ ├── correlation/ │ │ │ │ ├── README.md │ │ │ │ └── correlation.py │ │ │ ├── download.bash │ │ │ ├── images/ │ │ │ │ └── README.md │ │ │ ├── out.flo │ │ │ ├── pwcnet.py │ │ │ ├── requirements.txt │ │ │ └── run.py │ │ ├── requirements.txt │ │ ├── scripts/ │ │ │ ├── __init__.py │ │ │ ├── cal_mean_std.py │ │ │ ├── demo.sh │ │ │ ├── download_burstsr_dataset.py │ │ │ ├── evaluate.sh │ │ │ ├── evaluate_burstsr_val.py │ │ │ ├── save_results_synburst_val.py │ │ │ ├── test_burstsr_dataset.py │ │ │ └── test_synthetic_bursts.py │ │ ├── test.py │ │ ├── test_real.py │ │ ├── trainer.py │ │ ├── utility.py │ │ ├── utils/ │ │ │ ├── __init__.py │ │ │ ├── data_format_utils.py │ │ │ ├── debayer.py │ │ │ ├── interp_methods.py │ │ │ ├── metrics.py │ │ │ ├── postprocessing_functions.py │ │ │ ├── resize_right.py │ │ │ ├── spatial_color_alignment.py │ │ │ ├── stn.py │ │ │ └── warp.py │ │ └── validate.py │ └── synthetic/ │ └── bsrt/ │ ├── README.md │ ├── data_processing/ │ │ ├── __init__.py │ │ ├── camera_pipeline.py │ │ └── synthetic_burst_generation.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── burstsr_dataset.py │ │ ├── burstsr_test_dataset.py │ │ ├── data_sampler.py │ │ ├── realworld_burst_test_set.py │ │ ├── synthetic_burst_test_set.py │ │ ├── synthetic_burst_train_set.py │ │ ├── synthetic_burst_val_set.py │ │ └── zurich_raw2rgb_dataset.py │ ├── demo.sh │ ├── loss/ │ │ ├── Charbonnier.py │ │ ├── __init__.py │ │ ├── adversarial.py │ │ ├── discriminator.py │ │ ├── filter.py │ │ ├── hist_entropy.py │ │ ├── mssim.py │ │ └── vgg.py │ ├── main.py │ ├── model/ │ │ ├── DCNv2/ │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── dcn_v2.py │ │ │ ├── files.txt │ │ │ ├── make.sh │ │ │ ├── setup.py │ │ │ ├── src/ │ │ │ │ ├── cpu/ │ │ │ │ │ ├── dcn_v2_cpu.cpp │ │ │ │ │ ├── dcn_v2_im2col_cpu.cpp │ │ │ │ │ ├── dcn_v2_im2col_cpu.h │ │ │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp │ │ │ │ │ └── vision.h │ │ │ │ ├── cuda/ │ │ │ │ │ ├── dcn_v2_cuda.cu │ │ │ │ │ ├── dcn_v2_im2col_cuda.cu │ │ │ │ │ ├── dcn_v2_im2col_cuda.h │ │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu │ │ │ │ │ └── vision.h │ │ │ │ ├── dcn_v2.h │ │ │ │ └── vision.cpp │ │ │ └── test.py │ │ ├── __init__.py │ │ ├── arch_util.py │ │ ├── bsrt.py │ │ ├── checkpoint.py │ │ ├── common.py │ │ ├── ebsr.py │ │ ├── non_local/ │ │ │ ├── network.py │ │ │ ├── non_local_concatenation.py │ │ │ ├── non_local_cross_dot_product.py │ │ │ ├── non_local_dot_product.py │ │ │ ├── non_local_embedded_gaussian.py │ │ │ └── non_local_gaussian.py │ │ ├── swin_util.py │ │ └── utils/ │ │ ├── interp_methods.py │ │ ├── psconv.py │ │ └── resize_right.py │ ├── option.py │ ├── requirements.txt │ ├── scripts/ │ │ ├── __init__.py │ │ ├── cal_mean_std.py │ │ ├── demo.sh │ │ ├── download_burstsr_dataset.py │ │ ├── evaluate.sh │ │ ├── evaluate_burstsr_val.py │ │ ├── save_results_synburst_val.py │ │ ├── test_burstsr_dataset.py │ │ └── test_synthetic_bursts.py │ ├── test.py │ ├── test_synburst.py │ ├── trainer.py │ ├── utility.py │ └── utils/ │ ├── __init__.py │ ├── data_format_utils.py │ ├── debayer.py │ ├── interp_methods.py │ ├── metrics.py │ ├── postprocessing_functions.py │ ├── resize_right.py │ ├── spatial_color_alignment.py │ ├── stn.py │ └── warp.py └── requirements.txt
SYMBOL INDEX (1285 symbols across 125 files)
FILE: code/real/bsrt/data_processing/camera_pipeline.py
function random_ccm (line 13) | def random_ccm():
function random_gains (line 47) | def random_gains():
function apply_smoothstep (line 58) | def apply_smoothstep(image):
function invert_smoothstep (line 64) | def invert_smoothstep(image):
function gamma_expansion (line 70) | def gamma_expansion(image):
function gamma_compression (line 76) | def gamma_compression(image):
function apply_ccm (line 82) | def apply_ccm(image, ccm):
function apply_gains (line 95) | def apply_gains(image, rgb_gain, red_gain, blue_gain):
function safe_invert_gains (line 109) | def safe_invert_gains(image, rgb_gain, red_gain, blue_gain):
function mosaic (line 125) | def mosaic(image, mode='rggb'):
function demosaic (line 151) | def demosaic(image):
function random_noise_levels (line 188) | def random_noise_levels():
function add_noise (line 201) | def add_noise(image, shot_noise=0.01, read_noise=0.0005):
function process_linear_image_rgb (line 208) | def process_linear_image_rgb(image, meta_info, return_np=False):
function process_linear_image_raw (line 225) | def process_linear_image_raw(image, meta_info):
FILE: code/real/bsrt/data_processing/synthetic_burst_generation.py
function random_crop (line 10) | def random_crop(frames, crop_sz):
function rgb2rawburst (line 49) | def rgb2rawburst(image, burst_size, downsample_factor=1, burst_transform...
function get_tmat (line 123) | def get_tmat(image_shape, translation, theta, shear_values, scale_factors):
function single2lrburst (line 149) | def single2lrburst(image, burst_size, downsample_factor=1, transformatio...
FILE: code/real/bsrt/datasets/burstsr_dataset.py
class SamsungRAWImage (line 10) | class SamsungRAWImage:
method load (line 12) | def load(path):
method __init__ (line 24) | def __init__(self, im_raw, black_level, cam_wb, daylight_wb, color_mat...
method get_all_meta_data (line 38) | def get_all_meta_data(self):
method get_exposure_time (line 42) | def get_exposure_time(self):
method get_noise_profile (line 45) | def get_noise_profile(self):
method get_f_number (line 51) | def get_f_number(self):
method get_iso (line 54) | def get_iso(self):
method get_image_data (line 57) | def get_image_data(self, substract_black_level=False, white_balance=Fa...
method shape (line 72) | def shape(self):
method crop_image (line 76) | def crop_image(self, r1, r2, c1, c2):
method get_crop (line 79) | def get_crop(self, r1, r2, c1, c2):
method postprocess (line 90) | def postprocess(self, return_np=True, norm_factor=None):
class CanonImage (line 113) | class CanonImage:
method load (line 115) | def load(path, split='train'):
method __init__ (line 125) | def __init__(self, im_raw, black_level, cam_wb, daylight_wb, rgb_xyz_m...
method shape (line 150) | def shape(self):
method get_all_meta_data (line 154) | def get_all_meta_data(self):
method get_exposure_time (line 159) | def get_exposure_time(self):
method get_f_number (line 162) | def get_f_number(self):
method get_iso (line 165) | def get_iso(self):
method get_image_data (line 168) | def get_image_data(self, substract_black_level=False, white_balance=Fa...
method set_image_data (line 182) | def set_image_data(self, im_data):
method crop_image (line 185) | def crop_image(self, r1, r2, c1, c2):
method get_crop (line 188) | def get_crop(self, r1, r2, c1, c2):
method set_crop_info (line 193) | def set_crop_info(self, crop_info):
method resize (line 196) | def resize(self, size=None, scale_factor=None):
method postprocess (line 201) | def postprocess(self, return_np=True):
function load_txt (line 216) | def load_txt(path):
class BurstSRDataset (line 223) | class BurstSRDataset(torch.utils.data.Dataset):
method __init__ (line 225) | def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, ...
method _get_burst_list (line 255) | def _get_burst_list(self):
method get_burst_info (line 260) | def get_burst_info(self, burst_id):
method _get_raw_image (line 264) | def _get_raw_image(self, burst_id, im_id):
method _get_gt_image (line 268) | def _get_gt_image(self, burst_id):
method get_burst (line 272) | def get_burst(self, burst_id, im_ids, info=None):
method _sample_images (line 281) | def _sample_images(self):
method __len__ (line 288) | def __len__(self):
method __getitem__ (line 291) | def __getitem__(self, index):
function pack_raw_image (line 400) | def pack_raw_image(im_raw):
function flatten_raw_image (line 415) | def flatten_raw_image(im_raw_4ch):
function pack_raw_image_batch (line 430) | def pack_raw_image_batch(im_raw):
function flatten_raw_image_batch (line 439) | def flatten_raw_image_batch(im_raw_4ch):
FILE: code/real/bsrt/datasets/burstsr_test_dataset.py
class BurstSRDataset (line 8) | class BurstSRDataset(torch.utils.data.Dataset):
method __init__ (line 10) | def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, ...
method _get_burst_list (line 40) | def _get_burst_list(self):
method get_burst_info (line 45) | def get_burst_info(self, burst_id):
method _get_raw_image (line 49) | def _get_raw_image(self, burst_id, im_id):
method get_burst (line 53) | def get_burst(self, burst_id, im_ids, info=None):
method _sample_images (line 61) | def _sample_images(self):
method __len__ (line 68) | def __len__(self):
method __getitem__ (line 71) | def __getitem__(self, index):
FILE: code/real/bsrt/datasets/data_sampler.py
class DistIterSampler (line 13) | class DistIterSampler(Sampler):
method __init__ (line 31) | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
method __iter__ (line 47) | def __iter__(self):
method __len__ (line 64) | def __len__(self):
method set_epoch (line 67) | def set_epoch(self, epoch):
FILE: code/real/bsrt/datasets/realworld_burst_test_set.py
class RealWorldBurstTest (line 7) | class RealWorldBurstTest(torch.utils.data.Dataset):
method __init__ (line 10) | def __init__(self, root):
method __len__ (line 15) | def __len__(self):
method _read_burst_image (line 18) | def _read_burst_image(self, index, image_id):
method __getitem__ (line 23) | def __getitem__(self, index):
FILE: code/real/bsrt/datasets/synthetic_burst_test_set.py
class SyntheticBurstTest (line 7) | class SyntheticBurstTest(torch.utils.data.Dataset):
method __init__ (line 11) | def __init__(self, root):
method __len__ (line 16) | def __len__(self):
method _read_burst_image (line 19) | def _read_burst_image(self, index, image_id):
method __getitem__ (line 24) | def __getitem__(self, index):
FILE: code/real/bsrt/datasets/synthetic_burst_train_set.py
class SyntheticBurst (line 8) | class SyntheticBurst(torch.utils.data.Dataset):
method __init__ (line 18) | def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=...
method __len__ (line 37) | def __len__(self):
method __getitem__ (line 40) | def __getitem__(self, index):
FILE: code/real/bsrt/datasets/synthetic_burst_val_set.py
class SyntheticBurstVal (line 8) | class SyntheticBurstVal(torch.utils.data.Dataset):
method __init__ (line 15) | def __init__(self, root=None, initialize=True):
method initialize (line 25) | def initialize(self):
method __len__ (line 28) | def __len__(self):
method _read_burst_image (line 31) | def _read_burst_image(self, index, image_id):
method _read_gt_image (line 37) | def _read_gt_image(self, index):
method _read_meta_info (line 42) | def _read_meta_info(self, index):
method __getitem__ (line 48) | def __getitem__(self, index):
FILE: code/real/bsrt/datasets/zurich_raw2rgb_dataset.py
class ZurichRAW2RGB (line 7) | class ZurichRAW2RGB(torch.utils.data.Dataset):
method __init__ (line 12) | def __init__(self, root, split='train'):
method _get_image_list (line 22) | def _get_image_list(self, split):
method _get_image (line 33) | def _get_image(self, im_id):
method get_image (line 38) | def get_image(self, im_id):
method __len__ (line 43) | def __len__(self):
method __getitem__ (line 46) | def __getitem__(self, index):
FILE: code/real/bsrt/loss/Charbonnier.py
class CharbonnierLoss (line 5) | class CharbonnierLoss(nn.Module):
method __init__ (line 8) | def __init__(self, epsilon=1e-3, reduce=True):
method forward (line 13) | def forward(self, X, Y):
FILE: code/real/bsrt/loss/__init__.py
class Loss (line 14) | class Loss(nn.modules.loss._Loss):
method __init__ (line 15) | def __init__(self, args, ckp):
method forward (line 80) | def forward(self, sr, hr):
method step (line 97) | def step(self):
method start_log (line 102) | def start_log(self):
method end_log (line 105) | def end_log(self, n_batches):
method display_loss (line 108) | def display_loss(self, batch):
method plot_loss (line 116) | def plot_loss(self, apath, epoch):
method get_loss_module (line 130) | def get_loss_module(self):
method save (line 136) | def save(self, apath):
method load (line 140) | def load(self, apath, cpu=False):
FILE: code/real/bsrt/loss/adversarial.py
class Adversarial (line 12) | class Adversarial(nn.Module):
method __init__ (line 13) | def __init__(self, args, gan_type):
method forward (line 36) | def forward(self, fake, real):
method state_dict (line 96) | def state_dict(self, *args, **kwargs):
method bce (line 102) | def bce(self, real, fake):
FILE: code/real/bsrt/loss/discriminator.py
class Discriminator (line 5) | class Discriminator(nn.Module):
method __init__ (line 9) | def __init__(self, args, gan_type='GAN'):
method forward (line 65) | def forward(self, x):
FILE: code/real/bsrt/loss/filter.py
class Filter (line 5) | class Filter(nn.Module):
method __init__ (line 6) | def __init__(self, args):
method forward (line 16) | def forward(self, x, y):
FILE: code/real/bsrt/loss/hist_entropy.py
class HistEntropy (line 5) | class HistEntropy(nn.Module):
method __init__ (line 6) | def __init__(self, args):
method forward (line 10) | def forward(self, x):
FILE: code/real/bsrt/loss/mssim.py
function gaussian (line 7) | def gaussian(window_size, sigma):
function create_window (line 12) | def create_window(window_size, channel=1):
function ssim (line 19) | def ssim(img1, img2, window_size=11, window=None, size_average=True, ful...
function msssim (line 71) | def msssim(img1, img2, window_size=11, size_average=True, val_range=None...
class SSIM (line 109) | class SSIM(torch.nn.Module):
method __init__ (line 110) | def __init__(self, window_size=11, size_average=True, val_range=None):
method forward (line 120) | def forward(self, img1, img2):
class MSSSIM (line 132) | class MSSSIM(torch.nn.Module):
method __init__ (line 133) | def __init__(self, window_size=11, size_average=True, channel=3):
method forward (line 139) | def forward(self, img1, img2):
FILE: code/real/bsrt/loss/vgg.py
class VGG (line 8) | class VGG(nn.Module):
method __init__ (line 9) | def __init__(self, conv_index, rgb_range=1):
method forward (line 24) | def forward(self, sr, hr):
FILE: code/real/bsrt/main.py
function init_seeds (line 19) | def init_seeds(seed=0, cuda_deterministic=True):
function main (line 33) | def main():
function main_worker (line 37) | def main_worker(local_rank, nprocs, args):
FILE: code/real/bsrt/model/DCNv2/dcn_v2.py
class _DCNv2 (line 17) | class _DCNv2(Function):
method forward (line 21) | def forward(
method backward (line 52) | def backward(ctx, grad_output):
method symbolic (line 75) | def symbolic(
class DCNv2 (line 102) | class DCNv2(nn.Module):
method __init__ (line 103) | def __init__(
method reset_parameters (line 126) | def reset_parameters(self):
method forward (line 134) | def forward(self, input, offset, mask):
class DCN (line 153) | class DCN(DCNv2):
method __init__ (line 154) | def __init__(
method init_offset (line 179) | def init_offset(self):
method forward (line 183) | def forward(self, input):
class DCN_sep (line 201) | class DCN_sep(DCNv2):
method __init__ (line 204) | def __init__(self,
method init_offset (line 225) | def init_offset(self):
method forward (line 229) | def forward(self, input, fea):
class FlowGuidedDCN (line 249) | class FlowGuidedDCN(DCNv2):
method __init__ (line 252) | def __init__(self,
method init_offset (line 269) | def init_offset(self):
method forward (line 273) | def forward(self, input, fea, flows):
class InsideFlowGuidedDCN (line 296) | class InsideFlowGuidedDCN(DCNv2):
method __init__ (line 299) | def __init__(self,
method reset_parameters (line 322) | def reset_parameters(self):
method init_offset (line 331) | def init_offset(self):
method forward (line 335) | def forward(self, input, warped, ref, flows):
class _DCNv2Pooling (line 359) | class _DCNv2Pooling(Function):
method forward (line 361) | def forward(
method backward (line 402) | def backward(ctx, grad_output):
class DCNv2Pooling (line 426) | class DCNv2Pooling(nn.Module):
method __init__ (line 427) | def __init__(
method forward (line 448) | def forward(self, input, rois, offset):
class DCNPooling (line 467) | class DCNPooling(DCNv2Pooling):
method __init__ (line 468) | def __init__(
method forward (line 506) | def forward(self, input, rois):
FILE: code/real/bsrt/model/DCNv2/setup.py
function get_extensions (line 13) | def get_extensions():
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp
function dcn_v2_cpu_forward (line 17) | at::Tensor
function dcn_v2_cpu_backward (line 109) | std::vector<at::Tensor> dcn_v2_cpu_backward(const at::Tensor &input,
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp
function dmcn_im2col_bilinear_cpu (line 27) | float dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_...
function dmcn_get_gradient_weight_cpu (line 58) | float dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w,
function dmcn_get_coordinate_weight_cpu (line 84) | float dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w,
function modulated_deformable_im2col_cpu_kernel (line 127) | void modulated_deformable_im2col_cpu_kernel(const int n, const float *da...
function modulated_deformable_col2im_cpu_kernel (line 198) | void modulated_deformable_col2im_cpu_kernel(const int n, const float *da...
function modulated_deformable_col2im_coord_cpu_kernel (line 259) | void modulated_deformable_col2im_coord_cpu_kernel(const int n, const flo...
function modulated_deformable_im2col_cpu (line 331) | void modulated_deformable_im2col_cpu(const float* data_im, const float* ...
function modulated_deformable_col2im_cpu (line 353) | void modulated_deformable_col2im_cpu(const float* data_col, const float*...
function modulated_deformable_col2im_coord_cpu (line 375) | void modulated_deformable_col2im_coord_cpu(const float* data_col, const ...
FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp
function T (line 34) | T bilinear_interp_cpu(
function DeformablePSROIPoolForwardKernelCpu (line 59) | void DeformablePSROIPoolForwardKernelCpu(
function DeformablePSROIPoolBackwardAccKernelCpu (line 149) | void DeformablePSROIPoolBackwardAccKernelCpu(
function dcn_v2_psroi_pooling_cpu_forward (line 278) | std::tuple<at::Tensor, at::Tensor>
function dcn_v2_psroi_pooling_cpu_backward (line 350) | std::tuple<at::Tensor, at::Tensor>
FILE: code/real/bsrt/model/DCNv2/src/vision.cpp
function PYBIND11_MODULE (line 4) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: code/real/bsrt/model/DCNv2/test.py
function conv_identify (line 20) | def conv_identify(weight, bias):
function check_zero_offset (line 32) | def check_zero_offset():
function check_gradient_dconv (line 69) | def check_gradient_dconv():
function check_pooling_zero_offset (line 100) | def check_pooling_zero_offset():
function check_gradient_dpooling (line 134) | def check_gradient_dpooling():
function example_dconv (line 169) | def example_dconv():
function example_dpooling (line 183) | def example_dpooling():
function example_mdpooling (line 226) | def example_mdpooling():
FILE: code/real/bsrt/model/__init__.py
class Model (line 10) | class Model(nn.Module):
method __init__ (line 11) | def __init__(self, args, ckp):
method forward (line 52) | def forward(self, x, idx_scale):
method save (line 72) | def save(self, apath, epoch, is_best=False):
method load (line 89) | def load(self, apath, pre_train='', resume=-1, cpu=False):
method forward_chop (line 219) | def forward_chop(self, *args, shave=10, min_size=160000):
method forward_x8 (line 278) | def forward_x8(self, *args, forward_function=None):
FILE: code/real/bsrt/model/arch_util.py
function initialize_weights (line 9) | def initialize_weights(net_l, scale=1):
function make_layer (line 29) | def make_layer(block, n_layers):
function conv_layer (line 38) | def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding...
class ESA (line 42) | class ESA(nn.Module):
method __init__ (line 43) | def __init__(self, n_feats, conv=conv_layer):
method forward (line 56) | def forward(self, x):
class DWConv (line 71) | class DWConv(nn.Module):
method __init__ (line 72) | def __init__(self, dim=768):
method forward (line 76) | def forward(self, x):
class SELayer (line 82) | class SELayer(nn.Module):
method __init__ (line 86) | def __init__(self, channel, reduction=16):
method forward (line 96) | def forward(self, x):
class ResidualBlock_noBN (line 102) | class ResidualBlock_noBN(nn.Module):
method __init__ (line 108) | def __init__(self, nf=64):
method forward (line 116) | def forward(self, x):
class ResidualBlock_SE (line 123) | class ResidualBlock_SE(nn.Module):
method __init__ (line 129) | def __init__(self, nf=64, reduction=16):
method forward (line 138) | def forward(self, x):
class _PositionAttentionModule (line 148) | class _PositionAttentionModule(nn.Module):
method __init__ (line 151) | def __init__(self, in_channels, **kwargs):
method forward (line 159) | def forward(self, x):
class SALayer (line 171) | class SALayer(nn.Module):
method __init__ (line 172) | def __init__(self, wn=None):
method forward (line 178) | def forward(self, x):
class CALayerV2 (line 186) | class CALayerV2(nn.Module):
method __init__ (line 187) | def __init__(self, n_feat, reduction=16, wn=None):
method forward (line 200) | def forward(self, x):
class DALayer (line 207) | class DALayer(nn.Module):
method __init__ (line 208) | def __init__(self, channel, reduction, wn):
method forward (line 215) | def forward(self, x):
class CALayer (line 223) | class CALayer(nn.Module):
method __init__ (line 224) | def __init__(self, channel, reduction, wn):
method forward (line 236) | def forward(self, x):
class RCAB (line 243) | class RCAB(nn.Module):
method __init__ (line 244) | def __init__(
method forward (line 266) | def forward(self, x):
class ResidualGroup (line 273) | class ResidualGroup(nn.Module):
method __init__ (line 274) | def __init__(self, n_feat, n_resblocks, da=False):
method forward (line 292) | def forward(self, x):
function make_layer_idx (line 302) | def make_layer_idx(block, n_layers):
class LRSCRCAB (line 309) | class LRSCRCAB(nn.Module):
method __init__ (line 310) | def __init__(
method forward (line 332) | def forward(self, x):
class LRSCPYRCAB (line 339) | class LRSCPYRCAB(nn.Module):
method __init__ (line 340) | def __init__(
method forward (line 366) | def forward(self, x):
class LRSCResidualGroup (line 372) | class LRSCResidualGroup(nn.Module):
method __init__ (line 373) | def __init__(self, n_feat, n_resblocks, da=False, idx=0):
method forward (line 392) | def forward(self, x):
class LRSCPSResidualGroup (line 400) | class LRSCPSResidualGroup(nn.Module):
method __init__ (line 401) | def __init__(self, n_feat, n_resblocks, da=False, idx=0):
method forward (line 421) | def forward(self, x):
class LRSCPyResidualGroup (line 430) | class LRSCPyResidualGroup(nn.Module):
method __init__ (line 431) | def __init__(self, n_feat, n_resblocks, da=False, idx=0):
method forward (line 451) | def forward(self, x):
class LRSCWideActResBlock (line 458) | class LRSCWideActResBlock(nn.Module):
method __init__ (line 459) | def __init__(self, nf=64, idx=0):
method forward (line 482) | def forward(self, x):
class LRSCPyWideActResBlock (line 488) | class LRSCPyWideActResBlock(nn.Module):
method __init__ (line 489) | def __init__(self, nf=64, idx=0):
method forward (line 515) | def forward(self, x):
class LRSCPyWideActResGroup (line 523) | class LRSCPyWideActResGroup(nn.Module):
method __init__ (line 524) | def __init__(self, nf, n_resblocks, idx=0):
method forward (line 539) | def forward(self, x):
class LRSCWideActResGroup (line 548) | class LRSCWideActResGroup(nn.Module):
method __init__ (line 549) | def __init__(self, nf, n_resblocks, idx=0):
method forward (line 564) | def forward(self, x):
class PYRCAB (line 577) | class PYRCAB(nn.Module):
method __init__ (line 578) | def __init__(
method forward (line 603) | def forward(self, x):
class PyResidualGroup (line 609) | class PyResidualGroup(nn.Module):
method __init__ (line 610) | def __init__(self, n_feat, n_resblocks, da=False):
method forward (line 632) | def forward(self, x):
class WideActResBlock (line 637) | class WideActResBlock(nn.Module):
method __init__ (line 638) | def __init__(self, nf=64):
method forward (line 658) | def forward(self, x):
class PSWideActResBlock (line 664) | class PSWideActResBlock(nn.Module):
method __init__ (line 665) | def __init__(self, nf=64):
method forward (line 685) | def forward(self, x):
class PyWideActResBlock (line 691) | class PyWideActResBlock(nn.Module):
method __init__ (line 692) | def __init__(self, nf=64):
method forward (line 717) | def forward(self, x):
function flow_warp (line 723) | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', ali...
FILE: code/real/bsrt/model/bsrt.py
function make_model (line 27) | def make_model(args, parent=False):
class BasicModule (line 66) | class BasicModule(nn.Module):
method __init__ (line 70) | def __init__(self):
method forward (line 80) | def forward(self, tensor_input):
class SpyNet (line 84) | class SpyNet(nn.Module):
method __init__ (line 92) | def __init__(self, load_path=None, return_levels=[5]):
method preprocess (line 110) | def preprocess(self, tensor_input):
method process (line 114) | def process(self, ref, supp, w, h, w_floor, h_floor):
method forward (line 160) | def forward(self, ref, supp):
class FlowGuidedPCDAlign (line 176) | class FlowGuidedPCDAlign(nn.Module):
method __init__ (line 181) | def __init__(self, nf=64, groups=8):
method forward (line 209) | def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l):
class CrossNonLocal_Fusion (line 246) | class CrossNonLocal_Fusion(nn.Module):
method __init__ (line 249) | def __init__(self, nf=64, out_feat=96, nframes=5, center=2):
method forward (line 265) | def forward(self, aligned_fea):
class BSRT (line 287) | class BSRT(nn.Module):
method __init__ (line 288) | def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chan...
method _init_weights (line 452) | def _init_weights(self, m):
method no_weight_decay (line 462) | def no_weight_decay(self):
method no_weight_decay_keywords (line 466) | def no_weight_decay_keywords(self):
method _upsample_add (line 469) | def _upsample_add(self, x, y):
method check_image_size (line 472) | def check_image_size(self, x):
method pre_forward_features (line 479) | def pre_forward_features(self, x):
method forward_features (line 498) | def forward_features(self, x):
method forward (line 516) | def forward(self, x, print_time=False):
method get_ref_flows (line 593) | def get_ref_flows(self, x):
FILE: code/real/bsrt/model/checkpoint.py
function detach_variable (line 5) | def detach_variable(inputs):
function check_backward_validity (line 18) | def check_backward_validity(inputs):
class CheckpointFunction (line 23) | class CheckpointFunction(torch.autograd.Function):
method forward (line 25) | def forward(ctx, run_function, length, *args):
method backward (line 34) | def backward(ctx, *output_grads):
FILE: code/real/bsrt/model/common.py
function default_conv (line 9) | def default_conv(in_channels, out_channels, kernel_size, bias=True):
class MeanShift (line 15) | class MeanShift(nn.Conv2d):
method __init__ (line 16) | def __init__(
class BasicBlock (line 27) | class BasicBlock(nn.Sequential):
method __init__ (line 28) | def __init__(
class ResBlock (line 41) | class ResBlock(nn.Module):
method __init__ (line 42) | def __init__(
method forward (line 58) | def forward(self, x):
class Upsampler (line 65) | class Upsampler(nn.Sequential):
method __init__ (line 66) | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
class UpOnly (line 95) | class UpOnly(nn.Sequential):
method __init__ (line 96) | def __init__(self, scale):
function lanczos_kernel (line 114) | def lanczos_kernel(dx, a=3, N=None, dtype=None, device=None):
function lanczos_shift (line 160) | def lanczos_shift(img, shift, p=5, a=3):
FILE: code/real/bsrt/model/non_local/network.py
class Network (line 8) | class Network(nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 43) | def forward(self, x):
method forward_with_nl_map (line 57) | def forward_with_nl_map(self, x):
FILE: code/real/bsrt/model/non_local/non_local_concatenation.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 68) | def forward(self, x, return_nl_map=False):
class NONLocalBlock1D (line 109) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 110) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 117) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 118) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 125) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 126) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/real/bsrt/model/non_local/non_local_cross_dot_product.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 63) | def forward(self, x, ref, return_nl_map=False):
class NONLocalBlock1D (line 93) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 94) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 101) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 102) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 109) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 110) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/real/bsrt/model/non_local/non_local_dot_product.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 63) | def forward(self, x, return_nl_map=False):
class NONLocalBlock1D (line 93) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 94) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 101) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 102) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 109) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 110) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/real/bsrt/model/non_local/non_local_embedded_gaussian.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 70) | def forward(self, x, return_nl_map=False):
class NONLocalBlock1D (line 99) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 100) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 107) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 108) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 115) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 116) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/real/bsrt/model/non_local/non_local_gaussian.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 57) | def forward(self, x, return_nl_map=False):
class NONLocalBlock1D (line 95) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 96) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 103) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 104) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 111) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 112) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/real/bsrt/model/swin_util.py
class Mlp (line 17) | class Mlp(nn.Module):
method __init__ (line 18) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 27) | def forward(self, x):
class Mlp_GEGLU (line 35) | class Mlp_GEGLU(nn.Module):
method __init__ (line 45) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 56) | def forward(self, x):
function window_partition (line 64) | def window_partition(x, window_size):
function window_reverse (line 79) | def window_reverse(windows, window_size, H, W):
class WindowAttention (line 96) | class WindowAttention(nn.Module):
method __init__ (line 110) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
method forward (line 145) | def forward(self, x, mask=None):
method extra_repr (line 180) | def extra_repr(self) -> str:
method flops (line 183) | def flops(self, N):
function calculate_mask (line 197) | def calculate_mask(x_size, window_size, shift_size):
class SwinTransformerBlock (line 221) | class SwinTransformerBlock(nn.Module):
method __init__ (line 240) | def __init__(self, dim, input_resolution, num_heads, window_size=7, sh...
method forward (line 275) | def forward(self, x, x_size):
method extra_repr (line 332) | def extra_repr(self) -> str:
method flops (line 336) | def flops(self):
class PatchMerging (line 351) | class PatchMerging(nn.Module):
method __init__ (line 360) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
method forward (line 367) | def forward(self, x):
method extra_repr (line 390) | def extra_repr(self) -> str:
method flops (line 393) | def flops(self):
class BasicLayer (line 400) | class BasicLayer(nn.Module):
method __init__ (line 420) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
method forward (line 448) | def forward(self, x, x_size):
method extra_repr (line 459) | def extra_repr(self) -> str:
method flops (line 462) | def flops(self):
class RSTB (line 471) | class RSTB(nn.Module):
method __init__ (line 494) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
method forward (line 536) | def forward(self, x, x_size):
method flops (line 541) | def flops(self):
class PatchEmbed (line 552) | class PatchEmbed(nn.Module):
method __init__ (line 563) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
method forward (line 581) | def forward(self, x, use_norm=True):
method flops (line 587) | def flops(self):
class PatchUnEmbed (line 595) | class PatchUnEmbed(nn.Module):
method __init__ (line 606) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
method forward (line 619) | def forward(self, x, x_size):
method flops (line 624) | def flops(self):
class Upsample (line 629) | class Upsample(nn.Sequential):
method __init__ (line 637) | def __init__(self, scale, num_feat):
class UpsampleOneStep (line 651) | class UpsampleOneStep(nn.Sequential):
method __init__ (line 661) | def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
method flops (line 669) | def flops(self):
class SwinIR (line 675) | class SwinIR(nn.Module):
method __init__ (line 703) | def __init__(self, img_size=64, patch_size=1, in_chans=3,
method _init_weights (line 824) | def _init_weights(self, m):
method no_weight_decay (line 834) | def no_weight_decay(self):
method no_weight_decay_keywords (line 838) | def no_weight_decay_keywords(self):
method check_image_size (line 841) | def check_image_size(self, x):
method forward_features (line 848) | def forward_features(self, x):
method forward (line 863) | def forward(self, x):
method flops (line 899) | def flops(self):
FILE: code/real/bsrt/model/utils/interp_methods.py
function set_framework_dependencies (line 17) | def set_framework_dependencies(x):
function support_sz (line 28) | def support_sz(sz):
function cubic (line 35) | def cubic(x):
function lanczos2 (line 45) | def lanczos2(x):
function lanczos3 (line 51) | def lanczos3(x):
function linear (line 57) | def linear(x):
function box (line 63) | def box(x):
FILE: code/real/bsrt/model/utils/psconv.py
class PyConv2d (line 4) | class PyConv2d(nn.Module):
method __init__ (line 24) | def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_g...
method forward (line 36) | def forward(self, x):
class PSConv2d (line 45) | class PSConv2d(nn.Module):
method __init__ (line 46) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 69) | def forward(self, x):
class PSGConv2d (line 76) | class PSGConv2d(nn.Module):
method __init__ (line 77) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 102) | def forward(self, x):
FILE: code/real/bsrt/model/utils/resize_right.py
class NoneClass (line 6) | class NoneClass:
function resize (line 29) | def resize(input, scale_factors=None, out_shape=None,
class ResizeLayer (line 80) | class ResizeLayer(nnModuleWrapped):
method __init__ (line 81) | def __init__(self, in_shape, scale_factors=None, out_shape=None,
method forward (line 132) | def forward(self, input):
function prepare_weights_and_field_of_view_1d (line 147) | def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,
function apply_weights (line 174) | def apply_weights(input, field_of_view, weights, dim, n_dims, fw):
function set_scale_and_out_sz (line 211) | def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):
function get_projected_grid (line 251) | def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):
function get_field_of_view (line 266) | def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):
function get_weights (line 288) | def get_weights(interp_method, projected_grid, field_of_view):
function apply_antialiasing_if_needed (line 301) | def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
function fw_ceil (line 315) | def fw_ceil(x, fw):
function fw_cat (line 322) | def fw_cat(x, fw):
function fw_swapaxes (line 329) | def fw_swapaxes(x, ax_1, ax_2, fw):
function fw_set_device (line 335) | def fw_set_device(x, device, fw):
FILE: code/real/bsrt/pwcnet/correlation/correlation.py
function cupy_kernel (line 236) | def cupy_kernel(strFunction, objVariables):
function cupy_launch (line 275) | def cupy_launch(strFunction, strKernel):
class _FunctionCorrelation (line 279) | class _FunctionCorrelation(torch.autograd.Function):
method forward (line 282) | def forward(self, first, second):
method backward (line 336) | def backward(self, gradOutput):
function FunctionCorrelation (line 388) | def FunctionCorrelation(tenFirst, tenSecond):
class ModuleCorrelation (line 392) | class ModuleCorrelation(torch.nn.Module):
method __init__ (line 393) | def __init__(self):
method forward (line 397) | def forward(self, tenFirst, tenSecond):
FILE: code/real/bsrt/pwcnet/pwcnet.py
function backwarp (line 21) | def backwarp(tenInput, tenFlow):
class Network (line 44) | class Network(torch.nn.Module):
method __init__ (line 45) | def __init__(self):
method forward (line 224) | def forward(self, tenFirst, tenSecond):
class PWCNet (line 237) | class PWCNet(torch.nn.Module):
method __init__ (line 238) | def __init__(self, load_pretrained=True, weights_path=None, rgb2bgr=Fa...
method forward (line 253) | def forward(self, source_img, target_img):
FILE: code/real/bsrt/pwcnet/run.py
function backwarp (line 46) | def backwarp(tenInput, tenFlow):
class Network (line 70) | class Network(torch.nn.Module):
method __init__ (line 71) | def __init__(self):
method forward (line 262) | def forward(self, tenFirst, tenSecond):
function estimate (line 280) | def estimate(tenFirst, tenSecond):
FILE: code/real/bsrt/scripts/cal_mean_std.py
function main (line 9) | def main():
FILE: code/real/bsrt/scripts/download_burstsr_dataset.py
function download_burstsr_dataset (line 8) | def download_burstsr_dataset(download_path):
function main (line 59) | def main():
FILE: code/real/bsrt/scripts/evaluate_burstsr_val.py
class SimpleBaseline (line 8) | class SimpleBaseline:
method __init__ (line 9) | def __init__(self):
method __call__ (line 12) | def __call__(self, burst):
function main (line 19) | def main():
FILE: code/real/bsrt/scripts/save_results_synburst_val.py
class SimpleBaseline (line 9) | class SimpleBaseline:
method __init__ (line 10) | def __init__(self):
method __call__ (line 13) | def __call__(self, burst):
function main (line 20) | def main():
FILE: code/real/bsrt/scripts/test_burstsr_dataset.py
function main (line 11) | def main():
FILE: code/real/bsrt/scripts/test_synthetic_bursts.py
function main (line 11) | def main():
FILE: code/real/bsrt/test.py
function main_worker (line 25) | def main_worker(local_rank, nprocs, args):
function main (line 58) | def main():
FILE: code/real/bsrt/test_real.py
function main (line 31) | def main():
function main_worker (line 35) | def main_worker(local_rank, nprocs, args):
FILE: code/real/bsrt/trainer.py
class Trainer (line 47) | class Trainer():
method __init__ (line 48) | def __init__(self, args, train_loader, train_sampler, valid_loader, my...
method train (line 103) | def train(self):
method test (line 210) | def test(self):
method save_model (line 313) | def save_model(self, filename):
method prepare (line 322) | def prepare(self, *args):
method terminate (line 332) | def terminate(self):
FILE: code/real/bsrt/utility.py
function reduce_mean (line 24) | def reduce_mean(tensor, nprocs):
function setup (line 31) | def setup(rank, world_size):
function cleanup (line 54) | def cleanup():
function mkdir (line 58) | def mkdir(path):
class timer (line 63) | class timer():
method __init__ (line 64) | def __init__(self):
method tic (line 68) | def tic(self):
method toc (line 71) | def toc(self, restart=False):
method hold (line 76) | def hold(self):
method release (line 79) | def release(self):
method reset (line 85) | def reset(self):
class checkpoint (line 89) | class checkpoint():
method __init__ (line 90) | def __init__(self, args):
method get_path (line 127) | def get_path(self, *subdir):
method save (line 130) | def save(self, trainer, epoch, is_best=False):
method add_log (line 139) | def add_log(self, log):
method write_log (line 142) | def write_log(self, log, refresh=False):
method done (line 149) | def done(self):
method plot_psnr (line 152) | def plot_psnr(self, epoch):
method begin_background (line 171) | def begin_background(self):
method end_background (line 188) | def end_background(self):
method save_results (line 193) | def save_results(self, dataset, filename, save_list, scale):
function quantize (line 207) | def quantize(img, rgb_range):
function calc_psnr (line 212) | def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
function make_optimizer (line 231) | def make_optimizer(args, target):
function write_gray_to_tfboard (line 287) | def write_gray_to_tfboard(img):
function bayer_unify (line 309) | def bayer_unify(raw, input_pattern, target_pattern, mode) -> np.ndarray:
function bayer_aug (line 341) | def bayer_aug(raw, flip_h=False, flip_w=False, transpose=False, input_pa...
FILE: code/real/bsrt/utils/data_format_utils.py
function numpy_to_torch (line 6) | def numpy_to_torch(a: np.ndarray):
function torch_to_numpy (line 10) | def torch_to_numpy(a: torch.Tensor):
function torch_to_npimage (line 14) | def torch_to_npimage(a: torch.Tensor, unnormalize=True):
function npimage_to_torch (line 23) | def npimage_to_torch(a, normalize=True, input_bgr=True):
function convert_dict (line 34) | def convert_dict(base_dict, batch_sz):
FILE: code/real/bsrt/utils/debayer.py
class Debayer3x3 (line 5) | class Debayer3x3(torch.nn.Module):
method __init__ (line 29) | def __init__(self):
method forward (line 71) | def forward(self, x):
class Debayer2x2 (line 91) | class Debayer2x2(torch.nn.Module):
method __init__ (line 99) | def __init__(self):
method forward (line 115) | def forward(self, x):
class DebayerSplit (line 133) | class DebayerSplit(torch.nn.Module):
method __init__ (line 140) | def __init__(self):
method forward (line 151) | def forward(self, x):
FILE: code/real/bsrt/utils/interp_methods.py
function set_framework_dependencies (line 17) | def set_framework_dependencies(x):
function support_sz (line 28) | def support_sz(sz):
function cubic (line 35) | def cubic(x):
function lanczos2 (line 45) | def lanczos2(x):
function lanczos3 (line 51) | def lanczos3(x):
function linear (line 57) | def linear(x):
function box (line 63) | def box(x):
FILE: code/real/bsrt/utils/metrics.py
class MSSSIMLoss (line 15) | class MSSSIMLoss(nn.Module):
method __init__ (line 16) | def __init__(self, boundary_ignore=None):
method forward (line 21) | def forward(self, pred, gt, valid=None):
class CharbonnierLoss (line 33) | class CharbonnierLoss(nn.Module):
method __init__ (line 34) | def __init__(self, boundary_ignore=None):
method forward (line 39) | def forward(self, pred, gt, valid=None):
class L1 (line 51) | class L1(nn.Module):
method __init__ (line 52) | def __init__(self, boundary_ignore=None):
method forward (line 56) | def forward(self, pred, gt, valid=None):
class L2 (line 78) | class L2(nn.Module):
method __init__ (line 79) | def __init__(self, boundary_ignore=None):
method forward (line 83) | def forward(self, pred, gt, valid=None):
class PSNR (line 106) | class PSNR(nn.Module):
method __init__ (line 107) | def __init__(self, boundary_ignore=None, max_value=1.0):
method psnr (line 112) | def psnr(self, pred, gt, valid=None):
method forward (line 119) | def forward(self, pred, gt, valid=None):
class AlignedL1 (line 130) | class AlignedL1(nn.Module):
method __init__ (line 131) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):
method forward (line 139) | def forward(self, pred, gt, burst_input):
class AlignedL2 (line 180) | class AlignedL2(nn.Module):
method __init__ (line 181) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):
method forward (line 190) | def forward(self, pred, gt, burst_input):
class AlignedPSNR (line 237) | class AlignedPSNR(nn.Module):
method __init__ (line 238) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None, m...
method psnr (line 243) | def psnr(self, pred, gt, burst_input):
method forward (line 250) | def forward(self, pred, gt, burst_input):
class AlignedSSIM (line 259) | class AlignedSSIM(nn.Module):
method __init__ (line 260) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):
method _ssim (line 268) | def _ssim(self, pred, gt, burst_input):
method forward (line 307) | def forward(self, pred, gt, burst_input):
class AlignedLPIPS (line 313) | class AlignedLPIPS(nn.Module):
method __init__ (line 314) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):
method _lpips (line 323) | def _lpips(self, pred, gt, burst_input):
method forward (line 357) | def forward(self, pred, gt, burst_input):
FILE: code/real/bsrt/utils/postprocessing_functions.py
class SimplePostProcess (line 7) | class SimplePostProcess:
method __init__ (line 8) | def __init__(self, gains=True, ccm=True, gamma=True, smoothstep=True, ...
method process (line 15) | def process(self, image, meta_info):
function process_linear_image_rgb (line 20) | def process_linear_image_rgb(image, meta_info, gains=True, ccm=True, gam...
class BurstSRPostProcess (line 40) | class BurstSRPostProcess:
method __init__ (line 41) | def __init__(self, no_white_balance=False, gamma=True, smoothstep=True...
method process (line 47) | def process(self, image, meta_info, external_norm_factor=None):
function process_burstsr_image_rgb (line 53) | def process_burstsr_image_rgb(im, meta_info, return_np=False, external_n...
FILE: code/real/bsrt/utils/resize_right.py
class NoneClass (line 6) | class NoneClass:
function resize (line 29) | def resize(input, scale_factors=None, out_shape=None,
class ResizeLayer (line 80) | class ResizeLayer(nnModuleWrapped):
method __init__ (line 81) | def __init__(self, in_shape, scale_factors=None, out_shape=None,
method forward (line 132) | def forward(self, input):
function prepare_weights_and_field_of_view_1d (line 147) | def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,
function apply_weights (line 174) | def apply_weights(input, field_of_view, weights, dim, n_dims, fw):
function set_scale_and_out_sz (line 211) | def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):
function get_projected_grid (line 251) | def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):
function get_field_of_view (line 266) | def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):
function get_weights (line 288) | def get_weights(interp_method, projected_grid, field_of_view):
function apply_antialiasing_if_needed (line 301) | def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
function fw_ceil (line 315) | def fw_ceil(x, fw):
function fw_cat (line 322) | def fw_cat(x, fw):
function fw_swapaxes (line 329) | def fw_swapaxes(x, ax_1, ax_2, fw):
function fw_set_device (line 335) | def fw_set_device(x, device, fw):
FILE: code/real/bsrt/utils/spatial_color_alignment.py
function gauss_1d (line 6) | def gauss_1d(sz, sigma, center, end_pad=0, density=False):
function gauss_2d (line 15) | def gauss_2d(sz, sigma, center, end_pad=(0, 0), density=False):
function get_gaussian_kernel (line 29) | def get_gaussian_kernel(sd):
function apply_kernel (line 38) | def apply_kernel(im, ksz, gauss_kernel):
function match_colors (line 48) | def match_colors(im_ref, im_q, im_test, ksz, gauss_kernel):
FILE: code/real/bsrt/utils/stn.py
class SpatialTransformer (line 6) | class SpatialTransformer(nn.Module):
method __init__ (line 12) | def __init__(self, size, mode='bilinear'):
method forward (line 31) | def forward(self, src, flow):
FILE: code/real/bsrt/utils/warp.py
function warp (line 6) | def warp(feat, flow, mode='bilinear', padding_mode='zeros'):
FILE: code/real/bsrt/validate.py
function main (line 29) | def main():
function main_worker (line 33) | def main_worker(local_rank, nprocs, args):
FILE: code/synthetic/bsrt/data_processing/camera_pipeline.py
function random_ccm (line 13) | def random_ccm():
function random_gains (line 47) | def random_gains():
function apply_smoothstep (line 58) | def apply_smoothstep(image):
function invert_smoothstep (line 64) | def invert_smoothstep(image):
function gamma_expansion (line 70) | def gamma_expansion(image):
function gamma_compression (line 76) | def gamma_compression(image):
function apply_ccm (line 82) | def apply_ccm(image, ccm):
function apply_gains (line 95) | def apply_gains(image, rgb_gain, red_gain, blue_gain):
function safe_invert_gains (line 109) | def safe_invert_gains(image, rgb_gain, red_gain, blue_gain):
function mosaic (line 125) | def mosaic(image, mode='rggb'):
function demosaic (line 151) | def demosaic(image):
function random_noise_levels (line 188) | def random_noise_levels():
function add_noise (line 201) | def add_noise(image, shot_noise=0.01, read_noise=0.0005):
function process_linear_image_rgb (line 208) | def process_linear_image_rgb(image, meta_info, return_np=False):
function process_linear_image_raw (line 225) | def process_linear_image_raw(image, meta_info):
FILE: code/synthetic/bsrt/data_processing/synthetic_burst_generation.py
function random_crop (line 10) | def random_crop(frames, crop_sz):
function rgb2rawburst (line 49) | def rgb2rawburst(image, burst_size, downsample_factor=1, burst_transform...
function get_tmat (line 123) | def get_tmat(image_shape, translation, theta, shear_values, scale_factors):
function single2lrburst (line 149) | def single2lrburst(image, burst_size, downsample_factor=1, transformatio...
FILE: code/synthetic/bsrt/datasets/burstsr_dataset.py
class SamsungRAWImage (line 10) | class SamsungRAWImage:
method load (line 12) | def load(path):
method __init__ (line 24) | def __init__(self, im_raw, black_level, cam_wb, daylight_wb, color_mat...
method get_all_meta_data (line 38) | def get_all_meta_data(self):
method get_exposure_time (line 42) | def get_exposure_time(self):
method get_noise_profile (line 45) | def get_noise_profile(self):
method get_f_number (line 51) | def get_f_number(self):
method get_iso (line 54) | def get_iso(self):
method get_image_data (line 57) | def get_image_data(self, substract_black_level=False, white_balance=Fa...
method shape (line 72) | def shape(self):
method crop_image (line 76) | def crop_image(self, r1, r2, c1, c2):
method get_crop (line 79) | def get_crop(self, r1, r2, c1, c2):
method postprocess (line 90) | def postprocess(self, return_np=True, norm_factor=None):
class CanonImage (line 113) | class CanonImage:
method load (line 115) | def load(path, split='train'):
method __init__ (line 125) | def __init__(self, im_raw, black_level, cam_wb, daylight_wb, rgb_xyz_m...
method shape (line 150) | def shape(self):
method get_all_meta_data (line 154) | def get_all_meta_data(self):
method get_exposure_time (line 159) | def get_exposure_time(self):
method get_f_number (line 162) | def get_f_number(self):
method get_iso (line 165) | def get_iso(self):
method get_image_data (line 168) | def get_image_data(self, substract_black_level=False, white_balance=Fa...
method set_image_data (line 182) | def set_image_data(self, im_data):
method crop_image (line 185) | def crop_image(self, r1, r2, c1, c2):
method get_crop (line 188) | def get_crop(self, r1, r2, c1, c2):
method set_crop_info (line 193) | def set_crop_info(self, crop_info):
method resize (line 196) | def resize(self, size=None, scale_factor=None):
method postprocess (line 201) | def postprocess(self, return_np=True):
function load_txt (line 216) | def load_txt(path):
class BurstSRDataset (line 223) | class BurstSRDataset(torch.utils.data.Dataset):
method __init__ (line 225) | def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, ...
method _get_burst_list (line 255) | def _get_burst_list(self):
method get_burst_info (line 260) | def get_burst_info(self, burst_id):
method _get_raw_image (line 264) | def _get_raw_image(self, burst_id, im_id):
method _get_gt_image (line 268) | def _get_gt_image(self, burst_id):
method get_burst (line 272) | def get_burst(self, burst_id, im_ids, info=None):
method _sample_images (line 281) | def _sample_images(self):
method __len__ (line 288) | def __len__(self):
method __getitem__ (line 291) | def __getitem__(self, index):
function pack_raw_image (line 400) | def pack_raw_image(im_raw):
function flatten_raw_image (line 415) | def flatten_raw_image(im_raw_4ch):
function pack_raw_image_batch (line 430) | def pack_raw_image_batch(im_raw):
function flatten_raw_image_batch (line 439) | def flatten_raw_image_batch(im_raw_4ch):
FILE: code/synthetic/bsrt/datasets/burstsr_test_dataset.py
class BurstSRDataset (line 8) | class BurstSRDataset(torch.utils.data.Dataset):
method __init__ (line 10) | def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, ...
method _get_burst_list (line 40) | def _get_burst_list(self):
method get_burst_info (line 45) | def get_burst_info(self, burst_id):
method _get_raw_image (line 49) | def _get_raw_image(self, burst_id, im_id):
method get_burst (line 53) | def get_burst(self, burst_id, im_ids, info=None):
method _sample_images (line 61) | def _sample_images(self):
method __len__ (line 68) | def __len__(self):
method __getitem__ (line 71) | def __getitem__(self, index):
FILE: code/synthetic/bsrt/datasets/data_sampler.py
class DistIterSampler (line 13) | class DistIterSampler(Sampler):
method __init__ (line 31) | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
method __iter__ (line 47) | def __iter__(self):
method __len__ (line 64) | def __len__(self):
method set_epoch (line 67) | def set_epoch(self, epoch):
FILE: code/synthetic/bsrt/datasets/realworld_burst_test_set.py
class RealWorldBurstTest (line 7) | class RealWorldBurstTest(torch.utils.data.Dataset):
method __init__ (line 10) | def __init__(self, root):
method __len__ (line 15) | def __len__(self):
method _read_burst_image (line 18) | def _read_burst_image(self, index, image_id):
method __getitem__ (line 23) | def __getitem__(self, index):
FILE: code/synthetic/bsrt/datasets/synthetic_burst_test_set.py
class SyntheticBurstTest (line 7) | class SyntheticBurstTest(torch.utils.data.Dataset):
method __init__ (line 11) | def __init__(self, root):
method __len__ (line 16) | def __len__(self):
method _read_burst_image (line 19) | def _read_burst_image(self, index, image_id):
method __getitem__ (line 24) | def __getitem__(self, index):
FILE: code/synthetic/bsrt/datasets/synthetic_burst_train_set.py
class SyntheticBurst (line 8) | class SyntheticBurst(torch.utils.data.Dataset):
method __init__ (line 18) | def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=...
method __len__ (line 37) | def __len__(self):
method __getitem__ (line 40) | def __getitem__(self, index):
FILE: code/synthetic/bsrt/datasets/synthetic_burst_val_set.py
class SyntheticBurstVal (line 8) | class SyntheticBurstVal(torch.utils.data.Dataset):
method __init__ (line 15) | def __init__(self, root=None, initialize=True):
method initialize (line 25) | def initialize(self):
method __len__ (line 28) | def __len__(self):
method _read_burst_image (line 31) | def _read_burst_image(self, index, image_id):
method _read_gt_image (line 37) | def _read_gt_image(self, index):
method _read_meta_info (line 42) | def _read_meta_info(self, index):
method __getitem__ (line 48) | def __getitem__(self, index):
FILE: code/synthetic/bsrt/datasets/zurich_raw2rgb_dataset.py
class ZurichRAW2RGB (line 7) | class ZurichRAW2RGB(torch.utils.data.Dataset):
method __init__ (line 12) | def __init__(self, root, split='train'):
method _get_image_list (line 22) | def _get_image_list(self, split):
method _get_image (line 33) | def _get_image(self, im_id):
method get_image (line 38) | def get_image(self, im_id):
method __len__ (line 43) | def __len__(self):
method __getitem__ (line 46) | def __getitem__(self, index):
FILE: code/synthetic/bsrt/loss/Charbonnier.py
class CharbonnierLoss (line 5) | class CharbonnierLoss(nn.Module):
method __init__ (line 8) | def __init__(self, epsilon=1e-3, reduce=True):
method forward (line 13) | def forward(self, X, Y):
FILE: code/synthetic/bsrt/loss/__init__.py
class Loss (line 14) | class Loss(nn.modules.loss._Loss):
method __init__ (line 15) | def __init__(self, args, ckp):
method forward (line 80) | def forward(self, sr, hr):
method step (line 97) | def step(self):
method start_log (line 102) | def start_log(self):
method end_log (line 105) | def end_log(self, n_batches):
method display_loss (line 108) | def display_loss(self, batch):
method plot_loss (line 116) | def plot_loss(self, apath, epoch):
method get_loss_module (line 130) | def get_loss_module(self):
method save (line 136) | def save(self, apath):
method load (line 140) | def load(self, apath, cpu=False):
FILE: code/synthetic/bsrt/loss/adversarial.py
class Adversarial (line 12) | class Adversarial(nn.Module):
method __init__ (line 13) | def __init__(self, args, gan_type):
method forward (line 36) | def forward(self, fake, real):
method state_dict (line 96) | def state_dict(self, *args, **kwargs):
method bce (line 102) | def bce(self, real, fake):
FILE: code/synthetic/bsrt/loss/discriminator.py
class Discriminator (line 5) | class Discriminator(nn.Module):
method __init__ (line 9) | def __init__(self, args, gan_type='GAN'):
method forward (line 65) | def forward(self, x):
FILE: code/synthetic/bsrt/loss/filter.py
class Filter (line 5) | class Filter(nn.Module):
method __init__ (line 6) | def __init__(self, args):
method forward (line 16) | def forward(self, x, y):
FILE: code/synthetic/bsrt/loss/hist_entropy.py
class HistEntropy (line 5) | class HistEntropy(nn.Module):
method __init__ (line 6) | def __init__(self, args):
method forward (line 10) | def forward(self, x):
FILE: code/synthetic/bsrt/loss/mssim.py
function gaussian (line 7) | def gaussian(window_size, sigma):
function create_window (line 12) | def create_window(window_size, channel=1):
function ssim (line 19) | def ssim(img1, img2, window_size=11, window=None, size_average=True, ful...
function msssim (line 71) | def msssim(img1, img2, window_size=11, size_average=True, val_range=None...
class SSIM (line 109) | class SSIM(torch.nn.Module):
method __init__ (line 110) | def __init__(self, window_size=11, size_average=True, val_range=None):
method forward (line 120) | def forward(self, img1, img2):
class MSSSIM (line 132) | class MSSSIM(torch.nn.Module):
method __init__ (line 133) | def __init__(self, window_size=11, size_average=True, channel=3):
method forward (line 139) | def forward(self, img1, img2):
FILE: code/synthetic/bsrt/loss/vgg.py
class VGG (line 8) | class VGG(nn.Module):
method __init__ (line 9) | def __init__(self, conv_index, rgb_range=1):
method forward (line 24) | def forward(self, sr, hr):
FILE: code/synthetic/bsrt/main.py
function init_seeds (line 23) | def init_seeds(seed=0, cuda_deterministic=True):
function main (line 38) | def main():
function main_worker (line 45) | def main_worker(local_rank, nprocs, args):
FILE: code/synthetic/bsrt/model/DCNv2/dcn_v2.py
class _DCNv2 (line 17) | class _DCNv2(Function):
method forward (line 21) | def forward(
method backward (line 52) | def backward(ctx, grad_output):
method symbolic (line 75) | def symbolic(
class DCNv2 (line 102) | class DCNv2(nn.Module):
method __init__ (line 103) | def __init__(
method reset_parameters (line 126) | def reset_parameters(self):
method forward (line 134) | def forward(self, input, offset, mask):
class DCN (line 153) | class DCN(DCNv2):
method __init__ (line 154) | def __init__(
method init_offset (line 179) | def init_offset(self):
method forward (line 183) | def forward(self, input):
class DCN_sep (line 201) | class DCN_sep(DCNv2):
method __init__ (line 204) | def __init__(self,
method init_offset (line 225) | def init_offset(self):
method forward (line 229) | def forward(self, input, fea):
class FlowGuidedDCN (line 249) | class FlowGuidedDCN(DCNv2):
method __init__ (line 252) | def __init__(self,
method init_offset (line 269) | def init_offset(self):
method forward (line 273) | def forward(self, input, fea, flows):
class InsideFlowGuidedDCN (line 296) | class InsideFlowGuidedDCN(DCNv2):
method __init__ (line 299) | def __init__(self,
method reset_parameters (line 322) | def reset_parameters(self):
method init_offset (line 331) | def init_offset(self):
method forward (line 335) | def forward(self, input, warped, ref, flows):
class _DCNv2Pooling (line 359) | class _DCNv2Pooling(Function):
method forward (line 361) | def forward(
method backward (line 402) | def backward(ctx, grad_output):
class DCNv2Pooling (line 426) | class DCNv2Pooling(nn.Module):
method __init__ (line 427) | def __init__(
method forward (line 448) | def forward(self, input, rois, offset):
class DCNPooling (line 467) | class DCNPooling(DCNv2Pooling):
method __init__ (line 468) | def __init__(
method forward (line 506) | def forward(self, input, rois):
FILE: code/synthetic/bsrt/model/DCNv2/setup.py
function get_extensions (line 13) | def get_extensions():
FILE: code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp
function dcn_v2_cpu_forward (line 17) | at::Tensor
function dcn_v2_cpu_backward (line 109) | std::vector<at::Tensor> dcn_v2_cpu_backward(const at::Tensor &input,
FILE: code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp
function dmcn_im2col_bilinear_cpu (line 27) | float dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_...
function dmcn_get_gradient_weight_cpu (line 58) | float dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w,
function dmcn_get_coordinate_weight_cpu (line 84) | float dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w,
function modulated_deformable_im2col_cpu_kernel (line 127) | void modulated_deformable_im2col_cpu_kernel(const int n, const float *da...
function modulated_deformable_col2im_cpu_kernel (line 198) | void modulated_deformable_col2im_cpu_kernel(const int n, const float *da...
function modulated_deformable_col2im_coord_cpu_kernel (line 259) | void modulated_deformable_col2im_coord_cpu_kernel(const int n, const flo...
function modulated_deformable_im2col_cpu (line 331) | void modulated_deformable_im2col_cpu(const float* data_im, const float* ...
function modulated_deformable_col2im_cpu (line 353) | void modulated_deformable_col2im_cpu(const float* data_col, const float*...
function modulated_deformable_col2im_coord_cpu (line 375) | void modulated_deformable_col2im_coord_cpu(const float* data_col, const ...
FILE: code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp
function T (line 34) | T bilinear_interp_cpu(
function DeformablePSROIPoolForwardKernelCpu (line 59) | void DeformablePSROIPoolForwardKernelCpu(
function DeformablePSROIPoolBackwardAccKernelCpu (line 149) | void DeformablePSROIPoolBackwardAccKernelCpu(
function dcn_v2_psroi_pooling_cpu_forward (line 278) | std::tuple<at::Tensor, at::Tensor>
function dcn_v2_psroi_pooling_cpu_backward (line 350) | std::tuple<at::Tensor, at::Tensor>
FILE: code/synthetic/bsrt/model/DCNv2/src/vision.cpp
function PYBIND11_MODULE (line 4) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: code/synthetic/bsrt/model/DCNv2/test.py
function conv_identify (line 20) | def conv_identify(weight, bias):
function check_zero_offset (line 32) | def check_zero_offset():
function check_gradient_dconv (line 69) | def check_gradient_dconv():
function check_pooling_zero_offset (line 100) | def check_pooling_zero_offset():
function check_gradient_dpooling (line 134) | def check_gradient_dpooling():
function example_dconv (line 169) | def example_dconv():
function example_dpooling (line 183) | def example_dpooling():
function example_mdpooling (line 226) | def example_mdpooling():
FILE: code/synthetic/bsrt/model/__init__.py
class Model (line 10) | class Model(nn.Module):
method __init__ (line 11) | def __init__(self, args, ckp):
method forward (line 53) | def forward(self, x, idx_scale):
method save (line 73) | def save(self, apath, epoch, is_best=False):
method load (line 90) | def load(self, apath, pre_train='', resume=-1, cpu=False):
method forward_chop (line 217) | def forward_chop(self, *args, shave=10, min_size=160000):
method forward_x8 (line 276) | def forward_x8(self, *args, forward_function=None):
FILE: code/synthetic/bsrt/model/arch_util.py
function initialize_weights (line 9) | def initialize_weights(net_l, scale=1):
function make_layer (line 29) | def make_layer(block, n_layers):
function conv_layer (line 38) | def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding...
class ESA (line 42) | class ESA(nn.Module):
method __init__ (line 43) | def __init__(self, n_feats, conv=conv_layer):
method forward (line 56) | def forward(self, x):
class DWConv (line 71) | class DWConv(nn.Module):
method __init__ (line 72) | def __init__(self, dim=768):
method forward (line 76) | def forward(self, x):
class SELayer (line 82) | class SELayer(nn.Module):
method __init__ (line 86) | def __init__(self, channel, reduction=16):
method forward (line 96) | def forward(self, x):
class ResidualBlock_noBN (line 102) | class ResidualBlock_noBN(nn.Module):
method __init__ (line 108) | def __init__(self, nf=64):
method forward (line 116) | def forward(self, x):
class ResidualBlock_SE (line 123) | class ResidualBlock_SE(nn.Module):
method __init__ (line 129) | def __init__(self, nf=64, reduction=16):
method forward (line 138) | def forward(self, x):
class _PositionAttentionModule (line 148) | class _PositionAttentionModule(nn.Module):
method __init__ (line 151) | def __init__(self, in_channels, **kwargs):
method forward (line 159) | def forward(self, x):
class SALayer (line 171) | class SALayer(nn.Module):
method __init__ (line 172) | def __init__(self, wn=None):
method forward (line 178) | def forward(self, x):
class CALayerV2 (line 186) | class CALayerV2(nn.Module):
method __init__ (line 187) | def __init__(self, n_feat, reduction=16, wn=None):
method forward (line 200) | def forward(self, x):
class DALayer (line 207) | class DALayer(nn.Module):
method __init__ (line 208) | def __init__(self, channel, reduction, wn):
method forward (line 215) | def forward(self, x):
class CALayer (line 223) | class CALayer(nn.Module):
method __init__ (line 224) | def __init__(self, channel, reduction, wn):
method forward (line 236) | def forward(self, x):
class RCAB (line 243) | class RCAB(nn.Module):
method __init__ (line 244) | def __init__(
method forward (line 266) | def forward(self, x):
class ResidualGroup (line 273) | class ResidualGroup(nn.Module):
method __init__ (line 274) | def __init__(self, n_feat, n_resblocks, da=False):
method forward (line 292) | def forward(self, x):
function make_layer_idx (line 302) | def make_layer_idx(block, n_layers):
class LRSCRCAB (line 309) | class LRSCRCAB(nn.Module):
method __init__ (line 310) | def __init__(
method forward (line 332) | def forward(self, x):
class LRSCPYRCAB (line 339) | class LRSCPYRCAB(nn.Module):
method __init__ (line 340) | def __init__(
method forward (line 366) | def forward(self, x):
class LRSCResidualGroup (line 372) | class LRSCResidualGroup(nn.Module):
method __init__ (line 373) | def __init__(self, n_feat, n_resblocks, da=False, idx=0):
method forward (line 392) | def forward(self, x):
class LRSCPSResidualGroup (line 400) | class LRSCPSResidualGroup(nn.Module):
method __init__ (line 401) | def __init__(self, n_feat, n_resblocks, da=False, idx=0):
method forward (line 421) | def forward(self, x):
class LRSCPyResidualGroup (line 430) | class LRSCPyResidualGroup(nn.Module):
method __init__ (line 431) | def __init__(self, n_feat, n_resblocks, da=False, idx=0):
method forward (line 451) | def forward(self, x):
class LRSCWideActResBlock (line 458) | class LRSCWideActResBlock(nn.Module):
method __init__ (line 459) | def __init__(self, nf=64, idx=0):
method forward (line 482) | def forward(self, x):
class LRSCPyWideActResBlock (line 488) | class LRSCPyWideActResBlock(nn.Module):
method __init__ (line 489) | def __init__(self, nf=64, idx=0):
method forward (line 515) | def forward(self, x):
class LRSCPyWideActResGroup (line 523) | class LRSCPyWideActResGroup(nn.Module):
method __init__ (line 524) | def __init__(self, nf, n_resblocks, idx=0):
method forward (line 539) | def forward(self, x):
class LRSCWideActResGroup (line 548) | class LRSCWideActResGroup(nn.Module):
method __init__ (line 549) | def __init__(self, nf, n_resblocks, idx=0):
method forward (line 564) | def forward(self, x):
class PYRCAB (line 577) | class PYRCAB(nn.Module):
method __init__ (line 578) | def __init__(
method forward (line 603) | def forward(self, x):
class PyResidualGroup (line 609) | class PyResidualGroup(nn.Module):
method __init__ (line 610) | def __init__(self, n_feat, n_resblocks, da=False):
method forward (line 632) | def forward(self, x):
class WideActResBlock (line 637) | class WideActResBlock(nn.Module):
method __init__ (line 638) | def __init__(self, nf=64):
method forward (line 658) | def forward(self, x):
class PSWideActResBlock (line 664) | class PSWideActResBlock(nn.Module):
method __init__ (line 665) | def __init__(self, nf=64):
method forward (line 685) | def forward(self, x):
class PyWideActResBlock (line 691) | class PyWideActResBlock(nn.Module):
method __init__ (line 692) | def __init__(self, nf=64):
method forward (line 717) | def forward(self, x):
function flow_warp (line 723) | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', ali...
FILE: code/synthetic/bsrt/model/bsrt.py
function make_model (line 27) | def make_model(args, parent=False):
class BasicModule (line 66) | class BasicModule(nn.Module):
method __init__ (line 70) | def __init__(self):
method forward (line 80) | def forward(self, tensor_input):
class SpyNet (line 84) | class SpyNet(nn.Module):
method __init__ (line 92) | def __init__(self, load_path=None, return_levels=[5]):
method preprocess (line 110) | def preprocess(self, tensor_input):
method process (line 114) | def process(self, ref, supp, w, h, w_floor, h_floor):
method forward (line 160) | def forward(self, ref, supp):
class FlowGuidedPCDAlign (line 176) | class FlowGuidedPCDAlign(nn.Module):
method __init__ (line 181) | def __init__(self, nf=64, groups=8):
method forward (line 209) | def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l):
class CrossNonLocal_Fusion (line 246) | class CrossNonLocal_Fusion(nn.Module):
method __init__ (line 249) | def __init__(self, nf=64, out_feat=96, nframes=5, center=2):
method forward (line 265) | def forward(self, aligned_fea):
class BSRT (line 287) | class BSRT(nn.Module):
method __init__ (line 288) | def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chan...
method _init_weights (line 452) | def _init_weights(self, m):
method no_weight_decay (line 462) | def no_weight_decay(self):
method no_weight_decay_keywords (line 466) | def no_weight_decay_keywords(self):
method _upsample_add (line 469) | def _upsample_add(self, x, y):
method check_image_size (line 472) | def check_image_size(self, x):
method pre_forward_features (line 479) | def pre_forward_features(self, x):
method forward_features (line 498) | def forward_features(self, x):
method forward (line 516) | def forward(self, x, print_time=False):
method get_ref_flows (line 593) | def get_ref_flows(self, x):
FILE: code/synthetic/bsrt/model/checkpoint.py
function detach_variable (line 5) | def detach_variable(inputs):
function check_backward_validity (line 18) | def check_backward_validity(inputs):
class CheckpointFunction (line 23) | class CheckpointFunction(torch.autograd.Function):
method forward (line 25) | def forward(ctx, run_function, length, *args):
method backward (line 34) | def backward(ctx, *output_grads):
FILE: code/synthetic/bsrt/model/common.py
function default_conv (line 9) | def default_conv(in_channels, out_channels, kernel_size, bias=True):
class MeanShift (line 15) | class MeanShift(nn.Conv2d):
method __init__ (line 16) | def __init__(
class BasicBlock (line 27) | class BasicBlock(nn.Sequential):
method __init__ (line 28) | def __init__(
class ResBlock (line 41) | class ResBlock(nn.Module):
method __init__ (line 42) | def __init__(
method forward (line 58) | def forward(self, x):
class Upsampler (line 65) | class Upsampler(nn.Sequential):
method __init__ (line 66) | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
class UpOnly (line 95) | class UpOnly(nn.Sequential):
method __init__ (line 96) | def __init__(self, scale):
function lanczos_kernel (line 114) | def lanczos_kernel(dx, a=3, N=None, dtype=None, device=None):
function lanczos_shift (line 160) | def lanczos_shift(img, shift, p=5, a=3):
FILE: code/synthetic/bsrt/model/ebsr.py
function make_model (line 27) | def make_model(args, parent=False):
class BasicModule (line 58) | class BasicModule(nn.Module):
method __init__ (line 62) | def __init__(self):
method forward (line 72) | def forward(self, tensor_input):
class SpyNet (line 76) | class SpyNet(nn.Module):
method __init__ (line 84) | def __init__(self, load_path=None, return_levels=[5]):
method preprocess (line 102) | def preprocess(self, tensor_input):
method process (line 106) | def process(self, ref, supp, w, h, w_floor, h_floor):
method forward (line 152) | def forward(self, ref, supp):
class PCD_Align (line 167) | class PCD_Align(nn.Module):
method __init__ (line 172) | def __init__(self, nf=64, groups=8, wn=None):
method forward (line 206) | def forward(self, nbr_fea_l, ref_fea_l):
class FlowGuidedPCDAlign (line 245) | class FlowGuidedPCDAlign(nn.Module):
method __init__ (line 250) | def __init__(self, nf=64, groups=8):
method forward (line 277) | def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l):
class CrossNonLocal_Fusion (line 315) | class CrossNonLocal_Fusion(nn.Module):
method __init__ (line 318) | def __init__(self, nf=64, out_feat=96, nframes=5, center=2):
method forward (line 334) | def forward(self, aligned_fea):
class EBSR (line 356) | class EBSR(nn.Module):
method __init__ (line 360) | def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chan...
method _init_weights (line 536) | def _init_weights(self, m):
method no_weight_decay (line 546) | def no_weight_decay(self):
method no_weight_decay_keywords (line 550) | def no_weight_decay_keywords(self):
method _upsample_add (line 553) | def _upsample_add(self, x, y):
method check_image_size (line 556) | def check_image_size(self, x):
method pre_forward_features (line 563) | def pre_forward_features(self, x):
method forward_features (line 582) | def forward_features(self, x):
method forward (line 600) | def forward(self, x, print_time=False):
method get_ref_flows (line 679) | def get_ref_flows(self, x):
method get_flow_2frames (line 694) | def get_flow_2frames(self, x):
method get_aligned_image_2frames (line 714) | def get_aligned_image_2frames(self, x, flows_backward, flows_forward):
method get_aligned_feature_2frames (line 735) | def get_aligned_feature_2frames(self, x):
FILE: code/synthetic/bsrt/model/non_local/network.py
class Network (line 8) | class Network(nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 43) | def forward(self, x):
method forward_with_nl_map (line 57) | def forward_with_nl_map(self, x):
FILE: code/synthetic/bsrt/model/non_local/non_local_concatenation.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 68) | def forward(self, x, return_nl_map=False):
class NONLocalBlock1D (line 109) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 110) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 117) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 118) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 125) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 126) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/synthetic/bsrt/model/non_local/non_local_cross_dot_product.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 63) | def forward(self, x, ref, return_nl_map=False):
class NONLocalBlock1D (line 93) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 94) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 101) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 102) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 109) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 110) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/synthetic/bsrt/model/non_local/non_local_dot_product.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 63) | def forward(self, x, return_nl_map=False):
class NONLocalBlock1D (line 93) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 94) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 101) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 102) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 109) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 110) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/synthetic/bsrt/model/non_local/non_local_embedded_gaussian.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 70) | def forward(self, x, return_nl_map=False):
class NONLocalBlock1D (line 99) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 100) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 107) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 108) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 115) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 116) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/synthetic/bsrt/model/non_local/non_local_gaussian.py
class _NonLocalBlockND (line 6) | class _NonLocalBlockND(nn.Module):
method __init__ (line 7) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
method forward (line 57) | def forward(self, x, return_nl_map=False):
class NONLocalBlock1D (line 95) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 96) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock2D (line 103) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 104) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
class NONLocalBlock3D (line 111) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 112) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
FILE: code/synthetic/bsrt/model/swin_util.py
class Mlp (line 17) | class Mlp(nn.Module):
method __init__ (line 18) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 27) | def forward(self, x):
function window_partition (line 36) | def window_partition(x, window_size):
function window_reverse (line 51) | def window_reverse(windows, window_size, H, W):
class WindowAttention (line 68) | class WindowAttention(nn.Module):
method __init__ (line 82) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
method forward (line 117) | def forward(self, x, mask=None):
method extra_repr (line 152) | def extra_repr(self) -> str:
method flops (line 155) | def flops(self, N):
function calculate_mask (line 169) | def calculate_mask(x_size, window_size, shift_size):
class SwinTransformerBlock (line 193) | class SwinTransformerBlock(nn.Module):
method __init__ (line 212) | def __init__(self, dim, input_resolution, num_heads, window_size=7, sh...
method forward (line 240) | def forward(self, x, x_size):
method extra_repr (line 281) | def extra_repr(self) -> str:
method flops (line 285) | def flops(self):
class PatchMerging (line 300) | class PatchMerging(nn.Module):
method __init__ (line 309) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
method forward (line 316) | def forward(self, x):
method extra_repr (line 339) | def extra_repr(self) -> str:
method flops (line 342) | def flops(self):
class BasicLayer (line 349) | class BasicLayer(nn.Module):
method __init__ (line 369) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
method forward (line 397) | def forward(self, x, x_size):
method extra_repr (line 408) | def extra_repr(self) -> str:
method flops (line 411) | def flops(self):
class RSTB (line 420) | class RSTB(nn.Module):
method __init__ (line 443) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
method forward (line 485) | def forward(self, x, x_size):
method flops (line 490) | def flops(self):
class PatchEmbed (line 501) | class PatchEmbed(nn.Module):
method __init__ (line 512) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
method forward (line 530) | def forward(self, x, use_norm=True):
method flops (line 536) | def flops(self):
class PatchUnEmbed (line 544) | class PatchUnEmbed(nn.Module):
method __init__ (line 555) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
method forward (line 568) | def forward(self, x, x_size):
method flops (line 573) | def flops(self):
FILE: code/synthetic/bsrt/model/utils/interp_methods.py
function set_framework_dependencies (line 17) | def set_framework_dependencies(x):
function support_sz (line 28) | def support_sz(sz):
function cubic (line 35) | def cubic(x):
function lanczos2 (line 45) | def lanczos2(x):
function lanczos3 (line 51) | def lanczos3(x):
function linear (line 57) | def linear(x):
function box (line 63) | def box(x):
FILE: code/synthetic/bsrt/model/utils/psconv.py
class PyConv2d (line 4) | class PyConv2d(nn.Module):
method __init__ (line 24) | def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_g...
method forward (line 36) | def forward(self, x):
class PSConv2d (line 45) | class PSConv2d(nn.Module):
method __init__ (line 46) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 69) | def forward(self, x):
class PSGConv2d (line 76) | class PSGConv2d(nn.Module):
method __init__ (line 77) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 102) | def forward(self, x):
FILE: code/synthetic/bsrt/model/utils/resize_right.py
class NoneClass (line 6) | class NoneClass:
function resize (line 29) | def resize(input, scale_factors=None, out_shape=None,
class ResizeLayer (line 80) | class ResizeLayer(nnModuleWrapped):
method __init__ (line 81) | def __init__(self, in_shape, scale_factors=None, out_shape=None,
method forward (line 132) | def forward(self, input):
function prepare_weights_and_field_of_view_1d (line 147) | def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,
function apply_weights (line 174) | def apply_weights(input, field_of_view, weights, dim, n_dims, fw):
function set_scale_and_out_sz (line 211) | def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):
function get_projected_grid (line 251) | def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):
function get_field_of_view (line 266) | def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):
function get_weights (line 288) | def get_weights(interp_method, projected_grid, field_of_view):
function apply_antialiasing_if_needed (line 301) | def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
function fw_ceil (line 315) | def fw_ceil(x, fw):
function fw_cat (line 322) | def fw_cat(x, fw):
function fw_swapaxes (line 329) | def fw_swapaxes(x, ax_1, ax_2, fw):
function fw_set_device (line 335) | def fw_set_device(x, device, fw):
FILE: code/synthetic/bsrt/scripts/cal_mean_std.py
function main (line 9) | def main():
FILE: code/synthetic/bsrt/scripts/download_burstsr_dataset.py
function download_burstsr_dataset (line 8) | def download_burstsr_dataset(download_path):
function main (line 59) | def main():
FILE: code/synthetic/bsrt/scripts/evaluate_burstsr_val.py
class SimpleBaseline (line 8) | class SimpleBaseline:
method __init__ (line 9) | def __init__(self):
method __call__ (line 12) | def __call__(self, burst):
function main (line 19) | def main():
FILE: code/synthetic/bsrt/scripts/save_results_synburst_val.py
class SimpleBaseline (line 9) | class SimpleBaseline:
method __init__ (line 10) | def __init__(self):
method __call__ (line 13) | def __call__(self, burst):
function main (line 20) | def main():
FILE: code/synthetic/bsrt/scripts/test_burstsr_dataset.py
function main (line 11) | def main():
FILE: code/synthetic/bsrt/scripts/test_synthetic_bursts.py
function main (line 11) | def main():
FILE: code/synthetic/bsrt/test.py
function ttaup (line 23) | def ttaup(burst):
function ttadown (line 32) | def ttadown(bursts):
function main (line 39) | def main():
function main_worker (line 43) | def main_worker(local_rank, nprocs, args):
FILE: code/synthetic/bsrt/test_synburst.py
function ttaup (line 29) | def ttaup(burst):
function ttadown (line 35) | def ttadown(bursts):
function main (line 42) | def main():
function main_worker (line 46) | def main_worker(local_rank, nprocs, args):
FILE: code/synthetic/bsrt/trainer.py
class Trainer (line 47) | class Trainer():
method __init__ (line 48) | def __init__(self, args, train_loader, train_sampler, valid_loader, my...
method train (line 118) | def train(self):
method test (line 233) | def test(self, print_time=False):
method save_model (line 319) | def save_model(self, filename):
method prepare (line 331) | def prepare(self, *args):
method terminate (line 341) | def terminate(self):
FILE: code/synthetic/bsrt/utility.py
function reduce_mean (line 25) | def reduce_mean(tensor, nprocs):
function gradient (line 31) | def gradient(data):
function smooth_grad_1st (line 37) | def smooth_grad_1st(flo, image, alpha):
function smooth_loss (line 48) | def smooth_loss(flow, img):
function setup (line 53) | def setup(rank, world_size):
function cleanup (line 80) | def cleanup():
function mkdir (line 84) | def mkdir(path):
class timer (line 89) | class timer():
method __init__ (line 90) | def __init__(self):
method tic (line 94) | def tic(self):
method toc (line 97) | def toc(self, restart=False):
method hold (line 102) | def hold(self):
method release (line 105) | def release(self):
method reset (line 111) | def reset(self):
class checkpoint (line 115) | class checkpoint():
method __init__ (line 116) | def __init__(self, args):
method get_path (line 153) | def get_path(self, *subdir):
method save (line 156) | def save(self, trainer, epoch, is_best=False):
method add_log (line 165) | def add_log(self, log):
method write_log (line 168) | def write_log(self, log, refresh=False):
method done (line 175) | def done(self):
method plot_psnr (line 178) | def plot_psnr(self, epoch):
method begin_background (line 197) | def begin_background(self):
method end_background (line 214) | def end_background(self):
method save_results (line 219) | def save_results(self, dataset, filename, save_list, scale):
function quantize (line 233) | def quantize(img, rgb_range):
function calc_psnr (line 238) | def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
function make_optimizer (line 257) | def make_optimizer(args, target):
function write_gray_to_tfboard (line 313) | def write_gray_to_tfboard(img):
function bayer_unify (line 333) | def bayer_unify(raw, input_pattern, target_pattern, mode) -> np.ndarray:
function bayer_aug (line 365) | def bayer_aug(raw, flip_h=False, flip_w=False, transpose=False, input_pa...
FILE: code/synthetic/bsrt/utils/data_format_utils.py
function numpy_to_torch (line 6) | def numpy_to_torch(a: np.ndarray):
function torch_to_numpy (line 10) | def torch_to_numpy(a: torch.Tensor):
function torch_to_npimage (line 14) | def torch_to_npimage(a: torch.Tensor, unnormalize=True):
function npimage_to_torch (line 23) | def npimage_to_torch(a, normalize=True, input_bgr=True):
function convert_dict (line 34) | def convert_dict(base_dict, batch_sz):
FILE: code/synthetic/bsrt/utils/debayer.py
class Debayer3x3 (line 5) | class Debayer3x3(torch.nn.Module):
method __init__ (line 29) | def __init__(self):
method forward (line 71) | def forward(self, x):
class Debayer2x2 (line 91) | class Debayer2x2(torch.nn.Module):
method __init__ (line 99) | def __init__(self):
method forward (line 115) | def forward(self, x):
class DebayerSplit (line 133) | class DebayerSplit(torch.nn.Module):
method __init__ (line 140) | def __init__(self):
method forward (line 151) | def forward(self, x):
FILE: code/synthetic/bsrt/utils/interp_methods.py
function set_framework_dependencies (line 17) | def set_framework_dependencies(x):
function support_sz (line 28) | def support_sz(sz):
function cubic (line 35) | def cubic(x):
function lanczos2 (line 45) | def lanczos2(x):
function lanczos3 (line 51) | def lanczos3(x):
function linear (line 57) | def linear(x):
function box (line 63) | def box(x):
FILE: code/synthetic/bsrt/utils/metrics.py
class MSSSIMLoss (line 15) | class MSSSIMLoss(nn.Module):
method __init__ (line 16) | def __init__(self, boundary_ignore=None):
method forward (line 21) | def forward(self, pred, gt, valid=None):
class CharbonnierLoss (line 33) | class CharbonnierLoss(nn.Module):
method __init__ (line 34) | def __init__(self, boundary_ignore=None):
method forward (line 39) | def forward(self, pred, gt, valid=None):
class L1 (line 51) | class L1(nn.Module):
method __init__ (line 52) | def __init__(self, boundary_ignore=None):
method forward (line 56) | def forward(self, pred, gt, valid=None):
class L2 (line 78) | class L2(nn.Module):
method __init__ (line 79) | def __init__(self, boundary_ignore=None):
method forward (line 84) | def forward(self, pred, gt, valid=None):
class PSNR (line 110) | class PSNR(nn.Module):
method __init__ (line 111) | def __init__(self, boundary_ignore=None, max_value=1.0):
method psnr (line 116) | def psnr(self, pred, gt, valid=None):
method forward (line 123) | def forward(self, pred, gt, valid=None):
class AlignedL1 (line 136) | class AlignedL1(nn.Module):
method __init__ (line 137) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):
method forward (line 145) | def forward(self, pred, gt, burst_input):
class AlignedL2 (line 186) | class AlignedL2(nn.Module):
method __init__ (line 187) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):
method forward (line 196) | def forward(self, pred, gt, burst_input):
class AlignedPSNR (line 243) | class AlignedPSNR(nn.Module):
method __init__ (line 244) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None, m...
method psnr (line 249) | def psnr(self, pred, gt, burst_input):
method forward (line 256) | def forward(self, pred, gt, burst_input):
class AlignedSSIM (line 265) | class AlignedSSIM(nn.Module):
method __init__ (line 266) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):
method _ssim (line 274) | def _ssim(self, pred, gt, burst_input):
method forward (line 313) | def forward(self, pred, gt, burst_input):
class AlignedLPIPS (line 319) | class AlignedLPIPS(nn.Module):
method __init__ (line 320) | def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):
method _lpips (line 329) | def _lpips(self, pred, gt, burst_input):
method forward (line 363) | def forward(self, pred, gt, burst_input):
FILE: code/synthetic/bsrt/utils/postprocessing_functions.py
class SimplePostProcess (line 7) | class SimplePostProcess:
method __init__ (line 8) | def __init__(self, gains=True, ccm=True, gamma=True, smoothstep=True, ...
method process (line 15) | def process(self, image, meta_info):
function process_linear_image_rgb (line 20) | def process_linear_image_rgb(image, meta_info, gains=True, ccm=True, gam...
class BurstSRPostProcess (line 40) | class BurstSRPostProcess:
method __init__ (line 41) | def __init__(self, no_white_balance=False, gamma=True, smoothstep=True...
method process (line 47) | def process(self, image, meta_info, external_norm_factor=None):
function process_burstsr_image_rgb (line 53) | def process_burstsr_image_rgb(im, meta_info, return_np=False, external_n...
FILE: code/synthetic/bsrt/utils/resize_right.py
class NoneClass (line 6) | class NoneClass:
function resize (line 29) | def resize(input, scale_factors=None, out_shape=None,
class ResizeLayer (line 80) | class ResizeLayer(nnModuleWrapped):
method __init__ (line 81) | def __init__(self, in_shape, scale_factors=None, out_shape=None,
method forward (line 132) | def forward(self, input):
function prepare_weights_and_field_of_view_1d (line 147) | def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz,
function apply_weights (line 174) | def apply_weights(input, field_of_view, weights, dim, n_dims, fw):
function set_scale_and_out_sz (line 211) | def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw):
function get_projected_grid (line 251) | def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None):
function get_field_of_view (line 266) | def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps):
function get_weights (line 288) | def get_weights(interp_method, projected_grid, field_of_view):
function apply_antialiasing_if_needed (line 301) | def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
function fw_ceil (line 315) | def fw_ceil(x, fw):
function fw_cat (line 322) | def fw_cat(x, fw):
function fw_swapaxes (line 329) | def fw_swapaxes(x, ax_1, ax_2, fw):
function fw_set_device (line 335) | def fw_set_device(x, device, fw):
FILE: code/synthetic/bsrt/utils/spatial_color_alignment.py
function gauss_1d (line 6) | def gauss_1d(sz, sigma, center, end_pad=0, density=False):
function gauss_2d (line 15) | def gauss_2d(sz, sigma, center, end_pad=(0, 0), density=False):
function get_gaussian_kernel (line 29) | def get_gaussian_kernel(sd):
function apply_kernel (line 38) | def apply_kernel(im, ksz, gauss_kernel):
function match_colors (line 48) | def match_colors(im_ref, im_q, im_test, ksz, gauss_kernel):
FILE: code/synthetic/bsrt/utils/stn.py
class SpatialTransformer (line 6) | class SpatialTransformer(nn.Module):
method __init__ (line 12) | def __init__(self, size, mode='bilinear'):
method forward (line 31) | def forward(self, src, flow):
FILE: code/synthetic/bsrt/utils/warp.py
function warp (line 6) | def warp(feat, flow, mode='bilinear', padding_mode='zeros'):
Condensed preview — 184 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,148K chars).
[
{
"path": ".gitignore",
"chars": 1799,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "LICENSE",
"chars": 1068,
"preview": "MIT License\n\nCopyright (c) 2022 Megvii Inc.\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
},
{
"path": "README.md",
"chars": 5434,
"preview": "## BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (CVPRW 2022)\n[![PWC"
},
{
"path": "code/real/bsrt/README.md",
"chars": 1618,
"preview": "# BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (Real-World)\n\n## Dep"
},
{
"path": "code/real/bsrt/data_processing/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/real/bsrt/data_processing/camera_pipeline.py",
"chars": 7659,
"preview": "import torch\nimport random\nimport math\nimport cv2 as cv\nimport numpy as np\nimport utils.data_format_utils as df_utils\n\"\""
},
{
"path": "code/real/bsrt/data_processing/synthetic_burst_generation.py",
"chars": 10314,
"preview": "import torch\nimport random\nimport cv2\nimport numpy as np\nimport torch.nn.functional as F\nfrom data_processing.camera_pip"
},
{
"path": "code/real/bsrt/datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/real/bsrt/datasets/burstsr_dataset.py",
"chars": 16721,
"preview": "import os\nimport torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\nimport torch.nn.functional as F\nimport random\n"
},
{
"path": "code/real/bsrt/datasets/burstsr_test_dataset.py",
"chars": 4984,
"preview": "import os\nimport torch\nimport torch.nn.functional as F\nimport random\nfrom .burstsr_dataset import SamsungRAWImage, flatt"
},
{
"path": "code/real/bsrt/datasets/data_sampler.py",
"chars": 2384,
"preview": "\"\"\"\nModified from torch.utils.data.distributed.DistributedSampler\nSupport enlarging the dataset for *iter-oriented* trai"
},
{
"path": "code/real/bsrt/datasets/realworld_burst_test_set.py",
"chars": 1263,
"preview": "import torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass RealWorldBurstTest(torch.utils.data.Dataset):\n "
},
{
"path": "code/real/bsrt/datasets/synthetic_burst_test_set.py",
"chars": 1429,
"preview": "import torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass SyntheticBurstTest(torch.utils.data.Dataset):\n "
},
{
"path": "code/real/bsrt/datasets/synthetic_burst_train_set.py",
"chars": 5235,
"preview": "import torch\nimport numpy as np\nfrom PIL import Image\nfrom data_processing.synthetic_burst_generation import rgb2rawburs"
},
{
"path": "code/real/bsrt/datasets/synthetic_burst_val_set.py",
"chars": 2499,
"preview": "import os\nimport torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass SyntheticBurstVal(torch.utils.data.Dat"
},
{
"path": "code/real/bsrt/datasets/zurich_raw2rgb_dataset.py",
"chars": 1531,
"preview": "import torch\nimport os\nimport numpy as np\nfrom cv2 import imread\n\n\nclass ZurichRAW2RGB(torch.utils.data.Dataset):\n \"\""
},
{
"path": "code/real/bsrt/demo.sh",
"chars": 1040,
"preview": "#!/usr/bin/env bash\n\n\npython main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_tiny --model BSR"
},
{
"path": "code/real/bsrt/loss/Charbonnier.py",
"chars": 531,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass CharbonnierLoss(nn.Module):\n \"\"\"L1 charbonn"
},
{
"path": "code/real/bsrt/loss/__init__.py",
"chars": 5209,
"preview": "import os\nfrom importlib import import_module\n\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\n"
},
{
"path": "code/real/bsrt/loss/adversarial.py",
"chars": 4435,
"preview": "import utility\nfrom types import SimpleNamespace\n\nfrom model import common\nfrom loss import discriminator\n\nimport torch\n"
},
{
"path": "code/real/bsrt/loss/discriminator.py",
"chars": 2021,
"preview": "from model import common\n\nimport torch.nn as nn\n\nclass Discriminator(nn.Module):\n '''\n output is not normalize"
},
{
"path": "code/real/bsrt/loss/filter.py",
"chars": 558,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Filter(nn.Module):\n def __init__(self, args"
},
{
"path": "code/real/bsrt/loss/hist_entropy.py",
"chars": 364,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass HistEntropy(nn.Module):\n def __init__(self,"
},
{
"path": "code/real/bsrt/loss/mssim.py",
"chars": 4651,
"preview": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\n\n\ndef gaussian(window_size, sigma):"
},
{
"path": "code/real/bsrt/loss/vgg.py",
"chars": 1167,
"preview": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models a"
},
{
"path": "code/real/bsrt/main.py",
"chars": 3413,
"preview": "import torch\nimport random\nimport numpy as np\nfrom torch.utils.data import DataLoader\nimport os\n\nimport utility\nimport m"
},
{
"path": "code/real/bsrt/model/DCNv2/LICENSE",
"chars": 1520,
"preview": "BSD 3-Clause License\n\nCopyright (c) 2019, Charles Shang\nAll rights reserved.\n\nRedistribution and use in source and binar"
},
{
"path": "code/real/bsrt/model/DCNv2/README.md",
"chars": 2226,
"preview": "## Deformable Convolutional Networks V2 with Pytorch 1.0\n\n### Build\n```bash\n ./make.sh # build\n python tes"
},
{
"path": "code/real/bsrt/model/DCNv2/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/real/bsrt/model/DCNv2/dcn_v2.py",
"chars": 17562,
"preview": "#!/usr/bin/env python\nfrom __future__ import absolute_import, division, print_function\n\nimport math\n\nimport torch\nfrom t"
},
{
"path": "code/real/bsrt/model/DCNv2/files.txt",
"chars": 999,
"preview": "/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu"
},
{
"path": "code/real/bsrt/model/DCNv2/make.sh",
"chars": 50,
"preview": "#!/usr/bin/env bash\npython setup.py build develop\n"
},
{
"path": "code/real/bsrt/model/DCNv2/setup.py",
"chars": 1878,
"preview": "#!/usr/bin/env python\n\nimport glob\nimport os\n\nimport torch\nfrom setuptools import find_packages, setup\nfrom torch.utils."
},
{
"path": "code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp",
"chars": 10924,
"preview": "#include <vector>\n#include \"cpu/dcn_v2_im2col_cpu.h\"\n\n#include <ATen/ATen.h>\n//#include <ATen/cuda/CUDAContext.h>\n\n#incl"
},
{
"path": "code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp",
"chars": 19948,
"preview": "#include \"dcn_v2_im2col_cpu.h\"\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n//#incl"
},
{
"path": "code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h",
"chars": 5105,
"preview": "\n/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contrib"
},
{
"path": "code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp",
"chars": 17007,
"preview": "/*!\n * Copyright (c) 2017 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file deformable_psro"
},
{
"path": "code/real/bsrt/model/DCNv2/src/cpu/vision.h",
"chars": 2665,
"preview": "#pragma once\n#include <torch/extension.h>\n\nat::Tensor\ndcn_v2_cpu_forward(const at::Tensor &input,\n co"
},
{
"path": "code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_cuda.cu",
"chars": 16079,
"preview": "#include <vector>\n#include \"cuda/dcn_v2_im2col_cuda.h\"\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#inclu"
},
{
"path": "code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu",
"chars": 20335,
"preview": "#include \"dcn_v2_im2col_cuda.h\"\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n#inclu"
},
{
"path": "code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.h",
"chars": 5226,
"preview": "\n/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contrib"
},
{
"path": "code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu",
"chars": 16288,
"preview": "/*!\n * Copyright (c) 2017 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file deformable_psro"
},
{
"path": "code/real/bsrt/model/DCNv2/src/cuda/vision.h",
"chars": 2694,
"preview": "#pragma once\n#include <torch/extension.h>\n#include <ATen/div_rtn.h>\nat::Tensor\ndcn_v2_cuda_forward(const at::Tensor &inp"
},
{
"path": "code/real/bsrt/model/DCNv2/src/dcn_v2.h",
"chars": 7715,
"preview": "#pragma once\n\n#include \"cpu/vision.h\"\n\n#ifdef WITH_CUDA\n#include \"cuda/vision.h\"\n#endif\n\nat::Tensor\ndcn_v2_forward(const"
},
{
"path": "code/real/bsrt/model/DCNv2/src/vision.cpp",
"chars": 405,
"preview": "\n#include \"dcn_v2.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n m.def(\"dcn_v2_forward\", &dcn_v2_forward, \"dcn_v2_forw"
},
{
"path": "code/real/bsrt/model/DCNv2/test.py",
"chars": 8506,
"preview": "#!/usr/bin/env python\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ impor"
},
{
"path": "code/real/bsrt/model/__init__.py",
"chars": 11735,
"preview": "import os\nfrom importlib import import_module\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel as P\nimport t"
},
{
"path": "code/real/bsrt/model/arch_util.py",
"chars": 28695,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.nn.functional as F\nfrom model import common"
},
{
"path": "code/real/bsrt/model/bsrt.py",
"chars": 25956,
"preview": "import functools\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport model.arch_util as arch_util\n"
},
{
"path": "code/real/bsrt/model/checkpoint.py",
"chars": 1508,
"preview": "import torch\nimport warnings\n\n\ndef detach_variable(inputs):\n if isinstance(inputs, tuple):\n out = []\n f"
},
{
"path": "code/real/bsrt/model/common.py",
"chars": 6635,
"preview": "import math\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef default_conv(in"
},
{
"path": "code/real/bsrt/model/non_local/network.py",
"chars": 2271,
"preview": "from torch import nn\n# from lib.non_local_concatenation import NONLocalBlock2D\n# from lib.non_local_gaussian import NONL"
},
{
"path": "code/real/bsrt/model/non_local/non_local_concatenation.py",
"chars": 5512,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/real/bsrt/model/non_local/non_local_cross_dot_product.py",
"chars": 5102,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/real/bsrt/model/non_local/non_local_dot_product.py",
"chars": 5087,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/real/bsrt/model/non_local/non_local_embedded_gaussian.py",
"chars": 5241,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/real/bsrt/model/non_local/non_local_gaussian.py",
"chars": 4915,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/real/bsrt/model/swin_util.py",
"chars": 38884,
"preview": "# -----------------------------------------------------------------------------------\n# SwinIR: Image Restoration Using "
},
{
"path": "code/real/bsrt/model/utils/interp_methods.py",
"chars": 1711,
"preview": "from math import pi\n\ntry:\n import torch\nexcept ImportError:\n torch = None\n\ntry:\n import numpy\nexcept ImportErro"
},
{
"path": "code/real/bsrt/model/utils/psconv.py",
"chars": 5730,
"preview": "import torch\nimport torch.nn as nn\n\nclass PyConv2d(nn.Module):\n \"\"\"PyConv2d with padding (general case). Applies a 2D"
},
{
"path": "code/real/bsrt/model/utils/resize_right.py",
"chars": 14806,
"preview": "import warnings\nfrom math import ceil\nimport model.utils.interp_methods as interp_methods\n\n\nclass NoneClass:\n pass\n\nt"
},
{
"path": "code/real/bsrt/option.py",
"chars": 8351,
"preview": "import argparse\n\nparser = argparse.ArgumentParser(description='EDSR and MDSR')\n\nparser.add_argument('--n_resblocks', typ"
},
{
"path": "code/real/bsrt/pwcnet/LICENSE",
"chars": 35120,
"preview": "GNU GENERAL PUBLIC LICENSE\n Version 3, 29 June 2007\n\n Copyright (C) 2007 Free Software Foundation,"
},
{
"path": "code/real/bsrt/pwcnet/README.md",
"chars": 3483,
"preview": "# pytorch-pwc\nThis is a personal reimplementation of PWC-Net [1] using PyTorch. Should you be making use of this work, p"
},
{
"path": "code/real/bsrt/pwcnet/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/real/bsrt/pwcnet/comparison/comparison.py",
"chars": 1148,
"preview": "#!/usr/bin/env python\n\nimport math\nimport moviepy\nimport moviepy.editor\nimport numpy\nimport PIL\nimport PIL.Image\nimport "
},
{
"path": "code/real/bsrt/pwcnet/correlation/README.md",
"chars": 435,
"preview": "This is an adaptation of the <a href=\"https://github.com/lmb-freiburg/flownet2\">FlowNet2 implementation</a> in order to "
},
{
"path": "code/real/bsrt/pwcnet/correlation/correlation.py",
"chars": 13478,
"preview": "#!/usr/bin/env python\n\nimport torch\n\nimport cupy\nimport re\n# from torch.cuda.amp import custom_fwd, custom_bwd\n\nkernel_C"
},
{
"path": "code/real/bsrt/pwcnet/download.bash",
"chars": 242,
"preview": "#!/bin/bash\n\nwget --verbose --continue --timestamping http://content.sniklaus.com/github/pytorch-pwc/network-chairs-thin"
},
{
"path": "code/real/bsrt/pwcnet/images/README.md",
"chars": 85,
"preview": "The used example originates from the MPI Sintel dataset: http://sintel.is.tue.mpg.de/"
},
{
"path": "code/real/bsrt/pwcnet/pwcnet.py",
"chars": 15042,
"preview": "# Based on run.py from PWCNet\nimport torch\n\nimport getopt\nimport math\nimport numpy\nimport PIL.Image\nimport sys\nfrom torc"
},
{
"path": "code/real/bsrt/pwcnet/requirements.txt",
"chars": 52,
"preview": "cupy>=5.0.0\nnumpy>=1.15.0\nPillow>=5.0.0\ntorch>=1.3.0"
},
{
"path": "code/real/bsrt/pwcnet/run.py",
"chars": 14455,
"preview": "#!/usr/bin/env python\n\nimport torch\n\nimport getopt\nimport math\nimport numpy\nimport os\nimport PIL\nimport PIL.Image\nimport"
},
{
"path": "code/real/bsrt/requirements.txt",
"chars": 55,
"preview": "matplotlib\nimageio\nopencv-python\ntensorboardX\ntqdm\ntimm"
},
{
"path": "code/real/bsrt/scripts/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/real/bsrt/scripts/cal_mean_std.py",
"chars": 637,
"preview": "import torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom datasets.burstsr_dataset import BurstSRDataset, flatten_raw_"
},
{
"path": "code/real/bsrt/scripts/demo.sh",
"chars": 91,
"preview": "set -ex\nrlaunch --cpu=4 --gpu=1 --memory=10240 -- python ./scripts/evaluate_burstsr_val.py\n"
},
{
"path": "code/real/bsrt/scripts/download_burstsr_dataset.py",
"chars": 2276,
"preview": "import os\nimport urllib.request\nimport zipfile\nimport shutil\nimport argparse\n\n\ndef download_burstsr_dataset(download_pat"
},
{
"path": "code/real/bsrt/scripts/evaluate.sh",
"chars": 89,
"preview": "set -ex\nrlaunch --cpu=4 --gpu=1 --memory=10240 -- python scripts/evaluate_burstsr_val.py\n"
},
{
"path": "code/real/bsrt/scripts/evaluate_burstsr_val.py",
"chars": 1593,
"preview": "import torch.nn.functional as F\nfrom datasets.burstsr_dataset import BurstSRDataset\nfrom utils.metrics import AlignedPSN"
},
{
"path": "code/real/bsrt/scripts/save_results_synburst_val.py",
"chars": 1214,
"preview": "import torch.nn.functional as F\nimport cv2\nfrom datasets.synthetic_burst_val_set import SyntheticBurstVal\nimport torch\ni"
},
{
"path": "code/real/bsrt/scripts/test_burstsr_dataset.py",
"chars": 2117,
"preview": "import torch.nn.functional as F\nimport cv2\nfrom datasets.burstsr_dataset import BurstSRDataset\nfrom torch.utils.data.dat"
},
{
"path": "code/real/bsrt/scripts/test_synthetic_bursts.py",
"chars": 1938,
"preview": "import torch.nn.functional as F\nimport cv2\nfrom datasets.synthetic_burst_train_set import SyntheticBurst\nfrom torch.util"
},
{
"path": "code/real/bsrt/test.py",
"chars": 1630,
"preview": "import torch.nn.functional as F\nimport cv2\n\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\n\n\nfrom datas"
},
{
"path": "code/real/bsrt/test_real.py",
"chars": 3447,
"preview": "\nimport cv2\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\nimport random\nimport utility\nfrom option imp"
},
{
"path": "code/real/bsrt/trainer.py",
"chars": 12759,
"preview": "import os\nimport sys\nfrom decimal import Decimal\nimport cv2\nimport utility\nimport torchvision.utils as tvutils\nimport to"
},
{
"path": "code/real/bsrt/utility.py",
"chars": 11200,
"preview": "import math\nimport time\nimport datetime\nfrom multiprocessing import Process\nfrom multiprocessing import Queue\n\nimport ma"
},
{
"path": "code/real/bsrt/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/real/bsrt/utils/data_format_utils.py",
"chars": 946,
"preview": "import numpy as np\nimport torch\nimport cv2 as cv\n\n\ndef numpy_to_torch(a: np.ndarray):\n return torch.from_numpy(a).flo"
},
{
"path": "code/real/bsrt/utils/debayer.py",
"chars": 5077,
"preview": "import torch\nimport torch.nn\nimport torch.nn.functional\n\nclass Debayer3x3(torch.nn.Module):\n '''Demosaicing of Bayer "
},
{
"path": "code/real/bsrt/utils/interp_methods.py",
"chars": 1711,
"preview": "from math import pi\n\ntry:\n import torch\nexcept ImportError:\n torch = None\n\ntry:\n import numpy\nexcept ImportErro"
},
{
"path": "code/real/bsrt/utils/metrics.py",
"chars": 15504,
"preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport utils.spatial_color_alignment as s"
},
{
"path": "code/real/bsrt/utils/postprocessing_functions.py",
"chars": 2930,
"preview": "import torch\nimport numpy as np\nimport utils.data_format_utils as df_utils\nfrom data_processing.camera_pipeline import a"
},
{
"path": "code/real/bsrt/utils/resize_right.py",
"chars": 14776,
"preview": "import warnings\nfrom math import ceil\nimport interp_methods\n\n\nclass NoneClass:\n pass\n\ntry:\n import torch\n from "
},
{
"path": "code/real/bsrt/utils/spatial_color_alignment.py",
"chars": 3202,
"preview": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef gauss_1d(sz, sigma, center, end_pad=0, density=False):\n "
},
{
"path": "code/real/bsrt/utils/stn.py",
"chars": 1825,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass SpatialTransformer(nn.Module):\n \"\"\"\n [S"
},
{
"path": "code/real/bsrt/utils/warp.py",
"chars": 1075,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef warp(feat, flow, mode='bilinear', padding_mode="
},
{
"path": "code/real/bsrt/validate.py",
"chars": 2813,
"preview": "\nimport cv2\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\nimport random\nimport utility\nfrom option imp"
},
{
"path": "code/synthetic/bsrt/README.md",
"chars": 1394,
"preview": "# BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (Synthetic)\n\n## Depe"
},
{
"path": "code/synthetic/bsrt/data_processing/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/synthetic/bsrt/data_processing/camera_pipeline.py",
"chars": 7659,
"preview": "import torch\nimport random\nimport math\nimport cv2 as cv\nimport numpy as np\nimport utils.data_format_utils as df_utils\n\"\""
},
{
"path": "code/synthetic/bsrt/data_processing/synthetic_burst_generation.py",
"chars": 10314,
"preview": "import torch\nimport random\nimport cv2\nimport numpy as np\nimport torch.nn.functional as F\nfrom data_processing.camera_pip"
},
{
"path": "code/synthetic/bsrt/datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/synthetic/bsrt/datasets/burstsr_dataset.py",
"chars": 16729,
"preview": "import os\nimport torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\nimport torch.nn.functional as F\nimport random\n"
},
{
"path": "code/synthetic/bsrt/datasets/burstsr_test_dataset.py",
"chars": 4984,
"preview": "import os\nimport torch\nimport torch.nn.functional as F\nimport random\nfrom .burstsr_dataset import SamsungRAWImage, flatt"
},
{
"path": "code/synthetic/bsrt/datasets/data_sampler.py",
"chars": 2384,
"preview": "\"\"\"\nModified from torch.utils.data.distributed.DistributedSampler\nSupport enlarging the dataset for *iter-oriented* trai"
},
{
"path": "code/synthetic/bsrt/datasets/realworld_burst_test_set.py",
"chars": 1263,
"preview": "import torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass RealWorldBurstTest(torch.utils.data.Dataset):\n "
},
{
"path": "code/synthetic/bsrt/datasets/synthetic_burst_test_set.py",
"chars": 1429,
"preview": "import torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass SyntheticBurstTest(torch.utils.data.Dataset):\n "
},
{
"path": "code/synthetic/bsrt/datasets/synthetic_burst_train_set.py",
"chars": 5235,
"preview": "import torch\nimport numpy as np\nfrom PIL import Image\nfrom data_processing.synthetic_burst_generation import rgb2rawburs"
},
{
"path": "code/synthetic/bsrt/datasets/synthetic_burst_val_set.py",
"chars": 2499,
"preview": "import os\nimport torch\nimport cv2\nimport numpy as np\nimport pickle as pkl\n\n\nclass SyntheticBurstVal(torch.utils.data.Dat"
},
{
"path": "code/synthetic/bsrt/datasets/zurich_raw2rgb_dataset.py",
"chars": 1531,
"preview": "import torch\nimport os\nimport numpy as np\nfrom cv2 import imread\n\n\nclass ZurichRAW2RGB(torch.utils.data.Dataset):\n \"\""
},
{
"path": "code/synthetic/bsrt/demo.sh",
"chars": 877,
"preview": "#!/usr/bin/env bash\n\n\npython main.py --n_GPUs 8 --print_every 40 --lr 0.0001 --decay 100-200 --save bsrt_tiny --model BS"
},
{
"path": "code/synthetic/bsrt/loss/Charbonnier.py",
"chars": 531,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass CharbonnierLoss(nn.Module):\n \"\"\"L1 charbonn"
},
{
"path": "code/synthetic/bsrt/loss/__init__.py",
"chars": 5209,
"preview": "import os\nfrom importlib import import_module\n\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\n"
},
{
"path": "code/synthetic/bsrt/loss/adversarial.py",
"chars": 4435,
"preview": "import utility\nfrom types import SimpleNamespace\n\nfrom model import common\nfrom loss import discriminator\n\nimport torch\n"
},
{
"path": "code/synthetic/bsrt/loss/discriminator.py",
"chars": 2021,
"preview": "from model import common\n\nimport torch.nn as nn\n\nclass Discriminator(nn.Module):\n '''\n output is not normalize"
},
{
"path": "code/synthetic/bsrt/loss/filter.py",
"chars": 558,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Filter(nn.Module):\n def __init__(self, args"
},
{
"path": "code/synthetic/bsrt/loss/hist_entropy.py",
"chars": 364,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass HistEntropy(nn.Module):\n def __init__(self,"
},
{
"path": "code/synthetic/bsrt/loss/mssim.py",
"chars": 4651,
"preview": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\n\n\ndef gaussian(window_size, sigma):"
},
{
"path": "code/synthetic/bsrt/loss/vgg.py",
"chars": 1167,
"preview": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models a"
},
{
"path": "code/synthetic/bsrt/main.py",
"chars": 3820,
"preview": "import torch\nimport random\nimport numpy as np\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/LICENSE",
"chars": 1520,
"preview": "BSD 3-Clause License\n\nCopyright (c) 2019, Charles Shang\nAll rights reserved.\n\nRedistribution and use in source and binar"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/README.md",
"chars": 2226,
"preview": "## Deformable Convolutional Networks V2 with Pytorch 1.0\n\n### Build\n```bash\n ./make.sh # build\n python tes"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/synthetic/bsrt/model/DCNv2/dcn_v2.py",
"chars": 17562,
"preview": "#!/usr/bin/env python\nfrom __future__ import absolute_import, division, print_function\n\nimport math\n\nimport torch\nfrom t"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/files.txt",
"chars": 999,
"preview": "/home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/make.sh",
"chars": 50,
"preview": "#!/usr/bin/env bash\npython setup.py build develop\n"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/setup.py",
"chars": 1878,
"preview": "#!/usr/bin/env python\n\nimport glob\nimport os\n\nimport torch\nfrom setuptools import find_packages, setup\nfrom torch.utils."
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp",
"chars": 10924,
"preview": "#include <vector>\n#include \"cpu/dcn_v2_im2col_cpu.h\"\n\n#include <ATen/ATen.h>\n//#include <ATen/cuda/CUDAContext.h>\n\n#incl"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp",
"chars": 19948,
"preview": "#include \"dcn_v2_im2col_cpu.h\"\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n//#incl"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h",
"chars": 5105,
"preview": "\n/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contrib"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp",
"chars": 17007,
"preview": "/*!\n * Copyright (c) 2017 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file deformable_psro"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cpu/vision.h",
"chars": 2665,
"preview": "#pragma once\n#include <torch/extension.h>\n\nat::Tensor\ndcn_v2_cpu_forward(const at::Tensor &input,\n co"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_cuda.cu",
"chars": 16079,
"preview": "#include <vector>\n#include \"cuda/dcn_v2_im2col_cuda.h\"\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#inclu"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu",
"chars": 20335,
"preview": "#include \"dcn_v2_im2col_cuda.h\"\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n\n#include <ATen/ATen.h>\n#inclu"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.h",
"chars": 5226,
"preview": "\n/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contrib"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu",
"chars": 16288,
"preview": "/*!\n * Copyright (c) 2017 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file deformable_psro"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/cuda/vision.h",
"chars": 2694,
"preview": "#pragma once\n#include <torch/extension.h>\n#include <ATen/div_rtn.h>\nat::Tensor\ndcn_v2_cuda_forward(const at::Tensor &inp"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/dcn_v2.h",
"chars": 7715,
"preview": "#pragma once\n\n#include \"cpu/vision.h\"\n\n#ifdef WITH_CUDA\n#include \"cuda/vision.h\"\n#endif\n\nat::Tensor\ndcn_v2_forward(const"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/src/vision.cpp",
"chars": 405,
"preview": "\n#include \"dcn_v2.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n m.def(\"dcn_v2_forward\", &dcn_v2_forward, \"dcn_v2_forw"
},
{
"path": "code/synthetic/bsrt/model/DCNv2/test.py",
"chars": 8506,
"preview": "#!/usr/bin/env python\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ impor"
},
{
"path": "code/synthetic/bsrt/model/__init__.py",
"chars": 11564,
"preview": "import os\nfrom importlib import import_module\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel as P\nimport t"
},
{
"path": "code/synthetic/bsrt/model/arch_util.py",
"chars": 28695,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.nn.functional as F\nfrom model import common"
},
{
"path": "code/synthetic/bsrt/model/bsrt.py",
"chars": 25969,
"preview": "import functools\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport model.arch_util as arch_util\n"
},
{
"path": "code/synthetic/bsrt/model/checkpoint.py",
"chars": 1508,
"preview": "import torch\nimport warnings\n\n\ndef detach_variable(inputs):\n if isinstance(inputs, tuple):\n out = []\n f"
},
{
"path": "code/synthetic/bsrt/model/common.py",
"chars": 6635,
"preview": "import math\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef default_conv(in"
},
{
"path": "code/synthetic/bsrt/model/ebsr.py",
"chars": 34021,
"preview": "import functools\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport model.arch_util as arch_util\n"
},
{
"path": "code/synthetic/bsrt/model/non_local/network.py",
"chars": 2271,
"preview": "from torch import nn\n# from lib.non_local_concatenation import NONLocalBlock2D\n# from lib.non_local_gaussian import NONL"
},
{
"path": "code/synthetic/bsrt/model/non_local/non_local_concatenation.py",
"chars": 5512,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/synthetic/bsrt/model/non_local/non_local_cross_dot_product.py",
"chars": 5102,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/synthetic/bsrt/model/non_local/non_local_dot_product.py",
"chars": 5087,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/synthetic/bsrt/model/non_local/non_local_embedded_gaussian.py",
"chars": 5241,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/synthetic/bsrt/model/non_local/non_local_gaussian.py",
"chars": 4915,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass _NonLocalBlockND(nn.Module):\n def __in"
},
{
"path": "code/synthetic/bsrt/model/swin_util.py",
"chars": 23426,
"preview": "# -----------------------------------------------------------------------------------\n# SwinIR: Image Restoration Using "
},
{
"path": "code/synthetic/bsrt/model/utils/interp_methods.py",
"chars": 1711,
"preview": "from math import pi\n\ntry:\n import torch\nexcept ImportError:\n torch = None\n\ntry:\n import numpy\nexcept ImportErro"
},
{
"path": "code/synthetic/bsrt/model/utils/psconv.py",
"chars": 5730,
"preview": "import torch\nimport torch.nn as nn\n\nclass PyConv2d(nn.Module):\n \"\"\"PyConv2d with padding (general case). Applies a 2D"
},
{
"path": "code/synthetic/bsrt/model/utils/resize_right.py",
"chars": 14806,
"preview": "import warnings\nfrom math import ceil\nimport model.utils.interp_methods as interp_methods\n\n\nclass NoneClass:\n pass\n\nt"
},
{
"path": "code/synthetic/bsrt/option.py",
"chars": 8240,
"preview": "import argparse\n\nparser = argparse.ArgumentParser(description='EDSR and MDSR')\n\nparser.add_argument('--n_resblocks', typ"
},
{
"path": "code/synthetic/bsrt/requirements.txt",
"chars": 46,
"preview": "matplotlib\nimageio\nopencv-python\ntensorboardX\n"
},
{
"path": "code/synthetic/bsrt/scripts/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/synthetic/bsrt/scripts/cal_mean_std.py",
"chars": 637,
"preview": "import torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom datasets.burstsr_dataset import BurstSRDataset, flatten_raw_"
},
{
"path": "code/synthetic/bsrt/scripts/demo.sh",
"chars": 91,
"preview": "set -ex\nrlaunch --cpu=4 --gpu=1 --memory=10240 -- python ./scripts/evaluate_burstsr_val.py\n"
},
{
"path": "code/synthetic/bsrt/scripts/download_burstsr_dataset.py",
"chars": 2276,
"preview": "import os\nimport urllib.request\nimport zipfile\nimport shutil\nimport argparse\n\n\ndef download_burstsr_dataset(download_pat"
},
{
"path": "code/synthetic/bsrt/scripts/evaluate.sh",
"chars": 89,
"preview": "set -ex\nrlaunch --cpu=4 --gpu=1 --memory=10240 -- python scripts/evaluate_burstsr_val.py\n"
},
{
"path": "code/synthetic/bsrt/scripts/evaluate_burstsr_val.py",
"chars": 1593,
"preview": "import torch.nn.functional as F\nfrom datasets.burstsr_dataset import BurstSRDataset\nfrom utils.metrics import AlignedPSN"
},
{
"path": "code/synthetic/bsrt/scripts/save_results_synburst_val.py",
"chars": 1214,
"preview": "import torch.nn.functional as F\nimport cv2\nfrom datasets.synthetic_burst_val_set import SyntheticBurstVal\nimport torch\ni"
},
{
"path": "code/synthetic/bsrt/scripts/test_burstsr_dataset.py",
"chars": 2117,
"preview": "import torch.nn.functional as F\nimport cv2\nfrom datasets.burstsr_dataset import BurstSRDataset\nfrom torch.utils.data.dat"
},
{
"path": "code/synthetic/bsrt/scripts/test_synthetic_bursts.py",
"chars": 1938,
"preview": "import torch.nn.functional as F\nimport cv2\nfrom datasets.synthetic_burst_train_set import SyntheticBurst\nfrom torch.util"
},
{
"path": "code/synthetic/bsrt/test.py",
"chars": 2142,
"preview": "\nimport cv2\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\nimport random\nimport utility\nfrom option imp"
},
{
"path": "code/synthetic/bsrt/test_synburst.py",
"chars": 3518,
"preview": "\nimport cv2\nimport torch\nimport numpy as np\nimport os\nfrom tqdm import tqdm\nimport random\nimport utility\nfrom option imp"
},
{
"path": "code/synthetic/bsrt/trainer.py",
"chars": 12997,
"preview": "import os, sys\nfrom decimal import Decimal\nimport cv2\nimport utility\nimport random\n\nimport torch\nfrom tensorboardX impor"
},
{
"path": "code/synthetic/bsrt/utility.py",
"chars": 12261,
"preview": "import math\nimport time\nimport datetime\nfrom multiprocessing import Process\nfrom multiprocessing import Queue\nimport tor"
},
{
"path": "code/synthetic/bsrt/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "code/synthetic/bsrt/utils/data_format_utils.py",
"chars": 946,
"preview": "import numpy as np\nimport torch\nimport cv2 as cv\n\n\ndef numpy_to_torch(a: np.ndarray):\n return torch.from_numpy(a).flo"
},
{
"path": "code/synthetic/bsrt/utils/debayer.py",
"chars": 5077,
"preview": "import torch\nimport torch.nn\nimport torch.nn.functional\n\nclass Debayer3x3(torch.nn.Module):\n '''Demosaicing of Bayer "
},
{
"path": "code/synthetic/bsrt/utils/interp_methods.py",
"chars": 1711,
"preview": "from math import pi\n\ntry:\n import torch\nexcept ImportError:\n torch = None\n\ntry:\n import numpy\nexcept ImportErro"
},
{
"path": "code/synthetic/bsrt/utils/metrics.py",
"chars": 15979,
"preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport utils.spatial_color_alignment as s"
},
{
"path": "code/synthetic/bsrt/utils/postprocessing_functions.py",
"chars": 2930,
"preview": "import torch\nimport numpy as np\nimport utils.data_format_utils as df_utils\nfrom data_processing.camera_pipeline import a"
},
{
"path": "code/synthetic/bsrt/utils/resize_right.py",
"chars": 14776,
"preview": "import warnings\nfrom math import ceil\nimport interp_methods\n\n\nclass NoneClass:\n pass\n\ntry:\n import torch\n from "
},
{
"path": "code/synthetic/bsrt/utils/spatial_color_alignment.py",
"chars": 3202,
"preview": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef gauss_1d(sz, sigma, center, end_pad=0, density=False):\n "
},
{
"path": "code/synthetic/bsrt/utils/stn.py",
"chars": 1825,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass SpatialTransformer(nn.Module):\n \"\"\"\n [S"
},
{
"path": "code/synthetic/bsrt/utils/warp.py",
"chars": 1075,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef warp(feat, flow, mode='bilinear', padding_mode="
},
{
"path": "requirements.txt",
"chars": 46,
"preview": "matplotlib\nimageio\nopencv-python\ntensorboardX\n"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the Algolzw/BSRT GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 184 files (1.1 MB), approximately 286.0k tokens, and a symbol index with 1285 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.