Repository: Algolzw/BSRT Branch: main Commit: da8d5154478a Files: 184 Total size: 1.1 MB Directory structure: gitextract_kbz2i0m2/ ├── .gitignore ├── LICENSE ├── README.md ├── code/ │ ├── real/ │ │ └── bsrt/ │ │ ├── README.md │ │ ├── data_processing/ │ │ │ ├── __init__.py │ │ │ ├── camera_pipeline.py │ │ │ └── synthetic_burst_generation.py │ │ ├── datasets/ │ │ │ ├── __init__.py │ │ │ ├── burstsr_dataset.py │ │ │ ├── burstsr_test_dataset.py │ │ │ ├── data_sampler.py │ │ │ ├── realworld_burst_test_set.py │ │ │ ├── synthetic_burst_test_set.py │ │ │ ├── synthetic_burst_train_set.py │ │ │ ├── synthetic_burst_val_set.py │ │ │ └── zurich_raw2rgb_dataset.py │ │ ├── demo.sh │ │ ├── loss/ │ │ │ ├── Charbonnier.py │ │ │ ├── __init__.py │ │ │ ├── adversarial.py │ │ │ ├── discriminator.py │ │ │ ├── filter.py │ │ │ ├── hist_entropy.py │ │ │ ├── mssim.py │ │ │ └── vgg.py │ │ ├── main.py │ │ ├── model/ │ │ │ ├── DCNv2/ │ │ │ │ ├── LICENSE │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── dcn_v2.py │ │ │ │ ├── files.txt │ │ │ │ ├── make.sh │ │ │ │ ├── setup.py │ │ │ │ ├── src/ │ │ │ │ │ ├── cpu/ │ │ │ │ │ │ ├── dcn_v2_cpu.cpp │ │ │ │ │ │ ├── dcn_v2_im2col_cpu.cpp │ │ │ │ │ │ ├── dcn_v2_im2col_cpu.h │ │ │ │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp │ │ │ │ │ │ └── vision.h │ │ │ │ │ ├── cuda/ │ │ │ │ │ │ ├── dcn_v2_cuda.cu │ │ │ │ │ │ ├── dcn_v2_im2col_cuda.cu │ │ │ │ │ │ ├── dcn_v2_im2col_cuda.h │ │ │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu │ │ │ │ │ │ └── vision.h │ │ │ │ │ ├── dcn_v2.h │ │ │ │ │ └── vision.cpp │ │ │ │ └── test.py │ │ │ ├── __init__.py │ │ │ ├── arch_util.py │ │ │ ├── bsrt.py │ │ │ ├── checkpoint.py │ │ │ ├── common.py │ │ │ ├── non_local/ │ │ │ │ ├── network.py │ │ │ │ ├── non_local_concatenation.py │ │ │ │ ├── non_local_cross_dot_product.py │ │ │ │ ├── non_local_dot_product.py │ │ │ │ ├── non_local_embedded_gaussian.py │ │ │ │ └── non_local_gaussian.py │ │ │ ├── swin_util.py │ │ │ └── utils/ │ │ │ ├── interp_methods.py │ │ │ ├── psconv.py │ │ │ └── resize_right.py │ │ ├── option.py │ │ ├── pwcnet/ │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── comparison/ │ │ │ │ └── comparison.py │ │ │ ├── correlation/ │ │ │ │ ├── README.md │ │ │ │ └── correlation.py │ │ │ ├── download.bash │ │ │ ├── images/ │ │ │ │ └── README.md │ │ │ ├── out.flo │ │ │ ├── pwcnet.py │ │ │ ├── requirements.txt │ │ │ └── run.py │ │ ├── requirements.txt │ │ ├── scripts/ │ │ │ ├── __init__.py │ │ │ ├── cal_mean_std.py │ │ │ ├── demo.sh │ │ │ ├── download_burstsr_dataset.py │ │ │ ├── evaluate.sh │ │ │ ├── evaluate_burstsr_val.py │ │ │ ├── save_results_synburst_val.py │ │ │ ├── test_burstsr_dataset.py │ │ │ └── test_synthetic_bursts.py │ │ ├── test.py │ │ ├── test_real.py │ │ ├── trainer.py │ │ ├── utility.py │ │ ├── utils/ │ │ │ ├── __init__.py │ │ │ ├── data_format_utils.py │ │ │ ├── debayer.py │ │ │ ├── interp_methods.py │ │ │ ├── metrics.py │ │ │ ├── postprocessing_functions.py │ │ │ ├── resize_right.py │ │ │ ├── spatial_color_alignment.py │ │ │ ├── stn.py │ │ │ └── warp.py │ │ └── validate.py │ └── synthetic/ │ └── bsrt/ │ ├── README.md │ ├── data_processing/ │ │ ├── __init__.py │ │ ├── camera_pipeline.py │ │ └── synthetic_burst_generation.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── burstsr_dataset.py │ │ ├── burstsr_test_dataset.py │ │ ├── data_sampler.py │ │ ├── realworld_burst_test_set.py │ │ ├── synthetic_burst_test_set.py │ │ ├── synthetic_burst_train_set.py │ │ ├── synthetic_burst_val_set.py │ │ └── zurich_raw2rgb_dataset.py │ ├── demo.sh │ ├── loss/ │ │ ├── Charbonnier.py │ │ ├── __init__.py │ │ ├── adversarial.py │ │ ├── discriminator.py │ │ ├── filter.py │ │ ├── hist_entropy.py │ │ ├── mssim.py │ │ └── vgg.py │ ├── main.py │ ├── model/ │ │ ├── DCNv2/ │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── dcn_v2.py │ │ │ ├── files.txt │ │ │ ├── make.sh │ │ │ ├── setup.py │ │ │ ├── src/ │ │ │ │ ├── cpu/ │ │ │ │ │ ├── dcn_v2_cpu.cpp │ │ │ │ │ ├── dcn_v2_im2col_cpu.cpp │ │ │ │ │ ├── dcn_v2_im2col_cpu.h │ │ │ │ │ ├── dcn_v2_psroi_pooling_cpu.cpp │ │ │ │ │ └── vision.h │ │ │ │ ├── cuda/ │ │ │ │ │ ├── dcn_v2_cuda.cu │ │ │ │ │ ├── dcn_v2_im2col_cuda.cu │ │ │ │ │ ├── dcn_v2_im2col_cuda.h │ │ │ │ │ ├── dcn_v2_psroi_pooling_cuda.cu │ │ │ │ │ └── vision.h │ │ │ │ ├── dcn_v2.h │ │ │ │ └── vision.cpp │ │ │ └── test.py │ │ ├── __init__.py │ │ ├── arch_util.py │ │ ├── bsrt.py │ │ ├── checkpoint.py │ │ ├── common.py │ │ ├── ebsr.py │ │ ├── non_local/ │ │ │ ├── network.py │ │ │ ├── non_local_concatenation.py │ │ │ ├── non_local_cross_dot_product.py │ │ │ ├── non_local_dot_product.py │ │ │ ├── non_local_embedded_gaussian.py │ │ │ └── non_local_gaussian.py │ │ ├── swin_util.py │ │ └── utils/ │ │ ├── interp_methods.py │ │ ├── psconv.py │ │ └── resize_right.py │ ├── option.py │ ├── requirements.txt │ ├── scripts/ │ │ ├── __init__.py │ │ ├── cal_mean_std.py │ │ ├── demo.sh │ │ ├── download_burstsr_dataset.py │ │ ├── evaluate.sh │ │ ├── evaluate_burstsr_val.py │ │ ├── save_results_synburst_val.py │ │ ├── test_burstsr_dataset.py │ │ └── test_synthetic_bursts.py │ ├── test.py │ ├── test_synburst.py │ ├── trainer.py │ ├── utility.py │ └── utils/ │ ├── __init__.py │ ├── data_format_utils.py │ ├── debayer.py │ ├── interp_methods.py │ ├── metrics.py │ ├── postprocessing_functions.py │ ├── resize_right.py │ ├── spatial_color_alignment.py │ ├── stn.py │ └── warp.py └── requirements.txt ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 Megvii Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ ## BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (CVPRW 2022) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bsrt-improving-burst-super-resolution-with/burst-image-super-resolution-on-burstsr)](https://paperswithcode.com/sota/burst-image-super-resolution-on-burstsr?p=bsrt-improving-burst-super-resolution-with) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bsrt-improving-burst-super-resolution-with/burst-image-super-resolution-on)](https://paperswithcode.com/sota/burst-image-super-resolution-on?p=bsrt-improving-burst-super-resolution-with)![visitors](https://visitor-badge.glitch.me/badge?page_id=Algolzw/BSRT) #### [BSRT](https://arxiv.org/abs/2204.08332), the winner of the NTIRE 2022 Burst Super-Resolution Challenge Real-World Track. You can also find our winner method in NTIRE 2021 Burst Super-Resolution Challenge [here](https://github.com/Algolzw/EBSR). > This work addresses the Burst Super-Resolution (BurstSR) task using a new architecture, which requires restoring a high-quality image from a sequence of noisy, misaligned, and low-resolution RAW bursts. To overcome the challenges in BurstSR, we propose a **B**urst **S**uper-**R**esolution **T**ransformer (**BSRT**), which can significantly improve the capability of extracting inter-frame information and reconstruction. To achieve this goal, we propose a Pyramid Flow-Guided Deformable Convolution Network (Pyramid FG-DCN) and incorporate Swin Transformer Blocks and Groups as our main backbone. More specifically, we combine optical flows and deformable convolutions, hence our BSRT can handle misalignment and aggregate the potential texture information in multi-frames more efficiently. In addition, our Transformer-based structure can capture long-range dependency to further improve the performance. The evaluation on both synthetic and real-world tracks demonstrates that our approach achieves a new state-of-the-art in BurstSR task. Further, our BSRT wins the championship in the NTIRE2022 Burst Super-Resolution Challenge. #### Comparison with State-of-the-art Burst Super-Resolution Methods ![ts](figs/ts.png) ## Overview Architecture ![overview.png](figs/overview.png) ## Dependencies - OS: Ubuntu 18.04 - Python: Python 3.7 - nvidia : - cuda: 10.1 - cudnn: 7.6.1 - Other reference requirements ## Quick Start 1.Create a conda virtual environment and activate it ```python3 conda create -n pytorch_1.6 python=3.7 source activate pytorch_1.6 ``` 2.Install PyTorch and torchvision following the official instructions ```python3 conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch ``` 3.Install build requirements ```python3 pip3 install -r requirements.txt ``` 4.Install DCN ```python3 cd DCNv2 python3 setup.py build develop # build python3 test.py # run examples and check ``` ## Training We provide all pretrained model weights [here](https://drive.google.com/file/d/1Bv1ZwoE3s8trhG--wjB0Yt6WJIQPpvsn/view?usp=sharing). #### For Synthetic data ```python3 cd code/synthetic/bsrt # Modify the root path of training dataset and model etc. # The number of GPUs should be more than 1 python main.py --n_GPUs 8 --print_every 40 --lr 0.0001 --decay 150-300 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 32 --burst_size 14 --patch_size 256 ``` #### For Real-World data ```python3 cd code/real/bsrt # Modify the root path of training dataset and model etc. # The number of GPUs should be more than 1 python main.py --n_GPUs 8 --print_every 20 --lr 0.00005 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth ``` The pretrained PWC-Net model can be downloaded [here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing). ## Test #### For Synthetic data ```python3 # Modify the path of test dataset and the path of the trained model python test_synburst.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth --root /data/dataset/ntire21/burstsr/synthetic ``` #### For Real-World data ```python3 # Modify the path of test dataset and the path of the trained model python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real ``` ## Results ### Comparison on Synthetic dataset ![cmp_syn.png](figs/cmp_syn.png) ### Comparison on Real-World dataset ![cmp_real.png](figs/cmp_real.png) ## Citations If our code helps your research or work, please consider citing our paper. The following is a BibTeX reference. ``` @inproceedings{luo2022bsrt, title={BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment}, author={Luo, Ziwei and Li, Youwei and Cheng, Shen and Yu, Lei and Wu, Qi and Wen, Zhihong and Fan, Haoqiang and Sun, Jian and Liu, Shuaicheng}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={998--1008}, year={2022} } ``` ## Contact email: [ziwei.ro@gmail.com] ================================================ FILE: code/real/bsrt/README.md ================================================ # BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (Real-World) ## Dependencies - OS: Ubuntu 18.04 - Python: Python 3.7 - nvidia : - cuda: 10.1 - cudnn: 7.6.1 - Other reference requirements ## Quick Start 1.Create a conda virtual environment and activate it ```python3 conda create -n pytorch_1.6 python=3.7 source activate pytorch_1.6 ``` 2.Install PyTorch and torchvision following the official instructions ```python3 conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch ``` 3.Install build requirements ```python3 pip3 install -r requirements.txt ``` 4.Install DCN ```python3 cd DCNv2 python3 setup.py build develop # build python3 test.py # run examples and check ``` ## Training The pretrained PWC-Net model can be downloaded [here](https://drive.google.com/file/d/1dD6vB9QN3qwmOBi3AGKzJbbSojwDDlgV/view?usp=sharing). ```python3 # Modify the root path of training dataset and model etc. # The number of GPUs should be more than 1 python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth ``` ## Test ```python3 # Modify the path of test dataset and the path of the trained model python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real ``` ================================================ FILE: code/real/bsrt/data_processing/__init__.py ================================================ ================================================ FILE: code/real/bsrt/data_processing/camera_pipeline.py ================================================ import torch import random import math import cv2 as cv import numpy as np import utils.data_format_utils as df_utils """ Based on http://timothybrooks.com/tech/unprocessing Functions for forward and inverse camera pipeline. All functions input a torch float tensor of shape (c, h, w). Additionally, some also support batch operations, i.e. inputs of shape (b, c, h, w) """ def random_ccm(): """Generates random RGB -> Camera color correction matrices.""" # Takes a random convex combination of XYZ -> Camera CCMs. xyz2cams = [[[1.0234, -0.2969, -0.2266], [-0.5625, 1.6328, -0.0469], [-0.0703, 0.2188, 0.6406]], [[0.4913, -0.0541, -0.0202], [-0.613, 1.3513, 0.2906], [-0.1564, 0.2151, 0.7183]], [[0.838, -0.263, -0.0639], [-0.2887, 1.0725, 0.2496], [-0.0627, 0.1427, 0.5438]], [[0.6596, -0.2079, -0.0562], [-0.4782, 1.3016, 0.1933], [-0.097, 0.1581, 0.5181]]] num_ccms = len(xyz2cams) xyz2cams = torch.tensor(xyz2cams) weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(0.0, 1.0) weights_sum = weights.sum() xyz2cam = (xyz2cams * weights).sum(dim=0) / weights_sum # Multiplies with RGB -> XYZ to get RGB -> Camera CCM. rgb2xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375], [0.2126729, 0.7151522, 0.0721750], [0.0193339, 0.1191920, 0.9503041]]) rgb2cam = torch.mm(xyz2cam, rgb2xyz) # Normalizes each row. rgb2cam = rgb2cam / rgb2cam.sum(dim=-1, keepdims=True) return rgb2cam def random_gains(): """Generates random gains for brightening and white balance.""" # RGB gain represents brightening. rgb_gain = 1.0 / random.gauss(mu=0.8, sigma=0.1) # Red and blue gains represent white balance. red_gain = random.uniform(1.9, 2.4) blue_gain = random.uniform(1.5, 1.9) return rgb_gain, red_gain, blue_gain def apply_smoothstep(image): """Apply global tone mapping curve.""" image_out = 3 * image**2 - 2 * image**3 return image_out def invert_smoothstep(image): """Approximately inverts a global tone mapping curve.""" image = image.clamp(0.0, 1.0) return 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0) def gamma_expansion(image): """Converts from gamma to linear space.""" # Clamps to prevent numerical instability of gradients near zero. return image.clamp(1e-8) ** 2.2 def gamma_compression(image): """Converts from linear to gammaspace.""" # Clamps to prevent numerical instability of gradients near zero. return image.clamp(1e-8) ** (1.0 / 2.2) def apply_ccm(image, ccm): """Applies a color correction matrix.""" assert image.dim() == 3 and image.shape[0] == 3 shape = image.shape image = image.view(3, -1) ccm = ccm.to(image.device).type_as(image) image = torch.mm(ccm, image) return image.view(shape) def apply_gains(image, rgb_gain, red_gain, blue_gain): """Inverts gains while safely handling saturated pixels.""" assert image.dim() == 3 and image.shape[0] in [3, 4] if image.shape[0] == 3: gains = torch.tensor([red_gain, 1.0, blue_gain]) * rgb_gain else: gains = torch.tensor([red_gain, 1.0, 1.0, blue_gain]) * rgb_gain gains = gains.view(-1, 1, 1) gains = gains.to(image.device).type_as(image) return (image * gains).clamp(0.0, 1.0) def safe_invert_gains(image, rgb_gain, red_gain, blue_gain): """Inverts gains while safely handling saturated pixels.""" assert image.dim() == 3 and image.shape[0] == 3 gains = torch.tensor([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain gains = gains.view(-1, 1, 1) # Prevents dimming of saturated pixels by smoothly masking gains near white. gray = image.mean(dim=0, keepdims=True) inflection = 0.9 mask = ((gray - inflection).clamp(0.0) / (1.0 - inflection)) ** 2.0 safe_gains = torch.max(mask + (1.0 - mask) * gains, gains) return image * safe_gains def mosaic(image, mode='rggb'): """Extracts RGGB Bayer planes from an RGB image.""" shape = image.shape if image.dim() == 3: image = image.unsqueeze(0) if mode == 'rggb': red = image[:, 0, 0::2, 0::2] green_red = image[:, 1, 0::2, 1::2] green_blue = image[:, 1, 1::2, 0::2] blue = image[:, 2, 1::2, 1::2] image = torch.stack((red, green_red, green_blue, blue), dim=1) elif mode == 'grbg': green_red = image[:, 1, 0::2, 0::2] red = image[:, 0, 0::2, 1::2] blue = image[:, 2, 0::2, 1::2] green_blue = image[:, 1, 1::2, 1::2] image = torch.stack((green_red, red, blue, green_blue), dim=1) if len(shape) == 3: return image.view((4, shape[-2] // 2, shape[-1] // 2)) else: return image.view((-1, 4, shape[-2] // 2, shape[-1] // 2)) def demosaic(image): assert isinstance(image, torch.Tensor) image = image.clamp(0.0, 1.0) * 255 if image.dim() == 4: num_images = image.dim() batch_input = True else: num_images = 1 batch_input = False image = image.unsqueeze(0) # Generate single channel input for opencv im_sc = torch.zeros((num_images, image.shape[-2] * 2, image.shape[-1] * 2, 1)) im_sc[:, ::2, ::2, 0] = image[:, 0, :, :] im_sc[:, ::2, 1::2, 0] = image[:, 1, :, :] im_sc[:, 1::2, ::2, 0] = image[:, 2, :, :] im_sc[:, 1::2, 1::2, 0] = image[:, 3, :, :] im_sc = im_sc.numpy().astype(np.uint8) out = [] for im in im_sc: # cv.imwrite('frames/tmp.png', im) im_dem_np = cv.cvtColor(im, cv.COLOR_BAYER_BG2RGB)#_VNG) # Convert to torch image im_t = df_utils.npimage_to_torch(im_dem_np, input_bgr=False) out.append(im_t) if batch_input: return torch.stack(out, dim=0) else: return out[0] def random_noise_levels(): """Generates random noise levels from a log-log linear distribution.""" log_min_shot_noise = math.log(0.0001) log_max_shot_noise = math.log(0.012) log_shot_noise = random.uniform(log_min_shot_noise, log_max_shot_noise) shot_noise = math.exp(log_shot_noise) line = lambda x: 2.18 * x + 1.20 log_read_noise = line(log_shot_noise) + random.gauss(mu=0.0, sigma=0.26) read_noise = math.exp(log_read_noise) return shot_noise, read_noise def add_noise(image, shot_noise=0.01, read_noise=0.0005): """Adds random shot (proportional to image) and read (independent) noise.""" variance = image * shot_noise + read_noise noise = torch.FloatTensor(image.shape).normal_().to(image.device)*variance.sqrt() return image + noise def process_linear_image_rgb(image, meta_info, return_np=False): image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain']) image = apply_ccm(image, meta_info['cam2rgb']) if meta_info['gamma']: image = gamma_compression(image) if meta_info['smoothstep']: image = apply_smoothstep(image) image = image.clamp(0.0, 1.0) if return_np: image = df_utils.torch_to_npimage(image) return image def process_linear_image_raw(image, meta_info): image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain']) image = demosaic(image) image = apply_ccm(image, meta_info['cam2rgb']) if meta_info['gamma']: image = gamma_compression(image) if meta_info['smoothstep']: image = apply_smoothstep(image) return image.clamp(0.0, 1.0) ================================================ FILE: code/real/bsrt/data_processing/synthetic_burst_generation.py ================================================ import torch import random import cv2 import numpy as np import torch.nn.functional as F from data_processing.camera_pipeline import * from utils.data_format_utils import torch_to_numpy, numpy_to_torch def random_crop(frames, crop_sz): """ Extract a random crop of size crop_sz from the input frames. If the crop_sz is larger than the input image size, then the largest possible crop of same aspect ratio as crop_sz will be extracted from frames, and upsampled to crop_sz. """ if not isinstance(crop_sz, (tuple, list)): crop_sz = (crop_sz, crop_sz) crop_sz = torch.tensor(crop_sz).float() shape = frames.shape # Select scale_factor. Ensure the crop fits inside the image max_scale_factor = torch.tensor(shape[-2:]).float() / crop_sz max_scale_factor = max_scale_factor.min().item() if max_scale_factor < 1.0: scale_factor = max_scale_factor else: scale_factor = 1.0 # Extract the crop orig_crop_sz = (crop_sz * scale_factor).floor() assert orig_crop_sz[-2] <= shape[-2] and orig_crop_sz[-1] <= shape[-1], 'Bug in crop size estimation!' r1 = random.randint(0, shape[-2] - orig_crop_sz[-2]) c1 = random.randint(0, shape[-1] - orig_crop_sz[-1]) r2 = r1 + orig_crop_sz[0].int().item() c2 = c1 + orig_crop_sz[1].int().item() frames_crop = frames[:, r1:r2, c1:c2] # Resize to crop_sz if scale_factor < 1.0: frames_crop = F.interpolate(frames_crop.unsqueeze(0), size=crop_sz.int().tolist(), mode='bilinear', align_corners=False).squeeze(0) return frames_crop def rgb2rawburst(image, burst_size, downsample_factor=1, burst_transformation_params=None, image_processing_params=None, interpolation_type='bilinear'): """ Generates a synthetic LR RAW burst from the input image. The input sRGB image is first converted to linear sensor space using an inverse camera pipeline. A LR burst is then generated by applying random transformations defined by burst_transformation_params to the input image, and downsampling it by the downsample_factor. The generated burst is then mosaicekd and corrputed by random noise. """ if image_processing_params is None: image_processing_params = {} _defaults = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 'gamma': True, 'add_noise': True} for k, v in _defaults.items(): if k not in image_processing_params: image_processing_params[k] = v # Sample camera pipeline params if image_processing_params['random_ccm']: rgb2cam = random_ccm() else: rgb2cam = torch.eye(3).float() cam2rgb = rgb2cam.inverse() # Sample gains if image_processing_params['random_gains']: rgb_gain, red_gain, blue_gain = random_gains() else: rgb_gain, red_gain, blue_gain = (1.0, 1.0, 1.0) # Approximately inverts global tone mapping. use_smoothstep = image_processing_params['smoothstep'] if use_smoothstep: image = invert_smoothstep(image) # Inverts gamma compression. use_gamma = image_processing_params['gamma'] if use_gamma: image = gamma_expansion(image) # Inverts color correction. image = apply_ccm(image, rgb2cam) # Approximately inverts white balance and brightening. image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain) # Clip saturated pixels. image = image.clamp(0.0, 1.0) # Generate LR burst image_burst_rgb, flow_vectors = single2lrburst(image, burst_size=burst_size, downsample_factor=downsample_factor, transformation_params=burst_transformation_params, interpolation_type=interpolation_type) # mosaic image_burst = mosaic(image_burst_rgb.clone()) # Add noise if image_processing_params['add_noise']: shot_noise_level, read_noise_level = random_noise_levels() image_burst = add_noise(image_burst, shot_noise_level, read_noise_level) else: shot_noise_level = 0 read_noise_level = 0 # Clip saturated pixels. image_burst = image_burst.clamp(0.0, 1.0) meta_info = {'rgb2cam': rgb2cam, 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain, 'blue_gain': blue_gain, 'smoothstep': use_smoothstep, 'gamma': use_gamma, 'shot_noise_level': shot_noise_level, 'read_noise_level': read_noise_level} return image_burst, image, image_burst_rgb, flow_vectors, meta_info def get_tmat(image_shape, translation, theta, shear_values, scale_factors): """ Generates a transformation matrix corresponding to the input transformation parameters """ im_h, im_w = image_shape t_mat = np.identity(3) t_mat[0, 2] = translation[0] t_mat[1, 2] = translation[1] t_rot = cv2.getRotationMatrix2D((im_w * 0.5, im_h * 0.5), theta, 1.0) t_rot = np.concatenate((t_rot, np.array([0.0, 0.0, 1.0]).reshape(1, 3))) t_shear = np.array([[1.0, shear_values[0], -shear_values[0] * 0.5 * im_w], [shear_values[1], 1.0, -shear_values[1] * 0.5 * im_h], [0.0, 0.0, 1.0]]) t_scale = np.array([[scale_factors[0], 0.0, 0.0], [0.0, scale_factors[1], 0.0], [0.0, 0.0, 1.0]]) t_mat = t_scale @ t_rot @ t_shear @ t_mat t_mat = t_mat[:2, :] return t_mat def single2lrburst(image, burst_size, downsample_factor=1, transformation_params=None, interpolation_type='bilinear'): """ Generates a burst of size burst_size from the input image by applying random transformations defined by transformation_params, and downsampling the resulting burst by downsample_factor. """ if interpolation_type == 'bilinear': interpolation = cv2.INTER_LINEAR elif interpolation_type == 'lanczos': interpolation = cv2.INTER_LANCZOS4 else: raise ValueError normalize = False if isinstance(image, torch.Tensor): if image.max() < 2.0: image = image * 255.0 normalize = True image = torch_to_numpy(image).astype(np.uint8) burst = [] sample_pos_inv_all = [] rvs, cvs = torch.meshgrid([torch.arange(0, image.shape[0]), torch.arange(0, image.shape[1])]) sample_grid = torch.stack((cvs, rvs, torch.ones_like(cvs)), dim=-1).float() for i in range(burst_size): if i == 0: # For base image, do not apply any random transformations. We only translate the image to center the # sampling grid shift = (downsample_factor / 2.0) - 0.5 translation = (shift, shift) theta = 0.0 shear_factor = (0.0, 0.0) scale_factor = (1.0, 1.0) else: # Sample random image transformation parameters max_translation = transformation_params.get('max_translation', 0.0) if max_translation <= 0.01: shift = (downsample_factor / 2.0) - 0.5 translation = (shift, shift) else: translation = (random.uniform(-max_translation, max_translation), random.uniform(-max_translation, max_translation)) max_rotation = transformation_params.get('max_rotation', 0.0) theta = random.uniform(-max_rotation, max_rotation) max_shear = transformation_params.get('max_shear', 0.0) shear_x = random.uniform(-max_shear, max_shear) shear_y = random.uniform(-max_shear, max_shear) shear_factor = (shear_x, shear_y) max_ar_factor = transformation_params.get('max_ar_factor', 0.0) ar_factor = np.exp(random.uniform(-max_ar_factor, max_ar_factor)) max_scale = transformation_params.get('max_scale', 0.0) scale_factor = np.exp(random.uniform(-max_scale, max_scale)) scale_factor = (scale_factor, scale_factor * ar_factor) output_sz = (image.shape[1], image.shape[0]) # Generate a affine transformation matrix corresponding to the sampled parameters t_mat = get_tmat((image.shape[0], image.shape[1]), translation, theta, shear_factor, scale_factor) t_mat_tensor = torch.from_numpy(t_mat) # Apply the sampled affine transformation image_t = cv2.warpAffine(image, t_mat, output_sz, flags=interpolation, borderMode=cv2.BORDER_CONSTANT) t_mat_tensor_3x3 = torch.cat((t_mat_tensor.float(), torch.tensor([0.0, 0.0, 1.0]).view(1, 3)), dim=0) t_mat_tensor_inverse = t_mat_tensor_3x3.inverse()[:2, :].contiguous() sample_pos_inv = torch.mm(sample_grid.view(-1, 3), t_mat_tensor_inverse.t().float()).view( *sample_grid.shape[:2], -1) if transformation_params.get('border_crop') is not None: border_crop = transformation_params.get('border_crop') image_t = image_t[border_crop:-border_crop, border_crop:-border_crop, :] sample_pos_inv = sample_pos_inv[border_crop:-border_crop, border_crop:-border_crop, :] # Downsample the image image_t = cv2.resize(image_t, None, fx=1.0 / downsample_factor, fy=1.0 / downsample_factor, interpolation=interpolation) sample_pos_inv = cv2.resize(sample_pos_inv.numpy(), None, fx=1.0 / downsample_factor, fy=1.0 / downsample_factor, interpolation=interpolation) sample_pos_inv = torch.from_numpy(sample_pos_inv).permute(2, 0, 1).contiguous() if normalize: image_t = numpy_to_torch(image_t).float() / 255.0 else: image_t = numpy_to_torch(image_t).float() burst.append(image_t) sample_pos_inv_all.append(sample_pos_inv / downsample_factor) burst_images = torch.stack(burst) sample_pos_inv_all = torch.stack(sample_pos_inv_all) # Compute the flow vectors to go from the i'th burst image to the base image flow_vectors = sample_pos_inv_all - sample_pos_inv_all[:, :1, ...] return burst_images, flow_vectors ================================================ FILE: code/real/bsrt/datasets/__init__.py ================================================ ================================================ FILE: code/real/bsrt/datasets/burstsr_dataset.py ================================================ import os import torch import cv2 import numpy as np import pickle as pkl import torch.nn.functional as F import random import time class SamsungRAWImage: @staticmethod def load(path): im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED) im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16) im_raw = torch.from_numpy(im_raw) meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), "rb", -1)) return SamsungRAWImage(im_raw, meta_data['black_level'], meta_data['cam_wb'], meta_data['daylight_wb'], meta_data['color_matrix'], meta_data['exif_data'], meta_data.get('crop_info', None), meta_data.get('im_preview', None)) def __init__(self, im_raw, black_level, cam_wb, daylight_wb, color_matrix, exif_data, crop_info=None, im_preview=None): self.im_raw = im_raw self.black_level = black_level self.cam_wb = cam_wb self.daylight_wb = daylight_wb self.color_matrix = color_matrix self.exif_data = exif_data self.crop_info = crop_info self.im_preview = im_preview self.norm_factor = 1023.0 def get_all_meta_data(self): return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb, 'color_matrix': self.color_matrix.tolist()} def get_exposure_time(self): return self.exif_data['Image ExposureTime'].values[0].decimal() def get_noise_profile(self): noise = self.exif_data['Image Tag 0xC761'].values noise = [n[0] for n in noise] noise = np.array(noise).reshape(3, 2) return noise def get_f_number(self): return self.exif_data['Image FNumber'].values[0].decimal() def get_iso(self): return self.exif_data['Image ISOSpeedRatings'].values[0] def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False): im_raw = self.im_raw.float() if substract_black_level: im_raw = im_raw - torch.tensor(self.black_level).view(4, 1, 1) if white_balance: im_raw = im_raw * torch.tensor(self.cam_wb).view(4, 1, 1) if normalize: im_raw = im_raw / self.norm_factor return im_raw def shape(self): shape = (4, self.im_raw.shape[1], self.im_raw.shape[2]) return shape def crop_image(self, r1, r2, c1, c2): self.im_raw = self.im_raw[:, r1:r2, c1:c2] def get_crop(self, r1, r2, c1, c2): im_raw = self.im_raw[:, r1:r2, c1:c2] if self.im_preview is not None: im_preview = self.im_preview[2*r1:2*r2, 2*c1:2*c2] else: im_preview = None return SamsungRAWImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.color_matrix, self.exif_data, im_preview=im_preview) def postprocess(self, return_np=True, norm_factor=None): # Convert to rgb # im = torch.from_numpy(self.im_raw.astype(np.float32)) im = self.im_raw im = (im - torch.tensor(self.black_level).view(4, 1, 1)) * torch.tensor(self.cam_wb).view(4, 1, 1) if norm_factor is None: im = im / im.max() else: im = im / norm_factor im = torch.stack((im[0], (im[1] + im[2])/2, im[3]), dim=0) # im = torch.stack((im[0], im[1], im[3]), dim=0) im_out = im.clamp(0.0, 1.0) if return_np: im_out = im_out.permute(1, 2, 0).numpy() * 255.0 im_out = im_out.astype(np.uint8) return im_out class CanonImage: @staticmethod def load(path, split='train'): im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED) im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16) im_raw = torch.from_numpy(im_raw) meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), "rb", -1)) return CanonImage(im_raw.float(), meta_data['black_level'], meta_data['cam_wb'], meta_data['daylight_wb'], meta_data['rgb_xyz_matrix'], meta_data.get('exif_data', None), meta_data.get('crop_info', None)) def __init__(self, im_raw, black_level, cam_wb, daylight_wb, rgb_xyz_matrix, exif_data, crop_info=None): super(CanonImage, self).__init__() self.im_raw = im_raw if len(black_level) == 4: black_level = [black_level[0], black_level[1], black_level[3]] self.black_level = black_level if len(cam_wb) == 4: cam_wb = [cam_wb[0], cam_wb[1], cam_wb[3]] self.cam_wb = cam_wb if len(daylight_wb) == 4: daylight_wb = [daylight_wb[0], daylight_wb[1], daylight_wb[3]] self.daylight_wb = daylight_wb self.rgb_xyz_matrix = rgb_xyz_matrix self.xyz_srgb_matrix = torch.tensor([3.2404542, -1.5371385, -0.4985314, -0.9692660, 1.8760108, 0.0415560, 0.0556434, -0.2040259, 1.0572252]).view(3, 3) self.exif_data = exif_data self.crop_info = crop_info self.norm_factor = 16383 def shape(self): shape = (3, self.im_raw.shape[1], self.im_raw.shape[2]) return shape def get_all_meta_data(self): return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb, 'rgb_xyz_matrix': self.rgb_xyz_matrix.tolist(), 'crop_info': self.crop_info, 'norm_factor': self.norm_factor} def get_exposure_time(self): return self.exif_data['EXIF ExposureTime'].values[0].decimal() def get_f_number(self): return self.exif_data['EXIF FNumber'].values[0].decimal() def get_iso(self): return self.exif_data['EXIF ISOSpeedRatings'].values[0] def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False): im_raw = self.im_raw.float() if substract_black_level: im_raw = im_raw - torch.tensor(self.black_level).view(3, 1, 1) if white_balance: im_raw = im_raw * torch.tensor(self.cam_wb).view(3, 1, 1) / 1024.0 if normalize: im_raw = im_raw / self.norm_factor return im_raw def set_image_data(self, im_data): self.im_raw = im_data def crop_image(self, r1, r2, c1, c2): self.im_raw = self.im_raw[:, r1:r2, c1:c2] def get_crop(self, r1, r2, c1, c2): im_raw = self.im_raw[:, r1:r2, c1:c2] return CanonImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.rgb_xyz_matrix, self.exif_data, self.crop_info) def set_crop_info(self, crop_info): self.crop_info = crop_info def resize(self, size=None, scale_factor=None): self.im_raw = F.interpolate(self.im_raw.unsqueeze(0), size=size, scale_factor=scale_factor, mode='bilinear').squeeze(0) def postprocess(self, return_np=True): # Convert to rgb im = self.im_raw im = (im - torch.tensor(self.black_level).view(3, 1, 1)).float() * torch.tensor(self.cam_wb).view(3, 1, 1) im_out = im / im.max() im_out = im_out.clamp(0.0, 1.0) if return_np: im_out = im_out.permute(1, 2, 0).numpy() * 255.0 im_out = im_out.astype(np.uint8) return im_out def load_txt(path): with open(path, 'r') as fh: out = [d.rstrip() for d in fh.readlines()] return out class BurstSRDataset(torch.utils.data.Dataset): """ Real-world burst super-resolution dataset. """ def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='train'): """ args: root : path of the root directory burst_size : Burst size. Maximum allowed burst size is 14. crop_sz: Size of the extracted crop. Maximum allowed crop size is 80 center_crop: Whether to extract a random crop, or a centered crop. random_flip: Whether to apply random horizontal and vertical flip split: Can be 'train' or 'val' """ assert burst_size <= 14, 'burst_sz must be less than or equal to 14' assert crop_sz <= 80, 'crop_sz must be less than or equal to 80' assert split in ['train', 'val'] root = root + '/' + split super().__init__() self.burst_size = burst_size self.crop_sz = crop_sz self.split = split self.center_crop = center_crop self.random_flip = random_flip self.root = root self.substract_black_level = True self.white_balance = False self.burst_list = self._get_burst_list() def _get_burst_list(self): burst_list = sorted(os.listdir('{}'.format(self.root))) # print(burst_list) return burst_list def get_burst_info(self, burst_id): burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]} return burst_info def _get_raw_image(self, burst_id, im_id): raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id)) return raw_image def _get_gt_image(self, burst_id): canon_im = CanonImage.load('{}/{}/canon'.format(self.root, self.burst_list[burst_id]), split=self.split) return canon_im def get_burst(self, burst_id, im_ids, info=None): frames = [self._get_raw_image(burst_id, i) for i in im_ids] gt = self._get_gt_image(burst_id) if info is None: info = self.get_burst_info(burst_id) return frames, gt, info def _sample_images(self): burst_size = 14 ids = random.sample(range(1, burst_size), k=self.burst_size - 1) ids = [0, ] + ids return ids def __len__(self): return len(self.burst_list) def __getitem__(self, index): # Sample the images in the burst, in case a burst_size < 14 is used. im_ids = self._sample_images() # Read the burst images along with HR ground truth frames, gt, meta_info = self.get_burst(index, im_ids) # Extract crop if needed if frames[0].shape()[-1] != self.crop_sz: if getattr(self, 'center_crop', False): r1 = (frames[0].shape()[-2] - self.crop_sz) // 2 c1 = (frames[0].shape()[-1] - self.crop_sz) // 2 else: r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz) c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz) r2 = r1 + self.crop_sz c2 = c1 + self.crop_sz scale_factor = gt.shape()[-1] // frames[0].shape()[-1] frames = [im.get_crop(r1, r2, c1, c2) for im in frames] gt = gt.get_crop(scale_factor * r1, scale_factor * r2, scale_factor * c1, scale_factor * c2) # Load the RAW image data burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level, white_balance=self.white_balance) for im in frames] # Convert to tensor gt_image_data = gt.get_image_data(normalize=True, white_balance=self.white_balance, substract_black_level=self.substract_black_level) if self.random_flip: burst_image_data = [flatten_raw_image(im) for im in burst_image_data] pad = [0, 0, 0, 0] if random.random() > 0.5: burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data] gt_image_data = gt_image_data.flip([2, ])[:, :, 2:-2].contiguous() pad[1] = 1 if random.random() > 0.5: burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data] gt_image_data = gt_image_data.flip([1, ])[:, 2:-2, :].contiguous() pad[3] = 1 burst_image_data = [pack_raw_image(im) for im in burst_image_data] burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data] gt_image_data = F.pad(gt_image_data.unsqueeze(0), [4 * p for p in pad], mode='replicate').squeeze(0) burst_image_meta_info = frames[0].get_all_meta_data() burst_image_meta_info['black_level_subtracted'] = self.substract_black_level burst_image_meta_info['while_balance_applied'] = self.white_balance burst_image_meta_info['norm_factor'] = frames[0].norm_factor gt_image_meta_info = gt.get_all_meta_data() burst = torch.stack(burst_image_data, dim=0) burst_exposure = frames[0].get_exposure_time() canon_exposure = gt.get_exposure_time() burst_f_number = frames[0].get_f_number() canon_f_number = gt.get_f_number() burst_iso = frames[0].get_iso() canon_iso = gt.get_iso() # Normalize the GT image to account for differences in exposure, ISO etc light_factor_burst = burst_exposure * burst_iso / (burst_f_number ** 2) light_factor_canon = canon_exposure * canon_iso / (canon_f_number ** 2) exp_scale_factor = (light_factor_burst / light_factor_canon) gt_image_data = gt_image_data * exp_scale_factor gt_image_meta_info['black_level_subtracted'] = self.substract_black_level gt_image_meta_info['while_balance_applied'] = self.white_balance gt_image_meta_info['norm_factor'] = gt.norm_factor / exp_scale_factor burst_image_meta_info['exposure'] = burst_exposure burst_image_meta_info['f_number'] = burst_f_number burst_image_meta_info['iso'] = burst_iso gt_image_meta_info['exposure'] = canon_exposure gt_image_meta_info['f_number'] = canon_f_number gt_image_meta_info['iso'] = canon_iso burst = burst.float() frame_gt = gt_image_data.float() meta_info_burst = burst_image_meta_info meta_info_gt = gt_image_meta_info del meta_info_gt['crop_info'] for k, v in meta_info_gt.items(): if isinstance(v, (list, tuple)): meta_info_gt[k] = torch.tensor(v) for k, v in meta_info_burst.items(): if isinstance(v, (list, tuple)): meta_info_burst[k] = torch.tensor(v) meta_info_burst['burst_name'] = meta_info['burst_name'] return burst, frame_gt, meta_info_burst, meta_info_gt def pack_raw_image(im_raw): if isinstance(im_raw, np.ndarray): im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2)) elif isinstance(im_raw, torch.Tensor): im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype).to(im_raw.device) else: raise Exception im_out[0, :, :] = im_raw[0::2, 0::2] im_out[1, :, :] = im_raw[0::2, 1::2] im_out[2, :, :] = im_raw[1::2, 0::2] im_out[3, :, :] = im_raw[1::2, 1::2] return im_out def flatten_raw_image(im_raw_4ch): if isinstance(im_raw_4ch, np.ndarray): im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2)) elif isinstance(im_raw_4ch, torch.Tensor): im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device) else: raise Exception im_out[0::2, 0::2] = im_raw_4ch[0, :, :] im_out[0::2, 1::2] = im_raw_4ch[1, :, :] im_out[1::2, 0::2] = im_raw_4ch[2, :, :] im_out[1::2, 1::2] = im_raw_4ch[3, :, :] return im_out def pack_raw_image_batch(im_raw): im_out = torch.zeros((im_raw.shape[0], im_raw.shape[1], 4, im_raw.shape[3] // 2, im_raw.shape[4] // 2), dtype=im_raw.dtype).to(im_raw.device) im_out[:, :, 0, :, :] = im_raw[:, :, 0, 0::2, 0::2] im_out[:, :, 1, :, :] = im_raw[:, :, 0, 0::2, 1::2] im_out[:, :, 2, :, :] = im_raw[:, :, 0, 1::2, 0::2] im_out[:, :, 3, :, :] = im_raw[:, :, 0, 1::2, 1::2] return im_out def flatten_raw_image_batch(im_raw_4ch): im_out = torch.zeros((im_raw_4ch.shape[0], im_raw_4ch.shape[1], 1, im_raw_4ch.shape[3] * 2, im_raw_4ch.shape[4] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device) im_out[:, :, 0, 0::2, 0::2] = im_raw_4ch[:, :, 0, :, :] im_out[:, :, 0, 0::2, 1::2] = im_raw_4ch[:, :, 1, :, :] im_out[:, :, 0, 1::2, 0::2] = im_raw_4ch[:, :, 2, :, :] im_out[:, :, 0, 1::2, 1::2] = im_raw_4ch[:, :, 3, :, :] return im_out ================================================ FILE: code/real/bsrt/datasets/burstsr_test_dataset.py ================================================ import os import torch import torch.nn.functional as F import random from .burstsr_dataset import SamsungRAWImage, flatten_raw_image, pack_raw_image class BurstSRDataset(torch.utils.data.Dataset): """ Real-world burst super-resolution dataset. """ def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='test'): """ args: root : path of the root directory burst_size : Burst size. Maximum allowed burst size is 14. crop_sz: Size of the extracted crop. Maximum allowed crop size is 80 center_crop: Whether to extract a random crop, or a centered crop. random_flip: Whether to apply random horizontal and vertical flip split: Can be 'train' or 'val' """ assert burst_size <= 14, 'burst_sz must be less than or equal to 14' assert crop_sz <= 80, 'crop_sz must be less than or equal to 80' assert split in ['test'] root = root + '/' + split super().__init__() self.burst_size = burst_size self.crop_sz = crop_sz self.split = split self.center_crop = center_crop self.random_flip = random_flip self.root = root self.substract_black_level = True self.white_balance = False self.burst_list = self._get_burst_list() def _get_burst_list(self): burst_list = sorted(os.listdir('{}'.format(self.root))) return burst_list def get_burst_info(self, burst_id): burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]} return burst_info def _get_raw_image(self, burst_id, im_id): raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id)) return raw_image def get_burst(self, burst_id, im_ids, info=None): frames = [self._get_raw_image(burst_id, i) for i in im_ids] if info is None: info = self.get_burst_info(burst_id) return frames, info def _sample_images(self): burst_size = 14 ids = random.sample(range(1, burst_size), k=self.burst_size - 1) ids = [0, ] + ids return ids def __len__(self): return len(self.burst_list) def __getitem__(self, index): # Sample the images in the burst, in case a burst_size < 14 is used. im_ids = self._sample_images() # Read the burst images along with HR ground truth frames, meta_info = self.get_burst(index, im_ids) # Extract crop if needed if frames[0].shape()[-1] != self.crop_sz: if getattr(self, 'center_crop', False): r1 = (frames[0].shape()[-2] - self.crop_sz) // 2 c1 = (frames[0].shape()[-1] - self.crop_sz) // 2 else: r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz) c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz) r2 = r1 + self.crop_sz c2 = c1 + self.crop_sz frames = [im.get_crop(r1, r2, c1, c2) for im in frames] # Load the RAW image data burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level, white_balance=self.white_balance) for im in frames] if self.random_flip: burst_image_data = [flatten_raw_image(im) for im in burst_image_data] pad = [0, 0, 0, 0] if random.random() > 0.5: burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data] pad[1] = 1 if random.random() > 0.5: burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data] pad[3] = 1 burst_image_data = [pack_raw_image(im) for im in burst_image_data] burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data] burst_image_meta_info = frames[0].get_all_meta_data() burst_image_meta_info['black_level_subtracted'] = self.substract_black_level burst_image_meta_info['while_balance_applied'] = self.white_balance burst_image_meta_info['norm_factor'] = frames[0].norm_factor burst = torch.stack(burst_image_data, dim=0) burst_exposure = frames[0].get_exposure_time() burst_f_number = frames[0].get_f_number() burst_iso = frames[0].get_iso() burst_image_meta_info['exposure'] = burst_exposure burst_image_meta_info['f_number'] = burst_f_number burst_image_meta_info['iso'] = burst_iso burst = burst.float() meta_info_burst = burst_image_meta_info for k, v in meta_info_burst.items(): if isinstance(v, (list, tuple)): meta_info_burst[k] = torch.tensor(v) return burst, meta_info_burst ================================================ FILE: code/real/bsrt/datasets/data_sampler.py ================================================ """ Modified from torch.utils.data.distributed.DistributedSampler Support enlarging the dataset for *iter-oriented* training, for saving time when restart the dataloader after each epoch """ import math import torch import torch.distributed as dist from torch.utils.data.sampler import Sampler class DistIterSampler(Sampler): """Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. """ def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas def __iter__(self): # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) indices = torch.randperm( self.total_size, generator=g ).tolist() # Returns a random permutation of integers from 0 to n - 1 dsize = len(self.dataset) indices = [v % dsize for v in indices] # subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def __len__(self): return self.num_samples def set_epoch(self, epoch): self.epoch = epoch ================================================ FILE: code/real/bsrt/datasets/realworld_burst_test_set.py ================================================ import torch import cv2 import numpy as np import pickle as pkl class RealWorldBurstTest(torch.utils.data.Dataset): """ """ def __init__(self, root): self.root = root self.burst_list = list(range(20)) self.burst_size = 14 def __len__(self): return len(self.burst_list) def _read_burst_image(self, index, image_id): im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED) im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14) return im_t def __getitem__(self, index): """ args: index: Index of the burst returns: burst: LR RAW burst, a torch tensor of shape The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. meta_info: Meta information about the burst """ burst_name = '{:04d}'.format(index) burst = [self._read_burst_image(index, i) for i in range(self.burst_size)] burst = torch.stack(burst, 0) meta_info = {} meta_info['burst_name'] = burst_name return burst, meta_info ================================================ FILE: code/real/bsrt/datasets/synthetic_burst_test_set.py ================================================ import torch import cv2 import numpy as np import pickle as pkl class SyntheticBurstTest(torch.utils.data.Dataset): """ Synthetic burst test set. The test burst have been generated using the same synthetic pipeline as employed in SyntheticBurst dataset. """ def __init__(self, root): self.root = root self.burst_list = list(range(92)) self.burst_size = 14 def __len__(self): return len(self.burst_list) def _read_burst_image(self, index, image_id): im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED) im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14) return im_t def __getitem__(self, index): """ Generates a synthetic burst args: index: Index of the burst returns: burst: LR RAW burst, a torch tensor of shape The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. meta_info: Meta information about the burst """ burst_name = '{:04d}'.format(index) burst = [self._read_burst_image(index, i) for i in range(self.burst_size)] burst = torch.stack(burst, 0) meta_info = {} meta_info['burst_name'] = burst_name return burst, meta_info ================================================ FILE: code/real/bsrt/datasets/synthetic_burst_train_set.py ================================================ import torch import numpy as np from PIL import Image from data_processing.synthetic_burst_generation import rgb2rawburst, random_crop #syn_burst_utils import torchvision.transforms as tfm class SyntheticBurst(torch.utils.data.Dataset): """ Synthetic burst dataset for joint denoising, demosaicking, and super-resolution. RAW Burst sequences are synthetically generated on the fly as follows. First, a single image is loaded from the base_dataset. The sampled image is converted to linear sensor space using the inverse camera pipeline employed in [1]. A burst sequence is then generated by adding random translations and rotations to the converted image. The generated burst is then converted is then mosaicked, and corrupted by random noise to obtain the RAW burst. [1] Unprocessing Images for Learned Raw Denoising, Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen, Jiawen and Sharlet, Dillon and Barron, Jonathan T, CVPR 2019 """ def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=tfm.ToTensor()): self.base_dataset = base_dataset self.burst_size = burst_size self.crop_sz = crop_sz self.transform = transform self.downsample_factor = 4 self.burst_transformation_params = {'max_translation': 24.0, 'max_rotation': 1.0, 'max_shear': 0.0, 'max_scale': 0.0, 'border_crop': 24} self.image_processing_params = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 'gamma': True, 'add_noise': True} self.interpolation_type = 'bilinear' def __len__(self): return len(self.base_dataset) def __getitem__(self, index): """ Generates a synthetic burst args: index: Index of the image in the base_dataset used to generate the burst returns: burst: Generated LR RAW burst, a torch tensor of shape [burst_size, 4, self.crop_sz / (2*self.downsample_factor), self.crop_sz / (2*self.downsample_factor)] The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. The extra factor 2 in the denominator (2*self.downsample_factor) corresponds to the mosaicking operation. frame_gt: The HR RGB ground truth in the linear sensor space, a torch tensor of shape [3, self.crop_sz, self.crop_sz] flow_vectors: The ground truth flow vectors between a burst image and the base image (i.e. the first image in the burst). The flow_vectors can be used to warp the burst images to the base frame, using the 'warp' function in utils.warp package. flow_vectors is torch tensor of shape [burst_size, 2, self.crop_sz / self.downsample_factor, self.crop_sz / self.downsample_factor]. Note that the flow_vectors are in the LR RGB space, before mosaicking. Hence it has twice the number of rows and columns, compared to the output burst. NOTE: The flow_vectors are only available during training for the purpose of using any auxiliary losses if needed. The flow_vectors will NOT be provided for the bursts in the test set meta_info: A dictionary containing the parameters used to generate the synthetic burst. """ frame = self.base_dataset[index] # Augmentation, e.g. convert to tensor if self.transform is not None: # frame = Image.fromarray(frame) frame = self.transform(frame) # Extract a random crop from the image crop_sz = self.crop_sz + 2 * self.burst_transformation_params.get('border_crop', 0) frame_crop = random_crop(frame, crop_sz) # Generate RAW burst burst, frame_gt, burst_rgb, flow_vectors, meta_info = rgb2rawburst(frame_crop, self.burst_size, self.downsample_factor, burst_transformation_params=self.burst_transformation_params, image_processing_params=self.image_processing_params, interpolation_type=self.interpolation_type ) if self.burst_transformation_params.get('border_crop') is not None: border_crop = self.burst_transformation_params.get('border_crop') frame_gt = frame_gt[:, border_crop:-border_crop, border_crop:-border_crop] return burst, frame_gt, flow_vectors, meta_info ================================================ FILE: code/real/bsrt/datasets/synthetic_burst_val_set.py ================================================ import os import torch import cv2 import numpy as np import pickle as pkl class SyntheticBurstVal(torch.utils.data.Dataset): """ Synthetic burst validation set introduced in [1]. The validation burst have been generated using a synthetic data generation pipeline. The dataset can be downloaded from https://data.vision.ee.ethz.ch/bhatg/SyntheticBurstVal.zip [1] Deep Burst Super-Resolution. Goutam Bhat, Martin Danelljan, Luc Van Gool, and Radu Timofte. CVPR 2021 """ def __init__(self, root=None, initialize=True): """ args: root - Path to root dataset directory initialize - boolean indicating whether to load the meta-data for the dataset """ self.root = os.path.join(root, 'val') self.burst_list = list(range(300)) self.burst_size = 14 def initialize(self): pass def __len__(self): return len(self.burst_list) def _read_burst_image(self, index, image_id): im = cv2.imread('{}/bursts/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED) im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14) return im_t def _read_gt_image(self, index): gt = cv2.imread('{}/gt/{:04d}/im_rgb.png'.format(self.root, index), cv2.IMREAD_UNCHANGED) gt_t = (torch.from_numpy(gt.astype(np.float32)) / 2 ** 14).permute(2, 0, 1).float() return gt_t def _read_meta_info(self, index): with open('{}/gt/{:04d}/meta_info.pkl'.format(self.root, index), "rb") as input_file: meta_info = pkl.load(input_file) return meta_info def __getitem__(self, index): """ Generates a synthetic burst args: index: Index of the burst returns: burst: LR RAW burst, a torch tensor of shape [14, 4, 48, 48] The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. gt : Ground truth linear image meta_info: Meta info about the burst which can be used to convert gt to sRGB space """ burst_name = '{:04d}'.format(index) burst = [self._read_burst_image(index, i) for i in range(self.burst_size)] burst = torch.stack(burst, 0) gt = self._read_gt_image(index) meta_info = self._read_meta_info(index) meta_info['burst_name'] = burst_name return burst, gt, meta_info ================================================ FILE: code/real/bsrt/datasets/zurich_raw2rgb_dataset.py ================================================ import torch import os import numpy as np from cv2 import imread class ZurichRAW2RGB(torch.utils.data.Dataset): """ Canon RGB images from the "Zurich RAW to RGB mapping" dataset. You can download the full dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip """ def __init__(self, root, split='train'): super().__init__() if split in ['train', 'test']: self.img_pth = os.path.join(root, split, 'canon') else: raise Exception('Unknown split {}'.format(split)) self.image_list = self._get_image_list(split) def _get_image_list(self, split): if split == 'train': image_list = ['{:d}.jpg'.format(i) for i in range(46839)] elif split == 'test': # image_list = ['{:d}.jpg'.format(int(i)) for i in np.linspace(1, 1200, 200)] image_list = ['{:d}.jpg'.format(i) for i in range(1200)] else: raise Exception return image_list def _get_image(self, im_id): path = os.path.join(self.img_pth, self.image_list[im_id]) img = imread(path) return img def get_image(self, im_id): frame = self._get_image(im_id) return frame def __len__(self): return len(self.image_list) def __getitem__(self, index): frame = self._get_image(index) return frame ================================================ FILE: code/real/bsrt/demo.sh ================================================ #!/usr/bin/env bash python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 8 --burst_size 14 --patch_size 80 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth # python main.py --n_GPUs 8 --print_every 20 --lr 0.00004 --decay 40-80 --save bsrt_large --model BSRT --fp16 --model_level L --swinfeature --batch_size 8 --burst_size 14 --patch_size 48 --pre_train ../../synthetic/train_log/bsrt/real_models/bsrt_large/bsrt_best_epoch.pth # python test_real.py --n_GPUs 1 --model BSRT --model_level S --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrtbest_epoch.pth --root /data/dataset/ntire21/burstsr/real # python test_real.py --n_GPUs 1 --model BSRT --model_level L --swinfeature --batch_size 1 --burst_size 14 --patch_size 80 --pre_train ../train_log/bsrt/real_models/bsrt_large/bsrt_realworld.pth --root /data/dataset/ntire21/burstsr/real ================================================ FILE: code/real/bsrt/loss/Charbonnier.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class CharbonnierLoss(nn.Module): """L1 charbonnier loss.""" def __init__(self, epsilon=1e-3, reduce=True): super(CharbonnierLoss, self).__init__() self.eps = epsilon * epsilon self.reduce = reduce def forward(self, X, Y): diff = torch.add(X, -Y) error = torch.sqrt(diff * diff + self.eps) if self.reduce: loss = torch.mean(error) else: loss = error return loss ================================================ FILE: code/real/bsrt/loss/__init__.py ================================================ import os from importlib import import_module import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class Loss(nn.modules.loss._Loss): def __init__(self, args, ckp): super(Loss, self).__init__() if args.local_rank == 0: print('Preparing loss function:') self.n_GPUs = args.n_GPUs self.loss = [] self.loss_module = nn.ModuleList() for loss in args.loss.split('+'): weight, loss_type = loss.split('*') if loss_type == 'MSE': loss_function = nn.MSELoss() elif loss_type == 'L1': loss_function = nn.L1Loss() elif loss_type.find('VGG') >= 0: module = import_module('loss.vgg') loss_function = getattr(module, 'VGG')( loss_type[3:], rgb_range=args.rgb_range ) elif loss_type.find('GAN') >= 0: module = import_module('loss.adversarial') loss_function = getattr(module, 'Adversarial')( args, loss_type ) elif loss_type == 'FILTER': module = import_module('loss.filter') loss_function = getattr(module, 'Filter')(args) elif loss_type == 'SSIM': module = import_module('loss.mssim') loss_function = getattr(module, 'SSIM')(args) elif loss_type == 'MSSSIM': module = import_module('loss.mssim') loss_function = getattr(module, 'MSSSIM')(args) self.loss.append({ 'type': loss_type, 'weight': float(weight), 'function': loss_function} ) if loss_type.find('GAN') >= 0: self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) if len(self.loss) > 1: self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) for l in self.loss: if l['function'] is not None: if args.local_rank == 0: print('{:.3f} * {}'.format(l['weight'], l['type'])) self.loss_module.append(l['function']) self.log = torch.Tensor() device = torch.device('cpu' if args.cpu else 'cuda') self.loss_module.to(device) if args.precision == 'half': self.loss_module.half() if not args.cpu and args.n_GPUs > 1: self.loss_module = nn.DataParallel( self.loss_module, range(args.n_GPUs) ) if args.load != '': self.load(ckp.dir, cpu=args.cpu) def forward(self, sr, hr): losses = [] for i, l in enumerate(self.loss): if l['function'] is not None: loss = l['function'](sr, hr) effective_loss = l['weight'] * loss losses.append(effective_loss) self.log[-1, i] += effective_loss.item() elif l['type'] == 'DIS': self.log[-1, i] += self.loss[i - 1]['function'].loss loss_sum = sum(losses) if len(self.loss) > 1: self.log[-1, -1] += loss_sum.item() return loss_sum def step(self): for l in self.get_loss_module(): if hasattr(l, 'scheduler'): l.scheduler.step() def start_log(self): self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) def end_log(self, n_batches): self.log[-1].div_(n_batches) def display_loss(self, batch): n_samples = batch + 1 log = [] for l, c in zip(self.loss, self.log[-1]): log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) return ''.join(log) def plot_loss(self, apath, epoch): axis = np.linspace(1, epoch, epoch) for i, l in enumerate(self.loss): label = '{} Loss'.format(l['type']) fig = plt.figure() plt.title(label) plt.plot(axis, self.log[:, i].numpy(), label=label) plt.legend() plt.xlabel('Epochs') plt.ylabel('Loss') plt.grid(True) plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) plt.close(fig) def get_loss_module(self): if self.n_GPUs == 1: return self.loss_module else: return self.loss_module.module def save(self, apath): torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) torch.save(self.log, os.path.join(apath, 'loss_log.pt')) def load(self, apath, cpu=False): if cpu: kwargs = {'map_location': lambda storage, loc: storage} else: kwargs = {} self.load_state_dict(torch.load( os.path.join(apath, 'loss.pt'), **kwargs )) self.log = torch.load(os.path.join(apath, 'loss_log.pt')) for l in self.get_loss_module(): if hasattr(l, 'scheduler'): for _ in range(len(self.log)): l.scheduler.step() ================================================ FILE: code/real/bsrt/loss/adversarial.py ================================================ import utility from types import SimpleNamespace from model import common from loss import discriminator import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class Adversarial(nn.Module): def __init__(self, args, gan_type): super(Adversarial, self).__init__() self.gan_type = gan_type self.gan_k = args.gan_k self.dis = discriminator.Discriminator(args) # if gan_type == 'WGAN_GP': if True: # see https://arxiv.org/pdf/1704.00028.pdf pp.4 optim_dict = { 'optimizer': 'ADAM', 'betas': (0.5, 0.9), 'epsilon': 1e-8, 'lr': 1e-5, 'weight_decay': args.weight_decay, 'decay': args.decay, 'gamma': args.gamma } optim_args = SimpleNamespace(**optim_dict) else: optim_args = args self.optimizer = utility.make_optimizer(optim_args, self.dis) def forward(self, fake, real): # updating discriminator... self.loss = 0 fake_detach = fake.detach() # do not backpropagate through G for _ in range(self.gan_k): self.optimizer.zero_grad() # d: B x 1 tensor d_fake = self.dis(fake_detach) d_real = self.dis(real) retain_graph = False if self.gan_type in ['GAN', 'SNGAN']: loss_d = self.bce(d_real, d_fake) elif self.gan_type.find('WGAN') >= 0: loss_d = (d_fake - d_real).mean() if self.gan_type.find('GP') >= 0: epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) hat.requires_grad = True d_hat = self.dis(hat) gradients = torch.autograd.grad( outputs=d_hat.sum(), inputs=hat, retain_graph=True, create_graph=True, only_inputs=True )[0] gradients = gradients.view(gradients.size(0), -1) gradient_norm = gradients.norm(2, dim=1) gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() loss_d += gradient_penalty # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks elif self.gan_type == 'RGAN': better_real = d_real - d_fake.mean(dim=0, keepdim=True) better_fake = d_fake - d_real.mean(dim=0, keepdim=True) loss_d = self.bce(better_real, better_fake) retain_graph = True # Discriminator update self.loss += loss_d.item() loss_d.backward(retain_graph=retain_graph) self.optimizer.step() if self.gan_type == 'WGAN': for p in self.dis.parameters(): p.data.clamp_(-1, 1) self.loss /= self.gan_k # updating generator... d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is if self.gan_type in ['GAN', 'SNGAN']: label_real = torch.ones_like(d_fake_bp) loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) elif self.gan_type.find('WGAN') >= 0: loss_g = -d_fake_bp.mean() elif self.gan_type == 'RGAN': better_real = d_real.detach() - d_fake_bp.mean(dim=0, keepdim=True) better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True).detach() loss_g = self.bce(better_fake, better_real) # Generator loss return loss_g def state_dict(self, *args, **kwargs): state_discriminator = self.dis.state_dict(*args, **kwargs) state_optimizer = self.optimizer.state_dict() return dict(**state_discriminator, **state_optimizer) def bce(self, real, fake): label_real = torch.ones_like(real) label_fake = torch.zeros_like(fake) bce_real = F.binary_cross_entropy_with_logits(real, label_real) bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) bce_loss = bce_real + bce_fake return bce_loss # Some references # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py # OR # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py ================================================ FILE: code/real/bsrt/loss/discriminator.py ================================================ from model import common import torch.nn as nn class Discriminator(nn.Module): ''' output is not normalized ''' def __init__(self, args, gan_type='GAN'): super(Discriminator, self).__init__() in_channels = args.n_colors out_channels = 32 depth = 6 def _block(_in_channels, _out_channels, stride=1): Conv = nn.Conv2d( _in_channels, _out_channels, 3, padding=1, stride=stride, bias=False ) if gan_type == 'SNGAN': return nn.Sequential( spectral_norm(Conv), nn.BatchNorm2d(_out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True) ) else: return nn.Sequential( Conv, nn.BatchNorm2d(_out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True) ) m_features = [_block(in_channels, out_channels)] for i in range(depth): in_channels = out_channels # if i % 2 == 1: # stride = 1 # out_channels *= 2 # else: out_channels *= 2 stride = 2 m_features.append(_block(in_channels, out_channels, stride=stride)) patch_size = args.patch_size // 2**(depth-1) # print(out_channels, patch_size) m_classifier = [ nn.Flatten(), nn.Linear(out_channels*patch_size**2, 512), nn.LeakyReLU(0.2, True), nn.Linear(512, 1) ] self.features = nn.Sequential(*m_features) self.classifier = nn.Sequential(*m_classifier) def forward(self, x): features = self.features(x) # print(features.shape) output = self.classifier(features) return output ================================================ FILE: code/real/bsrt/loss/filter.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class Filter(nn.Module): def __init__(self, args): super().__init__() self.args = args kernel = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]]) self.conv = nn.Conv2d(args.n_colors, args.n_colors, 3, 3) with torch.no_grad(): self.conv.weight.copy_(kernel.float()) self.loss = nn.L1Loss() def forward(self, x, y): preds_x = self.conv(x) preds_y = self.conv(y) return self.loss(preds_x, preds_y) ================================================ FILE: code/real/bsrt/loss/hist_entropy.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class HistEntropy(nn.Module): def __init__(self, args): super().__init__() self.args = args def forward(self, x): p = torch.softmax(x, dim=1) logp = torch.log_softmax(x, dim=1) entropy = (-p * logp).sum(dim=(2, 3)).mean() return entropy ================================================ FILE: code/real/bsrt/loss/mssim.py ================================================ import torch import torch.nn.functional as F from math import exp import numpy as np def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) return gauss/gauss.sum() def create_window(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() return window def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). if val_range is None: if torch.max(img1) > 128: max_val = 255 else: max_val = 1 if torch.min(img1) < -0.5: min_val = -1 else: min_val = 0 L = max_val - min_val else: L = val_range padd = 0 (_, channel, height, width) = img1.size() if window is None: real_size = min(window_size, height, width) window = create_window(real_size, channel=channel).to(img1.device) mu1 = F.conv2d(img1, window, padding=padd, groups=channel) mu2 = F.conv2d(img2, window, padding=padd, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 C1 = (0.01 * L) ** 2 C2 = (0.03 * L) ** 2 v1 = 2.0 * sigma12 + C2 v2 = sigma1_sq + sigma2_sq + C2 cs = torch.mean(v1 / v2) # contrast sensitivity ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) if size_average: ret = ssim_map.mean() else: ret = ssim_map.mean(1).mean(1).mean(1) if full: return ret, cs return ret def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None): device = img1.device weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) levels = weights.size()[0] ssims = [] mcs = [] for _ in range(levels): sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) # Relu normalize (not compliant with original definition) if normalize == "relu": ssims.append(torch.relu(sim)) mcs.append(torch.relu(cs)) else: ssims.append(sim) mcs.append(cs) img1 = F.avg_pool2d(img1, (2, 2)) img2 = F.avg_pool2d(img2, (2, 2)) ssims = torch.stack(ssims) mcs = torch.stack(mcs) # Simple normalize (not compliant with original definition) # TODO: remove support for normalize == True (kept for backward support) if normalize == "simple" or normalize == True: ssims = (ssims + 1) / 2 mcs = (mcs + 1) / 2 pow1 = mcs ** weights pow2 = ssims ** weights # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ output = torch.prod(pow1[:-1] * pow2[-1]) return output # Classes to re-use window class SSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True, val_range=None): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.val_range = val_range # Assume 1 channel for SSIM self.channel = 1 self.window = create_window(window_size) def forward(self, img1, img2): (_, channel, _, _) = img1.size() if channel == self.channel and self.window.dtype == img1.dtype: window = self.window else: window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) self.window = window self.channel = channel return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) class MSSSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True, channel=3): super(MSSSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = channel def forward(self, img1, img2): # TODO: store window between calls if possible return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) ================================================ FILE: code/real/bsrt/loss/vgg.py ================================================ from model import common import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class VGG(nn.Module): def __init__(self, conv_index, rgb_range=1): super(VGG, self).__init__() vgg_features = models.vgg19(pretrained=True).features modules = [m for m in vgg_features] if conv_index.find('22') >= 0: self.vgg = nn.Sequential(*modules[:8]) elif conv_index.find('54') >= 0: self.vgg = nn.Sequential(*modules[:35]) vgg_mean = (0.485, 0.456, 0.406) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) for p in self.parameters(): p.requires_grad = False def forward(self, sr, hr): def _forward(x): # x = self.sub_mean(x) x = self.vgg(x) return x sr = sr.repeat(1, 3, 1, 1) hr = hr.repeat(1, 3, 1, 1) vgg_sr = _forward(sr) with torch.no_grad(): vgg_hr = _forward(hr.detach()) loss = F.mse_loss(vgg_sr, vgg_hr) return loss ================================================ FILE: code/real/bsrt/main.py ================================================ import torch import random import numpy as np from torch.utils.data import DataLoader import os import utility import model import loss from option import args from trainer import Trainer from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image import torch.multiprocessing as mp import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.utils.data.distributed def init_seeds(seed=0, cuda_deterministic=True): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html if cuda_deterministic: # slower, more reproducible cudnn.deterministic = True cudnn.benchmark = False else: # faster, less reproducible cudnn.deterministic = False cudnn.benchmark = True checkpoint = utility.checkpoint(args) def main(): mp.spawn(main_worker, nprocs=args.n_GPUs, args=(args.n_GPUs, args)) def main_worker(local_rank, nprocs, args): # print(local_rank) if checkpoint.ok: args.local_rank = local_rank init_seeds(local_rank+1) cudnn.benchmark = True utility.setup(local_rank, nprocs) torch.cuda.set_device(local_rank) batch_size = int(args.batch_size / nprocs) train_data = BurstSRDataset(root=args.root, burst_size=args.burst_size, crop_sz=args.patch_size, random_flip=True, center_crop=True, split='train') valid_data = BurstSRDataset(root=args.root, burst_size=14, crop_sz=80, split='val') if local_rank <= 0: print(f"train data: {len(train_data)}, test data: {len(valid_data)}") if nprocs > 1: train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data, shuffle=False) train_loader = DataLoader(dataset=train_data, batch_size=batch_size, num_workers=args.batch_size, pin_memory=True, drop_last=True, sampler=train_sampler) # args.cpus valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, num_workers=args.batch_size, pin_memory=True, drop_last=True, sampler=valid_sampler) # args.cpus else: train_sampler = None train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=8, shuffle=True, pin_memory=True, drop_last=True) # args.cpus valid_loader = DataLoader(dataset=valid_data, batch_size=args.batch_size, num_workers=4, shuffle=False, pin_memory=True, drop_last=True) # args.cpus _model = model.Model(args, checkpoint) _loss = loss.Loss(args, checkpoint) if not args.test_only else None t = Trainer(args, train_loader, train_sampler, valid_loader, _model, _loss, checkpoint) while not t.terminate(): t.train() del _model del _loss del train_loader del valid_loader # checkpoint.done() if __name__ == '__main__': # if not args.cpu: torch.cuda.set_device(0) main() ================================================ FILE: code/real/bsrt/model/DCNv2/LICENSE ================================================ BSD 3-Clause License Copyright (c) 2019, Charles Shang All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: code/real/bsrt/model/DCNv2/README.md ================================================ ## Deformable Convolutional Networks V2 with Pytorch 1.0 ### Build ```bash ./make.sh # build python test.py # run examples and gradient check ``` ### An Example - deformable conv ```python from dcn_v2 import DCN input = torch.randn(2, 64, 128, 128).cuda() # wrap all things (offset and mask) in DCN dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() output = dcn(input) print(output.shape) ``` - deformable roi pooling ```python from dcn_v2 import DCNPooling input = torch.randn(2, 32, 64, 64).cuda() batch_inds = torch.randint(2, (20, 1)).cuda().float() x = torch.randint(256, (20, 1)).cuda().float() y = torch.randint(256, (20, 1)).cuda().float() w = torch.randint(64, (20, 1)).cuda().float() h = torch.randint(64, (20, 1)).cuda().float() rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) # mdformable pooling (V2) # wrap all things (offset and mask) in DCNPooling dpooling = DCNPooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=32, no_trans=False, group_size=1, trans_std=0.1).cuda() dout = dpooling(input, rois) ``` ### Note Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with, ```bash git checkout pytorch_0.4 ``` ### Known Issues: - [x] Gradient check w.r.t offset (solved) - [ ] Backward is not reentrant (minor) This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some non-differential points? Update: all gradient check passes with double precision. Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for float `<1e-15` for double), so it may not be a serious problem (?) Please post an issue or PR if you have any comments. ================================================ FILE: code/real/bsrt/model/DCNv2/__init__.py ================================================ ================================================ FILE: code/real/bsrt/model/DCNv2/dcn_v2.py ================================================ #!/usr/bin/env python from __future__ import absolute_import, division, print_function import math import torch from torch import nn from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair from torch.cuda.amp import custom_fwd, custom_bwd # from apex import amp import _ext as _backend class _DCNv2(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # @amp.float_function def forward( ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups ): ctx.stride = _pair(stride) ctx.padding = _pair(padding) ctx.dilation = _pair(dilation) ctx.kernel_size = _pair(weight.shape[2:4]) ctx.deformable_groups = deformable_groups output = _backend.dcn_v2_forward( input, weight, bias, offset, mask, ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], ctx.deformable_groups, ) ctx.save_for_backward(input, offset, mask, weight, bias) return output @staticmethod @once_differentiable @custom_bwd # @amp.float_function def backward(ctx, grad_output): input, offset, mask, weight, bias = ctx.saved_tensors grad_input, grad_offset, grad_mask, grad_weight, grad_bias = _backend.dcn_v2_backward( input, weight, bias, offset, mask, grad_output, ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], ctx.deformable_groups, ) return grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None @staticmethod def symbolic( g, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups ): from torch.nn.modules.utils import _pair stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) # as of trt 7, the dcn operation will be translated again by modifying the onnx file # so the exporting code is kept to resemble the forward() return g.op( "DCNv2_2", input, offset, mask, weight, bias, stride_i=stride, padding_i=padding, dilation_i=dilation, deformable_groups_i=deformable_groups, ) dcn_v2_conv = _DCNv2.apply class DCNv2(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1, ): super(DCNv2, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.deformable_groups = deformable_groups self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) self.bias = nn.Parameter(torch.Tensor(out_channels)) self.reset_parameters() def reset_parameters(self): n = self.in_channels for k in self.kernel_size: n *= k stdv = 1.0 / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) self.bias.data.zero_() def forward(self, input, offset, mask): assert ( 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == offset.shape[1] ) assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1] return dcn_v2_conv( input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups, ) class DCN(DCNv2): def __init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1, ): super(DCN, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups ) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Conv2d( self.in_channels, channels_, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=True, ) self.init_offset() def init_offset(self): self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.bias.data.zero_() def forward(self, input): out = self.conv_offset_mask(input) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) return dcn_v2_conv( input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups, ) class DCN_sep(DCNv2): '''Use other features to generate offsets and masks''' def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1): super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Conv2d( self.in_channels, channels_, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=True) self.init_offset() def init_offset(self): self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.bias.data.zero_() def forward(self, input, fea): '''input: input features for deformable conv fea: other features used for generating offsets and mask''' out = self.conv_offset_mask(fea) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) # offset = torch.clamp(offset, -100, 100) offset_mean = torch.mean(torch.abs(offset)) if offset_mean > 250: print('Offset mean is {}, larger than 100.'.format(offset_mean)) # return None # offset[offset>=150] = 1e-3 # offset = offset.clamp(-50, 50) mask = torch.sigmoid(mask) return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups) class FlowGuidedDCN(DCNv2): '''Use other features to generate offsets and masks''' def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1): super(FlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Conv2d( in_channels, channels_, kernel_size, stride, padding, bias=True) self.init_offset() def init_offset(self): self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.bias.data.zero_() def forward(self, input, fea, flows): '''input: input features for deformable conv: N, C, H, W. fea: other features used for generating offsets and mask: N, C, H, W. flows: N, 2, H, W. ''' out = self.conv_offset_mask(fea) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1) offset_mean = torch.mean(torch.abs(offset)) if offset_mean > 250: print('FlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean)) # offset = offset.clamp(-50, 50) # return None mask = torch.sigmoid(mask) return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups) class InsideFlowGuidedDCN(DCNv2): '''Use other features to generate offsets and masks''' def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1): super(InsideFlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Sequential( nn.Conv2d(in_channels*2+2, out_channels, kernel_size, stride, padding, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(out_channels, channels_, kernel_size, stride, padding, bias=True) ) self.reset_parameters() self.init_offset() def reset_parameters(self): n = self.in_channels for k in self.kernel_size: n *= k stdv = 1.0 / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) self.bias.data.zero_() def init_offset(self): self.conv_offset_mask[-1].weight.data.zero_() self.conv_offset_mask[-1].bias.data.zero_() def forward(self, input, warped, ref, flows): '''input: input features for deformable conv: N, C, H, W. fea: other features used for generating offsets and mask: N, C, H, W. flows: N, 2, H, W. ''' out = self.conv_offset_mask(torch.cat([warped, ref, flows], dim=1)) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1) offset_mean = torch.mean(torch.abs(offset)) if offset_mean > 250: print('InsideFlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean)) print('flow mean is {}'.format(torch.abs(flows).mean())) offset = offset.clamp(-50, 50) # return None mask = torch.sigmoid(mask) return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups) class _DCNv2Pooling(Function): @staticmethod def forward( ctx, input, rois, offset, spatial_scale, pooled_size, output_dim, no_trans, group_size=1, part_size=None, sample_per_part=4, trans_std=0.0, ): ctx.spatial_scale = spatial_scale ctx.no_trans = int(no_trans) ctx.output_dim = output_dim ctx.group_size = group_size ctx.pooled_size = pooled_size ctx.part_size = pooled_size if part_size is None else part_size ctx.sample_per_part = sample_per_part ctx.trans_std = trans_std output, output_count = _backend.dcn_v2_psroi_pooling_forward( input, rois, offset, ctx.no_trans, ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size, ctx.part_size, ctx.sample_per_part, ctx.trans_std, ) ctx.save_for_backward(input, rois, offset, output_count) return output @staticmethod @once_differentiable def backward(ctx, grad_output): input, rois, offset, output_count = ctx.saved_tensors grad_input, grad_offset = _backend.dcn_v2_psroi_pooling_backward( grad_output, input, rois, offset, output_count, ctx.no_trans, ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size, ctx.part_size, ctx.sample_per_part, ctx.trans_std, ) return grad_input, None, grad_offset, None, None, None, None, None, None, None, None dcn_v2_pooling = _DCNv2Pooling.apply class DCNv2Pooling(nn.Module): def __init__( self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1, part_size=None, sample_per_part=4, trans_std=0.0, ): super(DCNv2Pooling, self).__init__() self.spatial_scale = spatial_scale self.pooled_size = pooled_size self.output_dim = output_dim self.no_trans = no_trans self.group_size = group_size self.part_size = pooled_size if part_size is None else part_size self.sample_per_part = sample_per_part self.trans_std = trans_std def forward(self, input, rois, offset): assert input.shape[1] == self.output_dim if self.no_trans: offset = input.new() return dcn_v2_pooling( input, rois, offset, self.spatial_scale, self.pooled_size, self.output_dim, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std, ) class DCNPooling(DCNv2Pooling): def __init__( self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1, part_size=None, sample_per_part=4, trans_std=0.0, deform_fc_dim=1024, ): super(DCNPooling, self).__init__( spatial_scale, pooled_size, output_dim, no_trans, group_size, part_size, sample_per_part, trans_std, ) self.deform_fc_dim = deform_fc_dim if not no_trans: self.offset_mask_fc = nn.Sequential( nn.Linear( self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim ), nn.ReLU(inplace=True), nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.ReLU(inplace=True), nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3), ) self.offset_mask_fc[4].weight.data.zero_() self.offset_mask_fc[4].bias.data.zero_() def forward(self, input, rois): offset = input.new() if not self.no_trans: # do roi_align first n = rois.shape[0] roi = dcn_v2_pooling( input, rois, offset, self.spatial_scale, self.pooled_size, self.output_dim, True, # no trans self.group_size, self.part_size, self.sample_per_part, self.trans_std, ) # build mask and offset offset_mask = self.offset_mask_fc(roi.view(n, -1)) offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size) o1, o2, mask = torch.chunk(offset_mask, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) # do pooling with offset and mask return ( dcn_v2_pooling( input, rois, offset, self.spatial_scale, self.pooled_size, self.output_dim, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std, ) * mask ) # only roi_align return dcn_v2_pooling( input, rois, offset, self.spatial_scale, self.pooled_size, self.output_dim, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std, ) ================================================ FILE: code/real/bsrt/model/DCNv2/files.txt ================================================ /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu.so /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.py /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/PKG-INFO /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/SOURCES.txt /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/dependency_links.txt /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/native_libs.txt /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/not-zip-safe /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/top_level.txt /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/__pycache__/_ext.cpython-37.pyc ================================================ FILE: code/real/bsrt/model/DCNv2/make.sh ================================================ #!/usr/bin/env bash python setup.py build develop ================================================ FILE: code/real/bsrt/model/DCNv2/setup.py ================================================ #!/usr/bin/env python import glob import os import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) os.environ["CC"] = "g++" sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": []} define_macros = [] if True: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] else: # raise NotImplementedError('Cuda is not available') pass sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] ext_modules = [ extension( "_ext", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="DCNv2", version="0.1", author="charlesshang", url="https://github.com/charlesshang/DCNv2", description="deformable convolutional networks", packages=find_packages(exclude=("configs", "tests")), # install_requires=requirements, ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp ================================================ #include #include "cpu/dcn_v2_im2col_cpu.h" #include //#include #include //#include //#include //extern THCState *state; // author: Charles Shang // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu // modified from the CUDA version for CPU use by Daniel K. Suhendro at::Tensor dcn_v2_cpu_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group) { // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_out = weight.size(0); const int channels_kernel = weight.size(1); const int kernel_h_ = weight.size(2); const int kernel_w_ = weight.size(3); // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); // printf("Channels: %d %d\n", channels, channels_kernel); // printf("Channels: %d %d\n", channels_out, channels_kernel); AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); AT_ASSERTM(channels == channels_kernel, "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; auto ones = at::ones({height_out, width_out}, input.options()); auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); using scalar_t = float; for (int b = 0; b < batch; b++) { auto input_n = input.select(0, b); auto offset_n = offset.select(0, b); auto mask_n = mask.select(0, b); auto output_n = output.select(0, b); // Do Bias first: // M,N,K are dims of matrix A and B // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) // (N x 1) (1 x M) long m_ = channels_out; long n_ = height_out * width_out; long k_ = 1; THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f, ones.contiguous().data(), k_, bias.contiguous().data(), k_, 0.0f, output_n.data(), n_); modulated_deformable_im2col_cpu(input_n.data(), offset_n.data(), mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns.data()); //(k * m) x (m * n) // Y = WC long m = channels_out; long n = height_out * width_out; long k = channels * kernel_h * kernel_w; THFloatBlas_gemm('n', 'n', n, m, k, 1.0f, columns.data(), n, weight.data(), k, 1.0f, output_n.data(), n); } return output; } std::vector dcn_v2_cpu_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group) { THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_out = weight.size(0); const int channels_kernel = weight.size(1); const int kernel_h_ = weight.size(2); const int kernel_w_ = weight.size(3); AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); AT_ASSERTM(channels == channels_kernel, "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; auto ones = at::ones({height_out, width_out}, input.options()); auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); auto grad_input = at::zeros_like(input); auto grad_weight = at::zeros_like(weight); auto grad_bias = at::zeros_like(bias); auto grad_offset = at::zeros_like(offset); auto grad_mask = at::zeros_like(mask); using scalar_t = float; for (int b = 0; b < batch; b++) { auto input_n = input.select(0, b); auto offset_n = offset.select(0, b); auto mask_n = mask.select(0, b); auto grad_output_n = grad_output.select(0, b); auto grad_input_n = grad_input.select(0, b); auto grad_offset_n = grad_offset.select(0, b); auto grad_mask_n = grad_mask.select(0, b); long m = channels * kernel_h * kernel_w; long n = height_out * width_out; long k = channels_out; THFloatBlas_gemm('n', 't', n, m, k, 1.0f, grad_output_n.data(), n, weight.data(), m, 0.0f, columns.data(), n); // gradient w.r.t. input coordinate data modulated_deformable_col2im_coord_cpu(columns.data(), input_n.data(), offset_n.data(), mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, grad_offset_n.data(), grad_mask_n.data()); // gradient w.r.t. input data modulated_deformable_col2im_cpu(columns.data(), offset_n.data(), mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, grad_input_n.data()); // gradient w.r.t. weight, dWeight should accumulate across the batch and group modulated_deformable_im2col_cpu(input_n.data(), offset_n.data(), mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns.data()); long m_ = channels_out; long n_ = channels * kernel_h * kernel_w; long k_ = height_out * width_out; THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f, columns.data(), k_, grad_output_n.data(), k_, 1.0f, grad_weight.data(), n_); // gradient w.r.t. bias // long m_ = channels_out; // long k__ = height_out * width_out; // THFloatBlas_gemv('t', k_, m_, 1.0f, // grad_output_n.data(), k_, // ones.data(), 1, 1.0f, // grad_bias.data(), 1); } return { grad_input, grad_offset, grad_mask, grad_weight, grad_bias }; } ================================================ FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp ================================================ #include "dcn_v2_im2col_cpu.h" #include #include #include #include //#include #include //#include //#include // modified from the CUDA version for CPU use by Daniel K. Suhendro /*#define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; }*/ float dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_width, const int height, const int width, float h, float w) { int h_low = floor(h); int w_low = floor(w); int h_high = h_low + 1; int w_high = w_low + 1; float lh = h - h_low; float lw = w - w_low; float hh = 1 - lh, hw = 1 - lw; float v1 = 0; if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; float v2 = 0; if (h_low >= 0 && w_high <= width - 1) v2 = bottom_data[h_low * data_width + w_high]; float v3 = 0; if (h_high <= height - 1 && w_low >= 0) v3 = bottom_data[h_high * data_width + w_low]; float v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) v4 = bottom_data[h_high * data_width + w_high]; float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } float dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w, const int h, const int w, const int height, const int width) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { //empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; float weight = 0; if (h == argmax_h_low && w == argmax_w_low) weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); if (h == argmax_h_low && w == argmax_w_high) weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); if (h == argmax_h_high && w == argmax_w_low) weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); if (h == argmax_h_high && w == argmax_w_high) weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); return weight; } float dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w, const int height, const int width, const float *im_data, const int data_width, const int bp_dir) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { //empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; float weight = 0; if (bp_dir == 0) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } else if (bp_dir == 1) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } return weight; } void modulated_deformable_im2col_cpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int num_channels, const int deformable_group, const int height_col, const int width_col, float *data_col) { // launch channels * batch_size * height_col * width_col cores for(int index=0; index(0); const float h_im = h_in + i * dilation_h + offset_h; const float w_im = w_in + j * dilation_w + offset_w; //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { //const float map_h = i * dilation_h + offset_h; //const float map_w = j * dilation_w + offset_w; //const int cur_height = height - h_in; //const int cur_width = width - w_in; //val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, cur_height, cur_width, map_h, map_w); val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im); } *data_col_ptr = val * mask; // data_col_ptr += batch_size * height_col * width_col; data_col_ptr += height_col * width_col; } } } } void modulated_deformable_col2im_cpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int deformable_group, const int height_col, const int width_col, float *grad_im) { for(int index = 0; index < n; index++) { const int j = (index / width_col / height_col / batch_size) % kernel_w; const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; // compute the start and end of the output const int deformable_group_index = c / channel_per_deformable_group; int w_out = index % width_col; int h_out = (index / width_col) % height_col; int b = (index / width_col / height_col) % batch_size; int w_in = w_out * stride_w - pad_w; int h_in = h_out * stride_h - pad_h; const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; const float cur_inv_h_data = h_in + i * dilation_h + offset_h; const float cur_inv_w_data = w_in + j * dilation_w + offset_w; const float cur_top_grad = data_col[index] * mask; const int cur_h = (int)cur_inv_h_data; const int cur_w = (int)cur_inv_w_data; for (int dy = -2; dy <= 2; dy++) { for (int dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1) { int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; float weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); //atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); *(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad; } } } } } void modulated_deformable_col2im_coord_cpu_kernel(const int n, const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int offset_channels, const int deformable_group, const int height_col, const int width_col, float *grad_offset, float *grad_mask) { for(int index = 0; index < n; index++) { float val = 0, mval = 0; int w = index % width_col; int h = (index / width_col) % height_col; int c = (index / width_col / height_col) % offset_channels; int b = (index / width_col / height_col) / offset_channels; // compute the start and end of the output const int deformable_group_index = c / (2 * kernel_h * kernel_w); const int col_step = kernel_h * kernel_w; int cnt = 0; const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) { const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; const int bp_dir = offset_c % 2; int j = (col_pos / width_col / height_col / batch_size) % kernel_w; int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; int w_out = col_pos % width_col; int h_out = (col_pos / width_col) % height_col; int w_in = w_out * stride_w - pad_w; int h_in = h_out * stride_h - pad_h; const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; float inv_h = h_in + i * dilation_h + offset_h; float inv_w = w_in + j * dilation_w + offset_w; if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { inv_h = inv_w = -2; } else { mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); } const float weight = dmcn_get_coordinate_weight_cpu( inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, width, bp_dir); val += weight * data_col_ptr[col_pos] * mask; cnt += 1; } // KERNEL_ASSIGN(grad_offset[index], offset_req, val); grad_offset[index] = val; if (offset_c % 2 == 0) // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; } } void modulated_deformable_im2col_cpu(const float* data_im, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* data_col) { // num_axes should be smaller than block size const int channel_per_deformable_group = channels / deformable_group; const int num_kernels = channels * batch_size * height_col * width_col; modulated_deformable_im2col_cpu_kernel( num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, channels, deformable_group, height_col, width_col, data_col); /*cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); }*/ } void modulated_deformable_col2im_cpu(const float* data_col, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* grad_im){ const int channel_per_deformable_group = channels / deformable_group; const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; modulated_deformable_col2im_cpu_kernel( num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, deformable_group, height_col, width_col, grad_im); /*cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); }*/ } void modulated_deformable_col2im_coord_cpu(const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* grad_offset, float* grad_mask) { const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; modulated_deformable_col2im_coord_cpu_kernel( num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, grad_offset, grad_mask); /*cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); }*/ } ================================================ FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h ================================================ /*! ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** * * COPYRIGHT * * All contributions by the University of California: * Copyright (c) 2014-2017 The Regents of the University of California (Regents) * All rights reserved. * * All other contributions: * Copyright (c) 2014-2017, the respective contributors * All rights reserved. * * Caffe uses a shared copyright model: each contributor holds copyright over * their contributions to Caffe. The project versioning records all such * contribution and copyright details. If a contributor wants to further mark * their specific copyright on a particular contribution, they should indicate * their copyright solely in the commit message of the change when it is * committed. * * LICENSE * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * CONTRIBUTION AGREEMENT * * By contributing to the BVLC/caffe repository through pull-request, comment, * or otherwise, the contributor releases their content to the * license and copyright terms herein. * ***************** END Caffe Copyright Notice and Disclaimer ******************** * * Copyright (c) 2018 Microsoft * Licensed under The MIT License [see LICENSE for details] * \file modulated_deformable_im2col.h * \brief Function definitions of converting an image to * column matrix based on kernel, padding, dilation, and offset. * These functions are mainly used in deformable convolution operators. * \ref: https://arxiv.org/abs/1811.11168 * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu */ /***************** Adapted by Charles Shang *********************/ // modified from the CUDA version for CPU use by Daniel K. Suhendro #ifndef DCN_V2_IM2COL_CPU #define DCN_V2_IM2COL_CPU #ifdef __cplusplus extern "C" { #endif void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *data_col); void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *grad_im); void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *grad_offset, float *grad_mask); #ifdef __cplusplus } #endif #endif ================================================ FILE: code/real/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp ================================================ /*! * Copyright (c) 2017 Microsoft * Licensed under The MIT License [see LICENSE for details] * \file deformable_psroi_pooling.cu * \brief * \author Yi Li, Guodong Zhang, Jifeng Dai */ /***************** Adapted by Charles Shang *********************/ // modified from the CUDA version for CPU use by Daniel K. Suhendro #include #include #include #include //#include #include //#include //#include /*#define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; }*/ template T bilinear_interp_cpu( const T *data, const T x, const T y, const int width, const int height) { int x1 = floor(x); int x2 = ceil(x); int y1 = floor(y); int y2 = ceil(y); T dist_x = static_cast(x - x1); T dist_y = static_cast(y - y1); T value11 = data[y1 * width + x1]; T value12 = data[y2 * width + x1]; T value21 = data[y1 * width + x2]; T value22 = data[y2 * width + x2]; T value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; return value; } template void DeformablePSROIPoolForwardKernelCpu( const int count, const T *bottom_data, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const T *bottom_rois, const T *bottom_trans, const int no_trans, const T trans_std, const int sample_per_part, const int output_dim, const int group_size, const int part_size, const int num_classes, const int channels_each_class, T *top_data, T *top_count) { for(int index = 0; index < count; index++) { // The output is in order (n, ctop, ph, pw) int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int ctop = (index / pooled_width / pooled_height) % output_dim; int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0 T roi_height = std::max(roi_end_h - roi_start_h, T(0.1)); // Compute w and h at bottom T bin_size_h = roi_height / static_cast(pooled_height); T bin_size_w = roi_width / static_cast(pooled_width); T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); int part_h = floor(static_cast(ph) / pooled_height * part_size); int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; T sum = 0; int count = 0; int gw = floor(static_cast(pw) * group_size / pooled_width); int gh = floor(static_cast(ph) * group_size / pooled_height); gw = std::min(std::max(gw, 0), group_size - 1); gh = std::min(std::max(gh, 0), group_size - 1); const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; for (int ih = 0; ih < sample_per_part; ih++) { for (int iw = 0; iw < sample_per_part; iw++) { T w = wstart + iw * sub_bin_size_w; T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = std::min(std::max(w, T(0.)), width - T(1.)); h = std::min(std::max(h, T(0.)), height - T(1.)); int c = (ctop * group_size + gh) * group_size + gw; T val = bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height); sum += val; count++; } } top_data[index] = count == 0 ? static_cast(0) : sum / count; top_count[index] = count; } } template void DeformablePSROIPoolBackwardAccKernelCpu( const int count, const T *top_diff, const T *top_count, const int num_rois, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int output_dim, T *bottom_data_diff, T *bottom_trans_diff, const T *bottom_data, const T *bottom_rois, const T *bottom_trans, const int no_trans, const T trans_std, const int sample_per_part, const int group_size, const int part_size, const int num_classes, const int channels_each_class) { for(int index = 0; index < count; index++) { // The output is in order (n, ctop, ph, pw) int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int ctop = (index / pooled_width / pooled_height) % output_dim; int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0 T roi_height = std::max(roi_end_h - roi_start_h, T(0.1)); // Compute w and h at bottom T bin_size_h = roi_height / static_cast(pooled_height); T bin_size_w = roi_width / static_cast(pooled_width); T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); int part_h = floor(static_cast(ph) / pooled_height * part_size); int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; if (top_count[index] <= 0) { continue; } T diff_val = top_diff[index] / top_count[index]; const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; int gw = floor(static_cast(pw) * group_size / pooled_width); int gh = floor(static_cast(ph) * group_size / pooled_height); gw = std::min(std::max(gw, 0), group_size - 1); gh = std::min(std::max(gh, 0), group_size - 1); for (int ih = 0; ih < sample_per_part; ih++) { for (int iw = 0; iw < sample_per_part; iw++) { T w = wstart + iw * sub_bin_size_w; T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = std::min(std::max(w, T(0.)), width - T(1.)); h = std::min(std::max(h, T(0.)), height - T(1.)); int c = (ctop * group_size + gh) * group_size + gw; // backward on feature int x0 = floor(w); int x1 = ceil(w); int y0 = floor(h); int y1 = ceil(h); T dist_x = w - x0, dist_y = h - y0; T q00 = (1 - dist_x) * (1 - dist_y); T q01 = (1 - dist_x) * dist_y; T q10 = dist_x * (1 - dist_y); T q11 = dist_x * dist_y; int bottom_index_base = c * height * width; /*atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);*/ *(offset_bottom_data_diff + bottom_index_base + y0 * width + x0) += q00 * diff_val; *(offset_bottom_data_diff + bottom_index_base + y1 * width + x0) += q01 * diff_val; *(offset_bottom_data_diff + bottom_index_base + y0 * width + x1) += q10 * diff_val; *(offset_bottom_data_diff + bottom_index_base + y1 * width + x1) += q11 * diff_val; if (no_trans) { continue; } T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; diff_x *= roi_width; T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; diff_y *= roi_height; /*atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);*/ *(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w) += diff_x; *(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w) += diff_y; } } } } std::tuple dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_trans = no_trans ? 2 : trans.size(1); const int num_bbox = bbox.size(0); AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); auto pooled_height = pooled_size; auto pooled_width = pooled_size; auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); long out_size = num_bbox * output_dim * pooled_height * pooled_width; auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; //cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (out.numel() == 0) { //THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); } /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512);*/ AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cpu_forward", [&] { DeformablePSROIPoolForwardKernelCpu( out_size, input.contiguous().data(), spatial_scale, channels, height, width, pooled_height, pooled_width, bbox.contiguous().data(), trans.contiguous().data(), no_trans, trans_std, sample_per_part, output_dim, group_size, part_size, num_classes, channels_each_class, out.data(), top_count.data()); }); //THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); } std::tuple dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { /*AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_trans = no_trans ? 2 : trans.size(1); const int num_bbox = bbox.size(0); AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); auto pooled_height = pooled_size; auto pooled_width = pooled_size; long out_size = num_bbox * output_dim * pooled_height * pooled_width; const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); auto trans_grad = at::zeros_like(trans); if (input_grad.numel() == 0) { //THCudaCheck(cudaGetLastError()); return std::make_tuple(input_grad, trans_grad); } /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512); cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/ AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cpu_backward", [&] { DeformablePSROIPoolBackwardAccKernelCpu( out_size, out_grad.contiguous().data(), top_count.contiguous().data(), num_bbox, spatial_scale, channels, height, width, pooled_height, pooled_width, output_dim, input_grad.contiguous().data(), trans_grad.contiguous().data(), input.contiguous().data(), bbox.contiguous().data(), trans.contiguous().data(), no_trans, trans_std, sample_per_part, group_size, part_size, num_classes, channels_each_class); }); //THCudaCheck(cudaGetLastError()); return std::make_tuple(input_grad, trans_grad); } ================================================ FILE: code/real/bsrt/model/DCNv2/src/cpu/vision.h ================================================ #pragma once #include at::Tensor dcn_v2_cpu_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group); std::vector dcn_v2_cpu_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group); std::tuple dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std); std::tuple dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std); ================================================ FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_cuda.cu ================================================ #include #include "cuda/dcn_v2_im2col_cuda.h" #include #include #include #include #include #include #include #include #include #include THCState *state = at::globalContext().lazyInitCUDA(); static cublasOperation_t _cublasOpFromChar(char op) { switch (op) { case 'n': case 'N': return CUBLAS_OP_N; case 't': case 'T': return CUBLAS_OP_T; case 'c': case 'C': return CUBLAS_OP_C; } AT_ERROR( "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); } static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) { // Note: leading dimensions generally are checked that they are > 0 // and at least as big the result requires (even if the value won't // be used). // Q: Why does Level3 check trans but this doesn't? // A: In level 2, the sizes (m, n) specify the size of A // (independent of trans value). In level 3. the sizes (m, n, k) // specify the sizes of op(A), op(B) where op depend on trans // values. if (n <= 1) *lda = std::max(m, 1); } // author: Charles Shang // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu // [batch gemm] // https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu __global__ void createBatchGemmBuffer(const float **input_b, float **output_b, float **columns_b, const float **ones_b, const float **weight_b, const float **bias_b, float *input, float *output, float *columns, float *ones, float *weight, float *bias, const int input_stride, const int output_stride, const int columns_stride, const int ones_stride, const int num_batches) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num_batches) { input_b[idx] = input + idx * input_stride; output_b[idx] = output + idx * output_stride; columns_b[idx] = columns + idx * columns_stride; ones_b[idx] = ones + idx * ones_stride; // share weights and bias within a Mini-Batch weight_b[idx] = weight; bias_b[idx] = bias; } } at::Tensor dcn_v2_cuda_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group) { using scalar_t = float; // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_out = weight.size(0); const int channels_kernel = weight.size(1); const int kernel_h_ = weight.size(2); const int kernel_w_ = weight.size(3); // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); // printf("Channels: %d %d\n", channels, channels_kernel); // printf("Channels: %d %d\n", channels_out, channels_kernel); AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); AT_ASSERTM(channels == channels_kernel, "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; auto ones = at::ones({batch, height_out, width_out}, input.options()); auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); // prepare for batch-wise computing, which is significantly faster than instance-wise computing // when batch size is large. // launch batch threads int matrices_size = batch * sizeof(float *); auto input_b = static_cast(THCudaMalloc(state, matrices_size)); auto output_b = static_cast(THCudaMalloc(state, matrices_size)); auto columns_b = static_cast(THCudaMalloc(state, matrices_size)); auto ones_b = static_cast(THCudaMalloc(state, matrices_size)); auto weight_b = static_cast(THCudaMalloc(state, matrices_size)); auto bias_b = static_cast(THCudaMalloc(state, matrices_size)); const int block = 128; const int grid = (batch + block - 1) / block; createBatchGemmBuffer<<>>( input_b, output_b, columns_b, ones_b, weight_b, bias_b, input.data_ptr(), output.data_ptr(), columns.data_ptr(), ones.data_ptr(), weight.data_ptr(), bias.data_ptr(), channels * width * height, channels_out * width_out * height_out, channels * kernel_h * kernel_w * height_out * width_out, height_out * width_out, batch); long m_ = channels_out; long n_ = height_out * width_out; long k_ = 1; THCudaBlas_SgemmBatched(state, 't', 'n', n_, m_, k_, 1.0f, ones_b, k_, bias_b, k_, 0.0f, output_b, n_, batch); modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), input.data_ptr(), offset.data_ptr(), mask.data_ptr(), batch, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns.data_ptr()); long m = channels_out; long n = height_out * width_out; long k = channels * kernel_h * kernel_w; THCudaBlas_SgemmBatched(state, 'n', 'n', n, m, k, 1.0f, (const float **)columns_b, n, weight_b, k, 1.0f, output_b, n, batch); THCudaFree(state, input_b); THCudaFree(state, output_b); THCudaFree(state, columns_b); THCudaFree(state, ones_b); THCudaFree(state, weight_b); THCudaFree(state, bias_b); return output; } __global__ void createBatchGemmBufferBackward( float **grad_output_b, float **columns_b, float **ones_b, float **weight_b, float **grad_weight_b, float **grad_bias_b, float *grad_output, float *columns, float *ones, float *weight, float *grad_weight, float *grad_bias, const int grad_output_stride, const int columns_stride, const int ones_stride, const int num_batches) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num_batches) { grad_output_b[idx] = grad_output + idx * grad_output_stride; columns_b[idx] = columns + idx * columns_stride; ones_b[idx] = ones + idx * ones_stride; // share weights and bias within a Mini-Batch weight_b[idx] = weight; grad_weight_b[idx] = grad_weight; grad_bias_b[idx] = grad_bias; } } std::vector dcn_v2_cuda_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group) { THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_out = weight.size(0); const int channels_kernel = weight.size(1); const int kernel_h_ = weight.size(2); const int kernel_w_ = weight.size(3); AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); AT_ASSERTM(channels == channels_kernel, "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; auto ones = at::ones({height_out, width_out}, input.options()); auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); auto grad_input = at::zeros_like(input); auto grad_weight = at::zeros_like(weight); auto grad_bias = at::zeros_like(bias); auto grad_offset = at::zeros_like(offset); auto grad_mask = at::zeros_like(mask); using scalar_t = float; for (int b = 0; b < batch; b++) { auto input_n = input.select(0, b); auto offset_n = offset.select(0, b); auto mask_n = mask.select(0, b); auto grad_output_n = grad_output.select(0, b); auto grad_input_n = grad_input.select(0, b); auto grad_offset_n = grad_offset.select(0, b); auto grad_mask_n = grad_mask.select(0, b); long m = channels * kernel_h * kernel_w; long n = height_out * width_out; long k = channels_out; THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, grad_output_n.data_ptr(), n, weight.data_ptr(), m, 0.0f, columns.data_ptr(), n); // gradient w.r.t. input coordinate data modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(), columns.data_ptr(), input_n.data_ptr(), offset_n.data_ptr(), mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, grad_offset_n.data_ptr(), grad_mask_n.data_ptr()); // gradient w.r.t. input data modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(), columns.data_ptr(), offset_n.data_ptr(), mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, grad_input_n.data_ptr()); // gradient w.r.t. weight, dWeight should accumulate across the batch and group modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), input_n.data_ptr(), offset_n.data_ptr(), mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns.data_ptr()); long m_ = channels_out; long n_ = channels * kernel_h * kernel_w; long k_ = height_out * width_out; THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, columns.data_ptr(), k_, grad_output_n.data_ptr(), k_, 1.0f, grad_weight.data_ptr(), n_); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasOperation_t op = _cublasOpFromChar('t'); _cublasAdjustLdLevel2(k_, m_, &k_); scalar_t* grad_output_n_float = grad_output_n.data_ptr(); scalar_t* one_float = ones.data_ptr(); scalar_t alpha = 1.0; scalar_t beta = 1.0; cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr(), 1); } return { grad_input, grad_offset, grad_mask, grad_weight, grad_bias }; } ================================================ FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu ================================================ #include "dcn_v2_im2col_cuda.h" #include #include #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } __device__ float dmcn_im2col_bilinear_cuda(const float *bottom_data, const int data_width, const int height, const int width, float h, float w) { int h_low = floor(h); int w_low = floor(w); int h_high = h_low + 1; int w_high = w_low + 1; float lh = h - h_low; float lw = w - w_low; float hh = 1 - lh, hw = 1 - lw; float v1 = 0; if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; float v2 = 0; if (h_low >= 0 && w_high <= width - 1) v2 = bottom_data[h_low * data_width + w_high]; float v3 = 0; if (h_high <= height - 1 && w_low >= 0) v3 = bottom_data[h_high * data_width + w_low]; float v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) v4 = bottom_data[h_high * data_width + w_high]; float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } __device__ float dmcn_get_gradient_weight_cuda(float argmax_h, float argmax_w, const int h, const int w, const int height, const int width) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { //empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; float weight = 0; if (h == argmax_h_low && w == argmax_w_low) weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); if (h == argmax_h_low && w == argmax_w_high) weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); if (h == argmax_h_high && w == argmax_w_low) weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); if (h == argmax_h_high && w == argmax_w_high) weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); return weight; } __device__ float dmcn_get_coordinate_weight_cuda(float argmax_h, float argmax_w, const int height, const int width, const float *im_data, const int data_width, const int bp_dir) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { //empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; float weight = 0; if (bp_dir == 0) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } else if (bp_dir == 1) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } return weight; } __global__ void modulated_deformable_im2col_gpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int num_channels, const int deformable_group, const int height_col, const int width_col, float *data_col) { // launch channels * batch_size * height_col * width_col cores CUDA_KERNEL_LOOP(index, n) { // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow) // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis // index index of output matrix const int w_col = index % width_col; const int h_col = (index / width_col) % height_col; // const int b_col = (index / width_col / height_col) % batch_size; const int b_col = (index / width_col / height_col / num_channels) % batch_size; // const int c_im = (index / width_col / height_col) / batch_size; const int c_im = (index / width_col / height_col) % num_channels; // const int c_col = c_im * kernel_h * kernel_w; const int c_col = c_im * kernel_h * kernel_w; // compute deformable group index const int deformable_group_index = c_im / channel_per_deformable_group; const int h_in = h_col * stride_h - pad_h; const int w_in = w_col * stride_w - pad_w; // float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col; //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; for (int i = 0; i < kernel_h; ++i) { for (int j = 0; j < kernel_w; ++j) { const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; float val = static_cast(0); const float h_im = h_in + i * dilation_h + offset_h; const float w_im = w_in + j * dilation_w + offset_w; //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { //const float map_h = i * dilation_h + offset_h; //const float map_w = j * dilation_w + offset_w; //const int cur_height = height - h_in; //const int cur_width = width - w_in; //val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, cur_height, cur_width, map_h, map_w); val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, height, width, h_im, w_im); } *data_col_ptr = val * mask; // data_col_ptr += batch_size * height_col * width_col; data_col_ptr += height_col * width_col; } } } } __global__ void modulated_deformable_col2im_gpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int deformable_group, const int height_col, const int width_col, float *grad_im) { CUDA_KERNEL_LOOP(index, n) { const int j = (index / width_col / height_col / batch_size) % kernel_w; const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; // compute the start and end of the output const int deformable_group_index = c / channel_per_deformable_group; int w_out = index % width_col; int h_out = (index / width_col) % height_col; int b = (index / width_col / height_col) % batch_size; int w_in = w_out * stride_w - pad_w; int h_in = h_out * stride_h - pad_h; const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; const float cur_inv_h_data = h_in + i * dilation_h + offset_h; const float cur_inv_w_data = w_in + j * dilation_w + offset_w; const float cur_top_grad = data_col[index] * mask; const int cur_h = (int)cur_inv_h_data; const int cur_w = (int)cur_inv_w_data; for (int dy = -2; dy <= 2; dy++) { for (int dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1) { int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; float weight = dmcn_get_gradient_weight_cuda(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); } } } } } __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int offset_channels, const int deformable_group, const int height_col, const int width_col, float *grad_offset, float *grad_mask) { CUDA_KERNEL_LOOP(index, n) { float val = 0, mval = 0; int w = index % width_col; int h = (index / width_col) % height_col; int c = (index / width_col / height_col) % offset_channels; int b = (index / width_col / height_col) / offset_channels; // compute the start and end of the output const int deformable_group_index = c / (2 * kernel_h * kernel_w); const int col_step = kernel_h * kernel_w; int cnt = 0; const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) { const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; const int bp_dir = offset_c % 2; int j = (col_pos / width_col / height_col / batch_size) % kernel_w; int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; int w_out = col_pos % width_col; int h_out = (col_pos / width_col) % height_col; int w_in = w_out * stride_w - pad_w; int h_in = h_out * stride_h - pad_h; const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; float inv_h = h_in + i * dilation_h + offset_h; float inv_w = w_in + j * dilation_w + offset_w; if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { inv_h = inv_w = -2; } else { mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cuda(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); } const float weight = dmcn_get_coordinate_weight_cuda( inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, width, bp_dir); val += weight * data_col_ptr[col_pos] * mask; cnt += 1; } // KERNEL_ASSIGN(grad_offset[index], offset_req, val); grad_offset[index] = val; if (offset_c % 2 == 0) // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; } } void modulated_deformable_im2col_cuda(cudaStream_t stream, const float* data_im, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* data_col) { // num_axes should be smaller than block size const int channel_per_deformable_group = channels / deformable_group; const int num_kernels = channels * batch_size * height_col * width_col; modulated_deformable_im2col_gpu_kernel <<>>( num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, channels, deformable_group, height_col, width_col, data_col); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); } } void modulated_deformable_col2im_cuda(cudaStream_t stream, const float* data_col, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* grad_im){ const int channel_per_deformable_group = channels / deformable_group; const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; modulated_deformable_col2im_gpu_kernel <<>>( num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, deformable_group, height_col, width_col, grad_im); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); } } void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* grad_offset, float* grad_mask) { const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; modulated_deformable_col2im_coord_gpu_kernel <<>>( num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, grad_offset, grad_mask); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.h ================================================ /*! ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** * * COPYRIGHT * * All contributions by the University of California: * Copyright (c) 2014-2017 The Regents of the University of California (Regents) * All rights reserved. * * All other contributions: * Copyright (c) 2014-2017, the respective contributors * All rights reserved. * * Caffe uses a shared copyright model: each contributor holds copyright over * their contributions to Caffe. The project versioning records all such * contribution and copyright details. If a contributor wants to further mark * their specific copyright on a particular contribution, they should indicate * their copyright solely in the commit message of the change when it is * committed. * * LICENSE * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * CONTRIBUTION AGREEMENT * * By contributing to the BVLC/caffe repository through pull-request, comment, * or otherwise, the contributor releases their content to the * license and copyright terms herein. * ***************** END Caffe Copyright Notice and Disclaimer ******************** * * Copyright (c) 2018 Microsoft * Licensed under The MIT License [see LICENSE for details] * \file modulated_deformable_im2col.h * \brief Function definitions of converting an image to * column matrix based on kernel, padding, dilation, and offset. * These functions are mainly used in deformable convolution operators. * \ref: https://arxiv.org/abs/1811.11168 * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu */ /***************** Adapted by Charles Shang *********************/ #ifndef DCN_V2_IM2COL_CUDA #define DCN_V2_IM2COL_CUDA #ifdef __cplusplus extern "C" { #endif void modulated_deformable_im2col_cuda(cudaStream_t stream, const float *data_im, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *data_col); void modulated_deformable_col2im_cuda(cudaStream_t stream, const float *data_col, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *grad_im); void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *grad_offset, float *grad_mask); #ifdef __cplusplus } #endif #endif ================================================ FILE: code/real/bsrt/model/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu ================================================ /*! * Copyright (c) 2017 Microsoft * Licensed under The MIT License [see LICENSE for details] * \file deformable_psroi_pooling.cu * \brief * \author Yi Li, Guodong Zhang, Jifeng Dai */ /***************** Adapted by Charles Shang *********************/ #include #include #include #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } template __device__ T bilinear_interp_cuda( const T *data, const T x, const T y, const int width, const int height) { int x1 = floor(x); int x2 = ceil(x); int y1 = floor(y); int y2 = ceil(y); T dist_x = static_cast(x - x1); T dist_y = static_cast(y - y1); T value11 = data[y1 * width + x1]; T value12 = data[y2 * width + x1]; T value21 = data[y1 * width + x2]; T value22 = data[y2 * width + x2]; T value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; return value; } template __global__ void DeformablePSROIPoolForwardKernelCuda( const int count, const T *bottom_data, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const T *bottom_rois, const T *bottom_trans, const int no_trans, const T trans_std, const int sample_per_part, const int output_dim, const int group_size, const int part_size, const int num_classes, const int channels_each_class, T *top_data, T *top_count) { CUDA_KERNEL_LOOP(index, count) { // The output is in order (n, ctop, ph, pw) int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int ctop = (index / pooled_width / pooled_height) % output_dim; int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 T roi_height = max(roi_end_h - roi_start_h, 0.1); // Compute w and h at bottom T bin_size_h = roi_height / static_cast(pooled_height); T bin_size_w = roi_width / static_cast(pooled_width); T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); int part_h = floor(static_cast(ph) / pooled_height * part_size); int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; T sum = 0; int count = 0; int gw = floor(static_cast(pw) * group_size / pooled_width); int gh = floor(static_cast(ph) * group_size / pooled_height); gw = min(max(gw, 0), group_size - 1); gh = min(max(gh, 0), group_size - 1); const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; for (int ih = 0; ih < sample_per_part; ih++) { for (int iw = 0; iw < sample_per_part; iw++) { T w = wstart + iw * sub_bin_size_w; T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = min(max(w, 0.), width - 1.); h = min(max(h, 0.), height - 1.); int c = (ctop * group_size + gh) * group_size + gw; T val = bilinear_interp_cuda(offset_bottom_data + c * height * width, w, h, width, height); sum += val; count++; } } top_data[index] = count == 0 ? static_cast(0) : sum / count; top_count[index] = count; } } template __global__ void DeformablePSROIPoolBackwardAccKernelCuda( const int count, const T *top_diff, const T *top_count, const int num_rois, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int output_dim, T *bottom_data_diff, T *bottom_trans_diff, const T *bottom_data, const T *bottom_rois, const T *bottom_trans, const int no_trans, const T trans_std, const int sample_per_part, const int group_size, const int part_size, const int num_classes, const int channels_each_class) { CUDA_KERNEL_LOOP(index, count) { // The output is in order (n, ctop, ph, pw) int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int ctop = (index / pooled_width / pooled_height) % output_dim; int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 T roi_height = max(roi_end_h - roi_start_h, 0.1); // Compute w and h at bottom T bin_size_h = roi_height / static_cast(pooled_height); T bin_size_w = roi_width / static_cast(pooled_width); T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); int part_h = floor(static_cast(ph) / pooled_height * part_size); int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; if (top_count[index] <= 0) { continue; } T diff_val = top_diff[index] / top_count[index]; const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; int gw = floor(static_cast(pw) * group_size / pooled_width); int gh = floor(static_cast(ph) * group_size / pooled_height); gw = min(max(gw, 0), group_size - 1); gh = min(max(gh, 0), group_size - 1); for (int ih = 0; ih < sample_per_part; ih++) { for (int iw = 0; iw < sample_per_part; iw++) { T w = wstart + iw * sub_bin_size_w; T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = min(max(w, 0.), width - 1.); h = min(max(h, 0.), height - 1.); int c = (ctop * group_size + gh) * group_size + gw; // backward on feature int x0 = floor(w); int x1 = ceil(w); int y0 = floor(h); int y1 = ceil(h); T dist_x = w - x0, dist_y = h - y0; T q00 = (1 - dist_x) * (1 - dist_y); T q01 = (1 - dist_x) * dist_y; T q10 = dist_x * (1 - dist_y); T q11 = dist_x * dist_y; int bottom_index_base = c * height * width; atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); if (no_trans) { continue; } T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; diff_x *= roi_width; T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; diff_y *= roi_height; atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); } } } } std::tuple dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_trans = no_trans ? 2 : trans.size(1); const int num_bbox = bbox.size(0); AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); auto pooled_height = pooled_size; auto pooled_width = pooled_size; auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); long out_size = num_bbox * output_dim * pooled_height * pooled_width; auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (out.numel() == 0) { THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); } dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512); AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { DeformablePSROIPoolForwardKernelCuda<<>>( out_size, input.contiguous().data_ptr(), spatial_scale, channels, height, width, pooled_height, pooled_width, bbox.contiguous().data_ptr(), trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, output_dim, group_size, part_size, num_classes, channels_each_class, out.data_ptr(), top_count.data_ptr()); }); THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); } std::tuple dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_trans = no_trans ? 2 : trans.size(1); const int num_bbox = bbox.size(0); AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); auto pooled_height = pooled_size; auto pooled_width = pooled_size; long out_size = num_bbox * output_dim * pooled_height * pooled_width; const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); auto trans_grad = at::zeros_like(trans); if (input_grad.numel() == 0) { THCudaCheck(cudaGetLastError()); return std::make_tuple(input_grad, trans_grad); } dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { DeformablePSROIPoolBackwardAccKernelCuda<<>>( out_size, out_grad.contiguous().data_ptr(), top_count.contiguous().data_ptr(), num_bbox, spatial_scale, channels, height, width, pooled_height, pooled_width, output_dim, input_grad.contiguous().data_ptr(), trans_grad.contiguous().data_ptr(), input.contiguous().data_ptr(), bbox.contiguous().data_ptr(), trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, group_size, part_size, num_classes, channels_each_class); }); THCudaCheck(cudaGetLastError()); return std::make_tuple(input_grad, trans_grad); } ================================================ FILE: code/real/bsrt/model/DCNv2/src/cuda/vision.h ================================================ #pragma once #include #include at::Tensor dcn_v2_cuda_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group); std::vector dcn_v2_cuda_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group); std::tuple dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std); std::tuple dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std); ================================================ FILE: code/real/bsrt/model/DCNv2/src/dcn_v2.h ================================================ #pragma once #include "cpu/vision.h" #ifdef WITH_CUDA #include "cuda/vision.h" #endif at::Tensor dcn_v2_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcn_v2_cuda_forward(input, weight, bias, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, deformable_group); #else AT_ERROR("Not compiled with GPU support"); #endif } else{ return dcn_v2_cpu_forward(input, weight, bias, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, deformable_group); } } std::vector dcn_v2_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcn_v2_cuda_backward(input, weight, bias, offset, mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, deformable_group); #else AT_ERROR("Not compiled with GPU support"); #endif } else{ return dcn_v2_cpu_backward(input, weight, bias, offset, mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, deformable_group); } } std::tuple dcn_v2_psroi_pooling_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcn_v2_psroi_pooling_cuda_forward(input, bbox, trans, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); #else AT_ERROR("Not compiled with GPU support"); #endif } else{ return dcn_v2_psroi_pooling_cpu_forward(input, bbox, trans, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); } } std::tuple dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcn_v2_psroi_pooling_cuda_backward(out_grad, input, bbox, trans, top_count, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); #else AT_ERROR("Not compiled with GPU support"); #endif } else{ return dcn_v2_psroi_pooling_cpu_backward(out_grad, input, bbox, trans, top_count, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); } } ================================================ FILE: code/real/bsrt/model/DCNv2/src/vision.cpp ================================================ #include "dcn_v2.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); } ================================================ FILE: code/real/bsrt/model/DCNv2/test.py ================================================ #!/usr/bin/env python from __future__ import absolute_import from __future__ import print_function from __future__ import division import time import torch import torch.nn as nn from torch.autograd import gradcheck from dcn_v2 import dcn_v2_conv, DCNv2, DCN from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling deformable_groups = 1 N, inC, inH, inW = 2, 2, 4, 4 outC = 2 kH, kW = 3, 3 def conv_identify(weight, bias): weight.data.zero_() bias.data.zero_() o, i, h, w = weight.shape y = h//2 x = w//2 for p in range(i): for q in range(o): if p == q: weight.data[q, p, y, x] = 1.0 def check_zero_offset(): conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, kernel_size=(kH, kW), stride=(1, 1), padding=(1, 1), bias=True).cuda() conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, kernel_size=(kH, kW), stride=(1, 1), padding=(1, 1), bias=True).cuda() dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda() conv_offset.weight.data.zero_() conv_offset.bias.data.zero_() conv_mask.weight.data.zero_() conv_mask.bias.data.zero_() conv_identify(dcn_v2.weight, dcn_v2.bias) input = torch.randn(N, inC, inH, inW).cuda() offset = conv_offset(input) mask = conv_mask(input) mask = torch.sigmoid(mask) output = dcn_v2(input, offset, mask) output *= 2 d = (input - output).abs().max() if d < 1e-10: print('Zero offset passed') else: print('Zero offset failed') print(input) print(output) def check_gradient_dconv(): input = torch.rand(N, inC, inH, inW).cuda() * 0.01 input.requires_grad = True offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2 # offset.data.zero_() # offset.data -= 0.5 offset.requires_grad = True mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() # mask.data.zero_() mask.requires_grad = True mask = torch.sigmoid(mask) weight = torch.randn(outC, inC, kH, kW).cuda() weight.requires_grad = True bias = torch.rand(outC).cuda() bias.requires_grad = True stride = 1 padding = 1 dilation = 1 print('check_gradient_dconv: ', gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups), eps=1e-3, atol=1e-4, rtol=1e-2)) def check_pooling_zero_offset(): input = torch.randn(2, 16, 64, 64).cuda().zero_() input[0, :, 16:26, 16:26] = 1. input[1, :, 10:20, 20:30] = 2. rois = torch.tensor([ [0, 65, 65, 103, 103], [1, 81, 41, 119, 79], ]).cuda().float() pooling = DCNv2Pooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=16, no_trans=True, group_size=1, trans_std=0.0).cuda() out = pooling(input, rois, input.new()) s = ', '.join(['%f' % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) print(s) dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=16, no_trans=False, group_size=1, trans_std=0.0).cuda() offset = torch.randn(20, 2, 7, 7).cuda().zero_() dout = dpooling(input, rois, offset) s = ', '.join(['%f' % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) print(s) def check_gradient_dpooling(): input = torch.randn(2, 3, 5, 5).cuda() * 0.01 N = 4 batch_inds = torch.randint(2, (N, 1)).cuda().float() x = torch.rand((N, 1)).cuda().float() * 15 y = torch.rand((N, 1)).cuda().float() * 15 w = torch.rand((N, 1)).cuda().float() * 10 h = torch.rand((N, 1)).cuda().float() * 10 rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) offset = torch.randn(N, 2, 3, 3).cuda() input.requires_grad = True offset.requires_grad = True spatial_scale = 1.0 / 4 pooled_size = 3 output_dim = 3 no_trans = 0 group_size = 1 trans_std = 0.0 sample_per_part = 4 part_size = pooled_size print('check_gradient_dpooling:', gradcheck(dcn_v2_pooling, (input, rois, offset, spatial_scale, pooled_size, output_dim, no_trans, group_size, part_size, sample_per_part, trans_std), eps=1e-4)) def example_dconv(): input = torch.randn(2, 64, 128, 128).cuda() # wrap all things (offset and mask) in DCN dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda() # print(dcn.weight.shape, input.shape) output = dcn(input) targert = output.new(*output.size()) targert.data.uniform_(-0.01, 0.01) error = (targert - output).mean() error.backward() print(output.shape) def example_dpooling(): input = torch.randn(2, 32, 64, 64).cuda() batch_inds = torch.randint(2, (20, 1)).cuda().float() x = torch.randint(256, (20, 1)).cuda().float() y = torch.randint(256, (20, 1)).cuda().float() w = torch.randint(64, (20, 1)).cuda().float() h = torch.randint(64, (20, 1)).cuda().float() rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) offset = torch.randn(20, 2, 7, 7).cuda() input.requires_grad = True offset.requires_grad = True # normal roi_align pooling = DCNv2Pooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=32, no_trans=True, group_size=1, trans_std=0.1).cuda() # deformable pooling dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=32, no_trans=False, group_size=1, trans_std=0.1).cuda() out = pooling(input, rois, offset) dout = dpooling(input, rois, offset) print(out.shape) print(dout.shape) target_out = out.new(*out.size()) target_out.data.uniform_(-0.01, 0.01) target_dout = dout.new(*dout.size()) target_dout.data.uniform_(-0.01, 0.01) e = (target_out - out).mean() e.backward() e = (target_dout - dout).mean() e.backward() def example_mdpooling(): input = torch.randn(2, 32, 64, 64).cuda() input.requires_grad = True batch_inds = torch.randint(2, (20, 1)).cuda().float() x = torch.randint(256, (20, 1)).cuda().float() y = torch.randint(256, (20, 1)).cuda().float() w = torch.randint(64, (20, 1)).cuda().float() h = torch.randint(64, (20, 1)).cuda().float() rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) # mdformable pooling (V2) dpooling = DCNPooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=32, no_trans=False, group_size=1, trans_std=0.1, deform_fc_dim=1024).cuda() dout = dpooling(input, rois) target = dout.new(*dout.size()) target.data.uniform_(-0.1, 0.1) error = (target - dout).mean() error.backward() print(dout.shape) if __name__ == '__main__': example_dconv() example_dpooling() example_mdpooling() check_pooling_zero_offset() # zero offset check if inC == outC: check_zero_offset() check_gradient_dpooling() check_gradient_dconv() # """ # ****** Note: backward is not reentrant error may not be a serious problem, # ****** since the max error is less than 1e-7, # ****** Still looking for what trigger this problem # """ ================================================ FILE: code/real/bsrt/model/__init__.py ================================================ import os from importlib import import_module import torch import torch.nn as nn import torch.nn.parallel as P import torch.utils.model_zoo import time class Model(nn.Module): def __init__(self, args, ckp): super(Model, self).__init__() self.args = args if args.local_rank == 0: print("Making model: ", args.model) print("Patch size: ", args.patch_size) self.scale = args.scale self.idx_scale = 0 self.input_large = (args.model == 'VDSR') self.self_ensemble = args.self_ensemble self.chop = args.chop self.precision = args.precision self.cpu = args.cpu self.device = torch.device('cpu' if args.cpu else 'cuda:%d' % args.local_rank) self.n_GPUs = args.n_GPUs self.save_models = args.save_models module = import_module('model.' + args.model.lower()) self.model = module.make_model(args).to(self.device) if args.precision == 'half': self.model.half() self.load( ckp.get_path('model'), pre_train=args.pre_train, resume=args.resume, cpu=args.cpu ) # time.sleep(3) if args.n_GPUs > 1: self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank], find_unused_parameters=True ) print(self.model, file=ckp.log_file) def forward(self, x, idx_scale): self.idx_scale = idx_scale if hasattr(self.model, 'set_scale'): self.model.set_scale(idx_scale) if self.training: # if self.n_GPUs > 1: return self.model(x) else: if self.chop: forward_function = self.forward_chop else: forward_function = self.model.forward if self.self_ensemble: return self.forward_x8(x, forward_function=forward_function) else: # return self.model(x) return forward_function(x) def save(self, apath, epoch, is_best=False): save_dirs = [os.path.join(apath, 'model_latest.pt')] if is_best: save_dirs.append(os.path.join(apath, 'model_best.pt')) if self.save_models: save_dirs.append( os.path.join(apath, 'model_{}.pt'.format(epoch)) ) if self.n_GPUs > 1: model = self.model.module else: model = self.model for s in save_dirs: torch.save(self.model.state_dict(), s) def load(self, apath, pre_train='', resume=-1, cpu=False): load_from = None kwargs = {} if cpu: kwargs = {'map_location': lambda storage, loc: storage} if resume == -1: load_from = torch.load( os.path.join(apath, 'model_latest.pt'), **kwargs ) elif resume == 0: if pre_train == 'download': print('Download the model') dir_model = os.path.join('..', 'models') os.makedirs(dir_model, exist_ok=True) load_from = torch.utils.model_zoo.load_url( self.model.url, model_dir=dir_model, **kwargs ) elif pre_train: if self.args.local_rank == 0: print('Load the model from {}'.format(pre_train)) map_location = {'cuda:%d' % 0: 'cuda:%d' % self.args.local_rank} load_from = torch.load(pre_train, map_location=map_location) # print(load_from.keys()) else: load_from = torch.load( os.path.join(apath, 'model_{}.pt'.format(resume)), **kwargs ) if load_from: self.model.load_state_dict(load_from, strict=True) del load_from if self.args.finetune: if self.args.local_rank == 0: print('finetune') for param in self.model.parameters(): param.requires_grad = False for param in self.model.HRconv.parameters(): param.requires_grad = True for param in self.model.conv_last.parameters(): param.requires_grad = True if self.args.finetune_prelayer: if self.args.local_rank == 0: print('finetune_prelayer') if self.args.swinfeature: if self.args.model == 'MBSRT': for param in self.model.pre_layer1.parameters(): param.requires_grad = True for param in self.model.pre_layer2.parameters(): param.requires_grad = True else: for param in self.model.pre_layers.parameters(): param.requires_grad = True else: for param in self.model.feature_extraction.parameters(): param.requires_grad = True for param in self.model.conv_after_pre_layer.parameters(): param.requires_grad = True if self.args.finetune_align: if self.args.local_rank == 0: print('finetune_align') for param in self.model.align.parameters(): param.requires_grad = True if self.args.finetune_spynet: if self.args.local_rank == 0: print('finetune_spynet') for param in self.model.spynet.parameters(): param.requires_grad = True if self.args.finetune_swin: if self.args.local_rank == 0: print('finetune_swin') for param in self.model.layers.parameters(): param.requires_grad = True for param in self.model.conv_after_body.parameters(): param.requires_grad = True if self.args.finetune_upconv: if self.args.local_rank == 0: print('finetune_upconv') for param in self.model.upconv1.parameters(): param.requires_grad = True for param in self.model.upconv2.parameters(): param.requires_grad = True for param in self.model.skipup1.parameters(): param.requires_grad = True for param in self.model.skipup2.parameters(): param.requires_grad = True if self.args.finetune_conv: if self.args.local_rank == 0: print('finetune_conv') # for param in self.model.conv_first.parameters(): # param.requires_grad = True # for param in self.model.conv_flow.parameters(): # param.requires_grad = True # for param in self.model.fea_L2_conv1.parameters(): # param.requires_grad = True # for param in self.model.fea_L3_conv1.parameters(): # param.requires_grad = True # for param in self.model.toplayer.parameters(): # param.requires_grad = True # for param in self.model.smooth1.parameters(): # param.requires_grad = True # for param in self.model.smooth2.parameters(): # param.requires_grad = True # for param in self.model.latlayer1.parameters(): # param.requires_grad = True # for param in self.model.latlayer2.parameters(): # param.requires_grad = True # for param in self.model.fusion.parameters(): # param.requires_grad = True # for param in self.model.conv_after_pre_layer.parameters(): # param.requires_grad = True for param in self.model.conv_after_body.parameters(): param.requires_grad = True def forward_chop(self, *args, shave=10, min_size=160000): scale = 1 if self.input_large else self.scale[self.idx_scale] n_GPUs = min(self.n_GPUs, 4) # height, width h, w = args[0].size()[-2:] top = slice(0, h//2 + shave) bottom = slice(h - h//2 - shave, h) left = slice(0, w//2 + shave) right = slice(w - w//2 - shave, w) x_chops = [torch.cat([ a[..., top, left], a[..., top, right], a[..., bottom, left], a[..., bottom, right] ]) for a in args] y_chops = [] if h * w < 4 * min_size: for i in range(0, 4, n_GPUs): x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] y = P.data_parallel(self.model, *x, range(n_GPUs)) if not isinstance(y, list): y = [y] if not y_chops: y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] else: for y_chop, _y in zip(y_chops, y): y_chop.extend(_y.chunk(n_GPUs, dim=0)) else: for p in zip(*x_chops): y = self.forward_chop(*p, shave=shave, min_size=min_size) if not isinstance(y, list): y = [y] if not y_chops: y_chops = [[_y] for _y in y] else: for y_chop, _y in zip(y_chops, y): y_chop.append(_y) h *= scale w *= scale top = slice(0, h//2) bottom = slice(h - h//2, h) bottom_r = slice(h//2 - h, None) left = slice(0, w//2) right = slice(w - w//2, w) right_r = slice(w//2 - w, None) # batch size, number of color channels b, c = y_chops[0][0].size()[:-2] y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] for y_chop, _y in zip(y_chops, y): _y[..., top, left] = y_chop[0][..., top, left] _y[..., top, right] = y_chop[1][..., top, right_r] _y[..., bottom, left] = y_chop[2][..., bottom_r, left] _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] if len(y) == 1: y = y[0] return y def forward_x8(self, *args, forward_function=None): def _transform(v, op): if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) if self.precision == 'half': ret = ret.half() return ret list_x = [] for a in args: x = [a] for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) list_x.append(x) list_y = [] for x in zip(*list_x): y = forward_function(*x) if not isinstance(y, list): y = [y] if not list_y: list_y = [[_y] for _y in y] else: for _list_y, _y in zip(list_y, y): _list_y.append(_y) for _list_y in list_y: for i in range(len(_list_y)): if i > 3: _list_y[i] = _transform(_list_y[i], 't') if i % 4 > 1: _list_y[i] = _transform(_list_y[i], 'h') if (i % 4) % 2 == 1: _list_y[i] = _transform(_list_y[i], 'v') y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] if len(y) == 1: y = y[0] return y ================================================ FILE: code/real/bsrt/model/arch_util.py ================================================ import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F from model import common from model.utils.psconv import PSGConv2d as PSConv2d, PyConv2d def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) ########################### def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0): return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True) class ESA(nn.Module): def __init__(self, n_feats, conv=conv_layer): super(ESA, self).__init__() f = n_feats // 4 self.conv1 = conv(n_feats, f, kernel_size=1) self.conv_f = conv(f, f, kernel_size=1) self.conv_max = conv(f, f, kernel_size=3, padding=1) self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) self.conv3 = conv(f, f, kernel_size=3, padding=1) self.conv3_ = conv(f, f, kernel_size=3, padding=1) self.conv4 = conv(f, n_feats, kernel_size=1) self.sigmoid = nn.Sigmoid() self.relu = nn.ReLU(inplace=True) def forward(self, x): c1_ = (self.conv1(x)) c1 = self.conv2(c1_) v_max = F.max_pool2d(c1, kernel_size=7, stride=3) v_range = self.relu(self.conv_max(v_max)) c3 = self.relu(self.conv3(v_range)) c3 = self.conv3_(c3) c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False) cf = self.conv_f(c1_) c4 = self.conv4(c3+cf) m = self.sigmoid(c4) return x * m class DWConv(nn.Module): def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, x): x = self.dwconv(x) return x ########################## class SELayer(nn.Module): ''' SE-block ''' def __init__(self, channel, reduction=16): super(SELayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), # nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) class ResidualBlock_noBN(nn.Module): '''Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| ''' def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out class ResidualBlock_SE(nn.Module): '''Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| ''' def __init__(self, nf=64, reduction=16): super(ResidualBlock_SE, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv3 = nn.Conv2d(3 * nf, nf, 1, padding=0, dilation=1, bias=True) self.se = SELayer(nf, reduction) # initialization initialize_weights([self.conv1, self.conv2, self.conv3], 0.1) def forward(self, x): identity = x basic_out = F.relu(self.conv1(x), inplace=True) basic_out = self.conv2(basic_out) se_out = self.se(basic_out) out = torch.cat((identity, basic_out, se_out), 1) out = self.conv3(out) return out class _PositionAttentionModule(nn.Module): """ Position attention module""" def __init__(self, in_channels, **kwargs): super(_PositionAttentionModule, self).__init__() self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1) self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1) self.conv_d = nn.Conv2d(in_channels, in_channels, 1) self.alpha = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self, x): batch_size, _, height, width = x.size() feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1) feat_c = self.conv_c(x).view(batch_size, -1, height * width) attention_s = self.softmax(torch.bmm(feat_b, feat_c)) feat_d = self.conv_d(x).view(batch_size, -1, height * width) feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width) out = self.alpha * feat_e + x return out ## Spatial Attention (CA) Layer class SALayer(nn.Module): def __init__(self, wn=None): super(SALayer,self).__init__() self.body = nn.Sequential( wn(nn.Conv2d(2, 1, 7, 1, 3, bias=False)), nn.Sigmoid() ) def forward(self, x): avg_f = torch.mean(x, dim=1, keepdim=True) max_f = torch.max(x, dim=1, keepdim=True)[0] y = torch.cat([avg_f, max_f], dim=1) return self.body(y).expand_as(x) * x ## Channel Attention (CA) Layer class CALayerV2(nn.Module): def __init__(self, n_feat, reduction=16, wn=None): super(CALayerV2, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( wn(nn.Conv2d(n_feat, n_feat//reduction, 1, padding=0, bias=False)), nn.ReLU(inplace=True), wn(nn.Conv2d(n_feat//reduction, n_feat, 1, padding=0, bias=False)), # nn.Sigmoid() ) def forward(self, x): y1 = self.avg_pool(x) y2 = self.max_pool(x) y1 = self.conv_du(y1) y2 = self.conv_du(y2) return x * torch.sigmoid(y1+y2) class DALayer(nn.Module): def __init__(self, channel, reduction, wn): super(DALayer, self).__init__() # global average pooling: feature --> point self.ca = CALayer(channel, reduction, wn) self.sa = SALayer(wn) self.conv = wn(nn.Conv2d(channel*2, channel, 1)) def forward(self, x): ca = self.ca(x) sa = self.sa(x) res = self.conv(torch.cat([ca, sa], dim=1)) return res + x ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction, wn): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( wn(nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True)), nn.ReLU(inplace=True), wn(nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True)), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False): super(RCAB, self).__init__() expand = 6 linear = 0.75 modules_body = [] # for i in range(2): modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias))) modules_body.append(act) modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias))) modules_body.append(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias)) if da: modules_body.append(DALayer(n_feat, reduction, wn)) else: modules_body.append(CALayer(n_feat, reduction, wn)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) #res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False): super(ResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = common.default_conv wn = lambda x: torch.nn.utils.weight_norm(x) modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \ for _ in range(n_resblocks)] modules_body.append(wn(conv(n_feat, n_feat, kernel_size))) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ################################################################ ################################################################ ################################################################ def make_layer_idx(block, n_layers): layers = [] for i in range(n_layers): layers.append(block(idx=i)) return nn.Sequential(*layers) ## Residual Channel Attention Block (RCAB) class LRSCRCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0): super(LRSCRCAB, self).__init__() expand = 6 linear = 0.75 modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else [] # for i in range(2): modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias))) modules_body.append(act) modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias))) modules_body.append(wn(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias))) if da: modules_body.append(DALayer(n_feat, reduction, wn)) else: modules_body.append(CALayer(n_feat, reduction, wn)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) res = torch.cat([res, x], dim=1) return res ## Residual Channel Attention Block (RCAB) class LRSCPYRCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0): super(LRSCPYRCAB, self).__init__() expand = 6 linear = 0.75 modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else [] # for i in range(2): modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias))) modules_body.append(act) modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias))) modules_body.append( PyConv2d(in_channels=int(n_feat*linear), out_channels=[n_feat//4, n_feat//4, n_feat//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])) if da: modules_body.append(DALayer(n_feat, reduction, wn)) else: modules_body.append(CALayer(n_feat, reduction, wn)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False, idx=0): super(LRSCResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = common.default_conv wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(conv(n_feat*(idx+1), n_feat, 1, bias=True))] if idx > 0 else [] modules_body = [ LRSCRCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \ for i in range(n_resblocks)] modules_body.append(wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size))) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.head(x) res = self.body(res) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCPSResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False, idx=0): super(LRSCPSResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = PSConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else [] modules_body = [ LRSCRCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \ for i in range(n_resblocks)] modules_tail = [wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size))] self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): res = self.head(x) res = self.body(res) res = self.tail(res) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCPyResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False, idx=0): super(LRSCPyResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = PyConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else [] modules_body = [ LRSCPYRCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \ for i in range(n_resblocks)] modules_tail = [wn(nn.Conv2d(n_feat*(n_resblocks+1), n_feat, 1))] self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): res = self.head(x) res = self.body(res) res = self.tail(res) res = torch.cat([res, x], dim=1) return res class LRSCWideActResBlock(nn.Module): def __init__(self, nf=64, idx=0): super(LRSCWideActResBlock, self).__init__() self.res_scale = 1 expand = 6 linear = 0.8 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else [] body = [] body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2))) self.head = nn.Sequential(*head) self.body = nn.Sequential(*body) def forward(self, x): res = self.head(x) res = self.body(res) res = torch.cat([res, x], dim=1) return res class LRSCPyWideActResBlock(nn.Module): def __init__(self, nf=64, idx=0): super(LRSCPyWideActResBlock, self).__init__() self.res_scale = 1 expand = 6 linear = 0.75 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else [] body = [] body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( PyConv2d(in_channels=int(nf*linear), out_channels=[nf//4, nf//4, nf//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])) self.head = nn.Sequential(*head) self.body = nn.Sequential(*body) def forward(self, x): res = self.head(x) res = self.body(res) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCPyWideActResGroup(nn.Module): def __init__(self, nf, n_resblocks, idx=0): super(LRSCPyWideActResGroup, self).__init__() kernel_size = 3 conv = PyConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else [] modules_body = [ LRSCPyWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)] modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))] self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): res = self.head(x) res = self.body(res) res = self.tail(res) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCWideActResGroup(nn.Module): def __init__(self, nf, n_resblocks, idx=0): super(LRSCWideActResGroup, self).__init__() kernel_size = 3 conv = PyConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else [] modules_body = [ LRSCWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)] modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))] self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): res = self.head(x) res = self.body(res) res = self.tail(res) res = torch.cat([res, x], dim=1) return res ################################################################ ################################################################ ################################################################ ## Residual Channel Attention Block (RCAB) class PYRCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False): super(PYRCAB, self).__init__() expand = 6 linear = 0.75 modules_body = [] # for i in range(2): modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias))) modules_body.append(act) modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias))) # modules_body.append(conv(, n_feat, kernel_size, bias=bias)) modules_body.append(PyConv2d(in_channels=int(n_feat*linear), out_channels=[n_feat//4, n_feat//4, n_feat//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8], bias=bias)) if da: modules_body.append(DALayer(n_feat, reduction, wn)) else: modules_body.append(CALayer(n_feat, reduction, wn)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) res += x return res ## Residual Group (RG) class PyResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False): super(PyResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = PyConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_body = [] modules_body = [ PYRCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \ for _ in range(n_resblocks)] modules_body.append( PyConv2d(in_channels=n_feat, out_channels=[n_feat//4, n_feat//4, n_feat//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res class WideActResBlock(nn.Module): def __init__(self, nf=64): super(WideActResBlock, self).__init__() self.res_scale = 1 body = [] expand = 6 linear = 0.8 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2))) self.body = nn.Sequential(*body) def forward(self, x): res = self.body(x) * self.res_scale res += x return res class PSWideActResBlock(nn.Module): def __init__(self, nf=64): super(PSWideActResBlock, self).__init__() self.res_scale = 1 body = [] expand = 6 linear = 0.75 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( wn(PSConv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2))) self.body = nn.Sequential(*body) def forward(self, x): res = self.body(x) * self.res_scale res += x return res class PyWideActResBlock(nn.Module): def __init__(self, nf=64): super(PyWideActResBlock, self).__init__() self.res_scale = 1 body = [] expand = 6 linear = 0.75 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) expand_nf = nf*expand linear_nf = int(nf * linear) body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( PyConv2d(in_channels=linear_nf, out_channels=[nf//4, nf//4, nf//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])) self.body = nn.Sequential(*body) def forward(self, x): res = self.body(x) * self.res_scale res += x return res def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True, use_pad_mask=False): """Warp an image or feature map with optical flow. Args: x (Tensor): Tensor with size (n, c, h, w). flow (Tensor): Tensor with size (n, h, w, 2), normal value. interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. padding_mode (str): 'zeros' or 'border' or 'reflection'. Default: 'zeros'. align_corners (bool): Before pytorch 1.3, the default value is align_corners=True. After pytorch 1.3, the default value is align_corners=False. Here, we use the True as default. use_pad_mask (bool): only used for PWCNet, x is first padded with ones along the channel dimension. The mask is generated according to the grid_sample results of the padded dimension. Returns: Tensor: Warped image or feature map. """ # assert x.size()[-2:] == flow.size()[1:3] # temporaily turned off for image-wise shift n, _, h, w = x.size() x = x.float() # create mesh grid # grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) # an illegal memory access on TITAN RTX + PyTorch1.9.1 grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), torch.arange(0, w, dtype=x.dtype, device=x.device)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # if use_pad_mask: # for PWCNet # x = F.pad(x, (0,0,0,0,0,1), mode='constant', value=1) # scale grid to [-1,1] if interp_mode == 'nearest4': # todo: bug, no gradient for flow model in this case!!! but the result is good vgrid_x_floor = 2.0 * torch.floor(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0 vgrid_x_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0 vgrid_y_floor = 2.0 * torch.floor(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0 vgrid_y_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0 output00 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners) output01 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners) output10 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners) output11 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners) return torch.cat([output00, output01, output10, output11], 1) else: vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) # if use_pad_mask: # for PWCNet # output = _flow_warp_masking(output) # TODO, what if align_corners=False return output # def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): # """Warp an image or feature map with optical flow # Args: # x (Tensor): size (N, C, H, W) # flow (Tensor): size (N, H, W, 2), normal value # interp_mode (str): 'nearest' or 'bilinear' # padding_mode (str): 'zeros' or 'border' or 'reflection' # Returns: # Tensor: warped image or feature map # """ # assert x.size()[-2:] == flow.size()[1:3] # B, C, H, W = x.size() # # mesh grid # grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) # grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 # grid.requires_grad = False # grid = grid.type_as(x) # vgrid = grid + flow # # scale grid to [-1,1] # vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 # vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 # vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) # output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) # return output ================================================ FILE: code/real/bsrt/model/bsrt.py ================================================ import functools import torch import torch.nn as nn import torch.nn.functional as F import model.arch_util as arch_util from torch.cuda.amp import autocast import model.swin_util as swu import time import os import math from utils.debayer import Debayer3x3 import torchvision.utils as tvutils from datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch try: from model.non_local.non_local_cross_dot_product import NONLocalBlock2D as NonLocalCross from model.non_local.non_local_dot_product import NONLocalBlock2D as NonLocal except ImportError: raise ImportError('Failed to import Non_Local module.') try: from model.DCNv2.dcn_v2 import DCN_sep as DCN, FlowGuidedDCN, InsideFlowGuidedDCN except ImportError: raise ImportError('Failed to import DCNv2 module.') def make_model(args, parent=False): nframes = args.burst_size img_size = args.patch_size * 2 patch_size = 1 in_chans = args.burst_channel out_chans = args.n_colors if args.model_level == "S": depths = [6]*1 + [6] * 4 num_heads = [6]*1 + [6] * 4 embed_dim = 60 elif args.model_level == "L": depths = [6]*1 + [8] * 6 num_heads = [6]*1 + [6] * 6 embed_dim = 180 window_size = 8 mlp_ratio = 2 upscale = args.scale[0] non_local = args.non_local use_checkpoint=args.use_checkpoint if args.local_rank <= 0: print("depths: ", depths) return BSRT(args=args,nframes=nframes, img_size=img_size, patch_size=patch_size, in_chans=in_chans, out_chans=out_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, upscale=upscale, non_local=non_local, use_checkpoint=use_checkpoint) class BasicModule(nn.Module): """Basic Module for SpyNet. """ def __init__(self): super(BasicModule, self).__init__() self.basic_module = nn.Sequential( nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) def forward(self, tensor_input): return self.basic_module(tensor_input) class SpyNet(nn.Module): """SpyNet architecture. Args: load_path (str): path for pretrained SpyNet. Default: None. return_levels (list[int]): return flows of different levels. Default: [5]. """ def __init__(self, load_path=None, return_levels=[5]): super(SpyNet, self).__init__() self.return_levels = return_levels self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) if load_path: if not os.path.exists(load_path): import requests url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/spynet_sintel_final-3d2a1287.pth' r = requests.get(url, allow_redirects=True) print(f'downloading SpyNet pretrained model from {url}') os.makedirs(os.path.dirname(load_path), exist_ok=True) open(load_path, 'wb').write(r.content) self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) def preprocess(self, tensor_input): tensor_output = (tensor_input - self.mean) / self.std return tensor_output def process(self, ref, supp, w, h, w_floor, h_floor): flow_list = [] ref = [self.preprocess(ref)] supp = [self.preprocess(supp)] # ref = [ref] # supp = [supp] for level in range(5): ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) flow = ref[0].new_zeros( [ref[0].size(0), 2, int(math.floor(ref[0].size(2) / 2.0)), int(math.floor(ref[0].size(3) / 2.0))]) for level in range(len(ref)): upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 if upsampled_flow.size(2) != ref[level].size(2): upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') if upsampled_flow.size(3) != ref[level].size(3): upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') flow = self.basic_module[level](torch.cat([ ref[level], arch_util.flow_warp( supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), upsampled_flow ], 1)) + upsampled_flow if level in self.return_levels: scale = 2**(5-level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8) flow_out = F.interpolate(input=flow, size=(h//scale, w//scale), mode='bilinear', align_corners=False) flow_out[:, 0, :, :] *= float(w//scale) / float(w_floor//scale) flow_out[:, 1, :, :] *= float(h//scale) / float(h_floor//scale) if torch.abs(flow_out).mean() > 200: print(f"level {level}, flow > 200: {torch.abs(flow_out).mean():.4f}") # return None flow_out.clamp(-250, 250) flow_list.insert(0, flow_out) return flow_list def forward(self, ref, supp): assert ref.size() == supp.size() h, w = ref.size(2), ref.size(3) w_floor = math.floor(math.ceil(w / 32.0) * 32.0) h_floor = math.floor(math.ceil(h / 32.0) * 32.0) ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) flow_list = self.process(ref, supp, w, h, w_floor, h_floor) return flow_list[0] if len(flow_list) == 1 else flow_list class FlowGuidedPCDAlign(nn.Module): ''' Alignment module using Pyramid, Cascading and Deformable convolution with 3 pyramid levels. [From EDVR] ''' def __init__(self, nf=64, groups=8): super(FlowGuidedPCDAlign, self).__init__() # L3: level 3, 1/4 spatial size self.L3_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L3_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) # L2: level 2, 1/2 spatial size self.L2_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L2_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # L1: level 1, original spatial size self.L1_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L1_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # Cascading DCN self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l): '''align other neighboring frames to the reference frame in the feature level nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features ''' # L3 L3_offset = torch.cat([nbr_fea_warped_l[2], ref_fea_l[2], flows_l[2]], dim=1) L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset)) L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset)) L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset, flows_l[2])) # L2 L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False) L2_offset = torch.cat([nbr_fea_warped_l[1], ref_fea_l[1], flows_l[1]], dim=1) L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset)) L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset*2], dim=1))) L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset)) L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset, flows_l[1]) L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False) L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1))) # L1 L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False) L1_offset = torch.cat([nbr_fea_warped_l[0], ref_fea_l[0], flows_l[0]], dim=1) L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset)) L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1))) L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset)) L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset, flows_l[0]) L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False) L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1)) # Cascading offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1) offset = self.lrelu(self.cas_offset_conv1(offset)) offset = self.lrelu(self.cas_offset_conv2(offset)) L1_fea = self.cas_dcnpack(L1_fea, offset) return L1_fea class CrossNonLocal_Fusion(nn.Module): ''' Cross Non Local fusion module ''' def __init__(self, nf=64, out_feat=96, nframes=5, center=2): super(CrossNonLocal_Fusion, self).__init__() self.center = center self.non_local_T = nn.ModuleList() self.non_local_F = nn.ModuleList() for i in range(nframes): self.non_local_T.append(NonLocalCross(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False)) self.non_local_F.append(NonLocal(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False)) # fusion conv: using 1x1 to save parameters and computation self.fea_fusion = nn.Conv2d(nframes * nf*2, out_feat, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, aligned_fea): B, N, C, H, W = aligned_fea.size() # N video frames ref = aligned_fea[:, self.center, :, :, :].clone() cor_l = [] non_l = [] for i in range(N): nbr = aligned_fea[:, i, :, :, :] non_l.append(self.non_local_F[i](nbr)) cor_l.append(self.non_local_T[i](nbr, ref)) aligned_fea_T = torch.cat(cor_l, dim=1) aligned_fea_F = torch.cat(non_l, dim=1) aligned_fea = torch.cat([aligned_fea_T, aligned_fea_F], dim=1) #### fusion fea = self.fea_fusion(aligned_fea) return fea class BSRT(nn.Module): def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chans=3, out_chans=3, embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=4, non_local=False, **kwargs): super(BSRT, self).__init__() num_in_ch = in_chans num_out_ch = out_chans num_feat = 64 groups = 8 # embed_dim = num_feat back_RBs = 5 n_resblocks = 6 self.args = args self.center = 0 self.upscale = upscale self.window_size = window_size self.non_local = non_local self.nframes = nframes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = embed_dim self.mlp_ratio = mlp_ratio spynet_path='/home/luoziwei/.pretrained_models/spynet_sintel_final-3d2a1287.pth' self.spynet = SpyNet(spynet_path, [3, 4, 5]) self.conv_flow = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1) self.flow_ps = nn.PixelShuffle(2) # split image into non-overlapping patches self.patch_embed = swu.PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # merge non-overlapping patches into image self.patch_unembed = swu.PatchUnEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) ##################################################################################################### ################################### 1, shallow feature extraction ################################### self.conv_first = nn.Conv2d(num_in_ch*(1+2*0), embed_dim, 3, 1, 1, bias=True) # # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule if args.swinfeature: if self.args.local_rank <= 0: print("using swinfeature") self.pre_layers = nn.ModuleList() for i_layer in range(depths[0]): layer = swu.SwinTransformerBlock(dim=embed_dim, input_resolution=(patches_resolution[0]//2, patches_resolution[1]//2), num_heads=num_heads[0], window_size=window_size, shift_size=0 if (i_layer % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i_layer], norm_layer=norm_layer) self.pre_layers.append(layer) self.pre_norm = norm_layer(embed_dim) else: WARB = functools.partial(arch_util.WideActResBlock, nf=embed_dim) self.feature_extraction = arch_util.make_layer(WARB, 5) self.conv_after_pre_layer = nn.Conv2d(embed_dim, num_feat*4, 3, 1, 1, bias=True) self.mid_ps = nn.PixelShuffle(2) self.fea_L2_conv1 = nn.Conv2d(num_feat, num_feat*2, 3, 2, 1, bias=True) self.fea_L3_conv1 = nn.Conv2d(num_feat*2, num_feat*4, 3, 2, 1, bias=True) ##################################################################################################### ################################### 2, Feature Enhanced PCD Align ################################### # Top layers self.toplayer = nn.Conv2d(num_feat*4, num_feat, kernel_size=1, stride=1, padding=0) # Smooth layers self.smooth1 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1) self.smooth2 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1) # Lateral layers self.latlayer1 = nn.Conv2d(num_feat*2, num_feat, kernel_size=1, stride=1, padding=0) self.latlayer2 = nn.Conv2d(num_feat*1, num_feat, kernel_size=1, stride=1, padding=0) # self.align = PCD_Align(nf=num_feat, groups=groups) self.align = FlowGuidedPCDAlign(nf=num_feat, groups=groups) ##################################################################################################### ################################### 3, Multi-frame Feature Fusion ################################## if self.non_local: if self.args.local_rank <= 0: print("using non_local") self.fusion = CrossNonLocal_Fusion(nf=num_feat, out_feat=embed_dim, nframes=nframes, center=self.center) else: self.fusion = nn.Conv2d(nframes * num_feat, embed_dim, 1, 1, bias=True) ##################################################################################################### ################################### 4, deep feature extraction ###################################### # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) swu.trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # build Residual Swin Transformer blocks (RSTB) self.layers = nn.ModuleList() for i_layer in range(1, self.num_layers): layer = swu.RSTB(dim=embed_dim, input_resolution=(patches_resolution[0], patches_resolution[1]), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results norm_layer=norm_layer, downsample=None, use_checkpoint=use_checkpoint, img_size=img_size, patch_size=patch_size ) self.layers.append(layer) self.norm = norm_layer(self.num_features) # build the last conv layer in deep feature extraction self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) ##################################################################################################### ################################ 5, high quality image reconstruction ################################ self.upconv1 = nn.Conv2d(embed_dim, num_feat * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(64, args.n_colors, 3, 1, 1, bias=True) #### skip ############# self.skip_pixel_shuffle = nn.PixelShuffle(2) self.skipup1 = nn.Conv2d(num_in_ch//4, num_feat * 4, 3, 1, 1, bias=True) self.skipup2 = nn.Conv2d(num_feat, args.n_colors * 4, 3, 1, 1, bias=True) #### activation function self.lrelu = nn.LeakyReLU(0.1, inplace=True) self.lrelu2 = nn.LeakyReLU(0.1, inplace=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): swu.trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def _upsample_add(self, x, y): return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + y def check_image_size(self, x): _, _, h, w = x.size() mod_pad_h = (self.window_size - h % self.window_size) % self.window_size mod_pad_w = (self.window_size - w % self.window_size) % self.window_size x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') return x def pre_forward_features(self, x): if self.args.swinfeature: x_size = (x.shape[-2], x.shape[-1]) x = self.patch_embed(x, use_norm=True) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for idx, layer in enumerate(self.pre_layers): x = layer(x, x_size) x = self.pre_norm(x) x = self.patch_unembed(x, x_size) else: x = self.feature_extraction(x) return x def forward_features(self, x): x_size = (x.shape[-2], x.shape[-1]) x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for idx, layer in enumerate(self.layers): x = layer(x, x_size) if torch.any(torch.isinf(x)) or torch.any(torch.isnan(x)): print('layer: ', idx) x = self.norm(x) # B L C x = self.patch_unembed(x, x_size) return x @autocast() def forward(self, x, print_time=False): B, N, C, H, W = x.size() # N video frames x_center = x[:, self.center, :, :, :].contiguous() #### skip module ######## skip1 = self.lrelu2(self.skip_pixel_shuffle(self.skipup1(self.skip_pixel_shuffle(x_center)))) skip2 = self.skip_pixel_shuffle(self.skipup2(skip1)) x_ = self.conv_flow(self.flow_ps(x.view(B*N, C, H, W))).view(B, N, -1, H*2, W*2) # calculate flows ref_flows = self.get_ref_flows(x_) #### extract LR features x = self.lrelu(self.conv_first(x.view(B*N, -1, H, W))) L1_fea = self.mid_ps(self.conv_after_pre_layer(self.pre_forward_features(x))) _, _, H, W = L1_fea.size() L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea)) L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea)) # FPN enhance features L3_fea = self.lrelu(self.toplayer(L3_fea)) L2_fea = self.smooth1(self._upsample_add(L3_fea, self.latlayer1(L2_fea))) L1_fea = self.smooth2(self._upsample_add(L2_fea, self.latlayer2(L1_fea))) L1_fea = L1_fea.view(B, N, -1, H, W).contiguous() L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2 ).contiguous() L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4).contiguous() #### PCD align # ref feature list ref_fea_l = [ L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(), L3_fea[:, self.center, :, :, :].clone() ] aligned_fea = [] for i in range(N): nbr_fea_l = [ L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(), L3_fea[:, i, :, :, :].clone() ] flows_l = [ ref_flows[0][:, i, :, :, :].clone(), ref_flows[1][:, i, :, :, :].clone(), ref_flows[2][:, i, :, :, :].clone() ] # print(nbr_fea_l[0].shape, flows_l[0].shape) nbr_warped_l = [ arch_util.flow_warp(nbr_fea_l[0], flows_l[0].permute(0, 2, 3, 1), 'bilinear'), arch_util.flow_warp(nbr_fea_l[1], flows_l[1].permute(0, 2, 3, 1), 'bilinear'), arch_util.flow_warp(nbr_fea_l[2], flows_l[2].permute(0, 2, 3, 1), 'bilinear') ] aligned_fea.append(self.align(nbr_fea_l, nbr_warped_l, ref_fea_l, flows_l)) aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W] --> [B, T, C, H, W] if not self.non_local: aligned_fea = aligned_fea.view(B, -1, H, W) x = self.lrelu(self.fusion(aligned_fea)) x = self.lrelu(self.conv_after_body(self.forward_features(x))) + x x = self.lrelu(self.pixel_shuffle(self.upconv1(x))) x = skip1 + x x = self.lrelu(self.pixel_shuffle(self.upconv2(x))) x = self.lrelu(self.HRconv(x)) x = self.conv_last(x) x = skip2 + x return x def get_ref_flows(self, x): '''Get flow between frames ref and other''' b, n, c, h, w = x.size() x_nbr = x.reshape(-1, c, h, w) x_ref = x[:, self.center:self.center+1, :, :, :].repeat(1, n, 1, 1, 1).reshape(-1, c, h, w) # backward flows = self.spynet(x_ref, x_nbr) flows_list = [flow.view(b, n, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in zip(flows, range(3))] return flows_list ================================================ FILE: code/real/bsrt/model/checkpoint.py ================================================ import torch import warnings def detach_variable(inputs): if isinstance(inputs, tuple): out = [] for inp in inputs: x = inp.detach() x.requires_grad = inp.requires_grad out.append(x) return tuple(out) else: raise RuntimeError( "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) def check_backward_validity(inputs): if not any(inp.requires_grad for inp in inputs): warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod def backward(ctx, *output_grads): for i in range(len(ctx.input_tensors)): temp = ctx.input_tensors[i] ctx.input_tensors[i] = temp.detach() ctx.input_tensors[i].requires_grad = temp.requires_grad with torch.enable_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) return (None, None) + input_grads ================================================ FILE: code/real/bsrt/model/common.py ================================================ import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) class MeanShift(nn.Conv2d): def __init__( self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std for p in self.parameters(): p.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)): m = [conv(in_channels, out_channels, kernel_size, bias=bias)] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feats)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feats, 4 * n_feats, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) elif scale == 3: m.append(conv(n_feats, 9 * n_feats, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) else: raise NotImplementedError super(Upsampler, self).__init__(*m) class UpOnly(nn.Sequential): def __init__(self, scale): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(nn.PixelShuffle(2)) elif scale == 3: m.append(nn.PixelShuffle(3)) else: raise NotImplementedError super(UpOnly, self).__init__(*m) def lanczos_kernel(dx, a=3, N=None, dtype=None, device=None): ''' Generates 1D Lanczos kernels for translation and interpolation. Args: dx : float, tensor (batch_size, 1), the translation in pixels to shift an image. a : int, number of lobes in the kernel support. If N is None, then the width is the kernel support (length of all lobes), S = 2(a + ceil(dx)) + 1. N : int, width of the kernel. If smaller than S then N is set to S. Returns: k: tensor (?, ?), lanczos kernel ''' if not torch.is_tensor(dx): dx = torch.tensor(dx, dtype=dtype, device=device) if device is None: device = dx.device if dtype is None: dtype = dx.dtype D = dx.abs().ceil().int() S = 2 * (a + D) + 1 # width of kernel support S_max = S.max() if hasattr(S, 'shape') else S if (N is None) or (N < S_max): N = S Z = (N - S) // 2 # width of zeros beyond kernel support start = (-(a + D + Z)).min() end = (a + D + Z + 1).max() x = torch.arange(start, end, dtype=dtype, device=device).view(1, -1) - dx px = (np.pi * x) + 1e-3 sin_px = torch.sin(px) sin_pxa = torch.sin(px / a) k = a * sin_px * sin_pxa / px ** 2 # sinc(x) masked by sinc(x/a) return k def lanczos_shift(img, shift, p=5, a=3): ''' Shifts an image by convolving it with a Lanczos kernel. Lanczos interpolation is an approximation to ideal sinc interpolation, by windowing a sinc kernel with another sinc function extending up to a few nunber of its lobes (typically a=3). Args: img : tensor (batch_size, channels, height, width), the images to be shifted shift : tensor (batch_size, 2) of translation parameters (dy, dx) p : int, padding width prior to convolution (default=3) a : int, number of lobes in the Lanczos interpolation kernel (default=3) Returns: I_s: tensor (batch_size, channels, height, width), shifted images ''' img = img.transpose(0, 1) dtype = img.dtype if len(img.shape) == 2: img = img[None, None].repeat(1, shift.shape[0], 1, 1) # batch of one image elif len(img.shape) == 3: # one image per shift assert img.shape[0] == shift.shape[0] img = img[None,] # Apply padding padder = torch.nn.ReflectionPad2d(p) # reflect pre-padding I_padded = padder(img) # Create 1D shifting kernels y_shift = shift[:, [0]] x_shift = shift[:, [1]] k_y = (lanczos_kernel(y_shift, a=a, N=None, dtype=dtype) .flip(1) # flip axis of convolution )[:, None, :, None] # expand dims to get shape (batch, channels, y_kernel, 1) k_x = (lanczos_kernel(x_shift, a=a, N=None, dtype=dtype) .flip(1) )[:, None, None, :] # shape (batch, channels, 1, x_kernel) # Apply kernels # print(I_padded.shape, k_y.shape) I_s = torch.conv1d(I_padded, groups=k_y.shape[0], weight=k_y, padding=[k_y.shape[2] // 2, 0]) # same padding I_s = torch.conv1d(I_s, groups=k_x.shape[0], weight=k_x, padding=[0, k_x.shape[3] // 2]) I_s = I_s[..., p:-p, p:-p] # remove padding # print(I_s.shape) return I_s.transpose(0, 1) # , k.squeeze() ================================================ FILE: code/real/bsrt/model/non_local/network.py ================================================ from torch import nn # from lib.non_local_concatenation import NONLocalBlock2D # from lib.non_local_gaussian import NONLocalBlock2D from lib.non_local_embedded_gaussian import NONLocalBlock2D # from lib.non_local_dot_product import NONLocalBlock2D class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv_1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), ) self.nl_1 = NONLocalBlock2D(in_channels=32) self.conv_2 = nn.Sequential( nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), ) self.nl_2 = NONLocalBlock2D(in_channels=64) self.conv_3 = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), ) self.fc = nn.Sequential( nn.Linear(in_features=128*3*3, out_features=256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(in_features=256, out_features=10) ) def forward(self, x): batch_size = x.size(0) feature_1 = self.conv_1(x) nl_feature_1 = self.nl_1(feature_1) feature_2 = self.conv_2(nl_feature_1) nl_feature_2 = self.nl_2(feature_2) output = self.conv_3(nl_feature_2).view(batch_size, -1) output = self.fc(output) return output def forward_with_nl_map(self, x): batch_size = x.size(0) feature_1 = self.conv_1(x) nl_feature_1, nl_map_1 = self.nl_1(feature_1, return_nl_map=True) feature_2 = self.conv_2(nl_feature_1) nl_feature_2, nl_map_2 = self.nl_2(feature_2, return_nl_map=True) output = self.conv_3(nl_feature_2).view(batch_size, -1) output = self.fc(output) return output, [nl_map_1, nl_map_2] if __name__ == '__main__': import torch img = torch.randn(3, 1, 28, 28) net = Network() out = net(img) print(out.size()) ================================================ FILE: code/real/bsrt/model/non_local/non_local_concatenation.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(2)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.concat_project = nn.Sequential( nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), nn.ReLU() ) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer) def forward(self, x, return_nl_map=False): ''' :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: ''' batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) # (b, c, N, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) # (b, c, 1, N) phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) h = theta_x.size(2) w = phi_x.size(3) theta_x = theta_x.repeat(1, 1, 1, w) phi_x = phi_x.repeat(1, 1, h, 1) concat_feature = torch.cat([theta_x, phi_x], dim=1) f = self.concat_project(concat_feature) b, _, h, w = f.size() f = f.view(b, h, w) N = f.size(-1) f_div_C = f / N y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/real/bsrt/model/non_local/non_local_cross_dot_product.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(4)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer) def forward(self, x, ref, return_nl_map=False): """ :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_ref = self.theta(ref).view(batch_size, self.inter_channels, -1) theta_ref = theta_ref.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_ref, phi_x) N = f.size(-1) f_div_C = f / N y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/real/bsrt/model/non_local/non_local_dot_product.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(2)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer) def forward(self, x, return_nl_map=False): """ :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_x, phi_x) N = f.size(-1) f_div_C = f / N y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/real/bsrt/model/non_local/non_local_embedded_gaussian.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): """ :param in_channels: :param inter_channels: :param dimension: :param sub_sample: :param bn_layer: """ super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(2)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer) def forward(self, x, return_nl_map=False): """ :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_x, phi_x) f_div_C = F.softmax(f, dim=-1) y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer,) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer,) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/real/bsrt/model/non_local/non_local_gaussian.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(2)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = max_pool_layer def forward(self, x, return_nl_map=False): """ :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = x.view(batch_size, self.in_channels, -1) theta_x = theta_x.permute(0, 2, 1) if self.sub_sample: phi_x = self.phi(x).view(batch_size, self.in_channels, -1) else: phi_x = x.view(batch_size, self.in_channels, -1) f = torch.matmul(theta_x, phi_x) f_div_C = F.softmax(f, dim=-1) # if self.store_last_batch_nl_map: # self.nl_map = f_div_C y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/real/bsrt/model/swin_util.py ================================================ # ----------------------------------------------------------------------------------- # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 # Originally Written by Ze Liu, Modified by Jingyun Liang. # ----------------------------------------------------------------------------------- import math import torch import torch.nn as nn import torch.nn.functional as F # import torch.utils.checkpoint as checkpoint from model.checkpoint import CheckpointFunction as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import time from functools import reduce, lru_cache class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Mlp_GEGLU(nn.Module): """ Multilayer perceptron with gated linear unit (GEGLU). Ref. "GLU Variants Improve Transformer". Args: x: (B, D, H, W, C) Returns: x: (B, D, H, W, C) """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc11 = nn.Linear(in_features, hidden_features) self.fc12 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.act(self.fc11(x)) * self.fc12(x) x = self.drop(x) x = self.fc2(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops @lru_cache() def calculate_mask(x_size, window_size, shift_size): # calculate attention mask for SW-MSA H, W = x_size img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) w_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, window_size * window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) return attn_mask class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio self.use_checkpoint = use_checkpoint if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # if self.shift_size > 0: # attn_mask = self.calculate_mask(self.input_resolution) # else: # attn_mask = None # self.register_buffer("attn_mask", attn_mask) def forward(self, x, x_size): H, W = x_size B, L, C = x.shape # assert L == H * W, "input feature has wrong size" # if self.input_resolution != x_size: # self.input_resolution = x_size # if self.attn_mask is not None: # self.attn_mask = self.calculate_mask(x_size).to(x.device) shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size # if self.input_resolution == x_size: # if self.use_checkpoint: # attn_windows = checkpoint.apply(self.attn, x_windows, self.attn_mask) # nW*B, window_size*window_size, C # else: # attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # else: # if self.use_checkpoint: # attn_windows = checkpoint.apply(self.attn, x_windows, self.calculate_mask(x_size).to(x.device)) # else: # attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) attn_mask = calculate_mask(x_size, self.window_size, self.shift_size).to(x.device) attn_windows = self.attn(x_windows, mask=attn_mask) # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = False # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, use_checkpoint=use_checkpoint) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, x_size): for i, blk in enumerate(self.blocks): if self.use_checkpoint: # x = checkpoint.checkpoint(blk, x, x_size) x = checkpoint.apply(blk, 2, x, x_size) else: x = blk(x, x_size) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class RSTB(nn.Module): """Residual Swin Transformer Block (RSTB). Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. img_size: Input image size. patch_size: Patch size. resi_connection: The convolutional block before residual connection. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, img_size=224, patch_size=4, resi_connection='1conv'): super(RSTB, self).__init__() # print(f'dim: {dim}, input_resolution: {input_resolution}, depth: {depth}, num_heads: {num_heads}, window_size: {window_size}, img_size: {img_size}. patch_size: {patch_size}') self.dim = dim self.input_resolution = input_resolution self.residual_group = BasicLayer(dim=dim, input_resolution=input_resolution, depth=depth, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, norm_layer=norm_layer, downsample=downsample, use_checkpoint=use_checkpoint) if resi_connection == '1conv': self.conv = nn.Conv2d(dim, dim, 3, 1, 1) elif resi_connection == '3conv': # to save parameters and memory self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(dim // 4, dim, 3, 1, 1)) self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) self.patch_unembed = PatchUnEmbed( img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) def forward(self, x, x_size): x = self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x return x def flops(self): flops = 0 flops += self.residual_group.flops() H, W = self.input_resolution flops += H * W * self.dim * self.dim * 9 flops += self.patch_embed.flops() flops += self.patch_unembed.flops() return flops class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x, use_norm=True): x = x.flatten(2).transpose(1, 2) # B Ph*Pw C if use_norm and self.norm is not None: x = self.norm(x) return x def flops(self): flops = 0 H, W = self.img_size if self.norm is not None: flops += H * W * self.embed_dim return flops class PatchUnEmbed(nn.Module): r""" Image to Patch Unembedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim def forward(self, x, x_size): B, HW, C = x.shape x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C return x def flops(self): flops = 0 return flops class Upsample(nn.Sequential): """Upsample module. Args: scale (int): Scale factor. Supported scales: 2^n and 3. num_feat (int): Channel number of intermediate features. """ def __init__(self, scale, num_feat): m = [] if (scale & (scale - 1)) == 0: # scale = 2^n for _ in range(int(math.log(scale, 2))): m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) m.append(nn.PixelShuffle(2)) elif scale == 3: m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) m.append(nn.PixelShuffle(3)) else: raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') super(Upsample, self).__init__(*m) class UpsampleOneStep(nn.Sequential): """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) Used in lightweight SR to save parameters. Args: scale (int): Scale factor. Supported scales: 2^n and 3. num_feat (int): Channel number of intermediate features. """ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): self.num_feat = num_feat self.input_resolution = input_resolution m = [] m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) m.append(nn.PixelShuffle(scale)) super(UpsampleOneStep, self).__init__(*m) def flops(self): H, W = self.input_resolution flops = H * W * self.num_feat * 3 * 9 return flops class SwinIR(nn.Module): r""" SwinIR A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. Args: img_size (int | tuple(int)): Input image size. Default 64 patch_size (int | tuple(int)): Patch size. Default: 1 in_chans (int): Number of input image channels. Default: 3 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction img_range: Image range. 1. or 255. upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None resi_connection: The convolutional block before residual connection. '1conv'/'3conv' """ def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', **kwargs): super(SwinIR, self).__init__() num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 self.img_range = img_range if in_chans == 3: rgb_mean = (0.4488, 0.4371, 0.4040) self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) else: self.mean = torch.zeros(1, 1, 1, 1) self.upscale = upscale self.upsampler = upsampler self.window_size = window_size ##################################################################################################### ################################### 1, shallow feature extraction ################################### self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) ##################################################################################################### ################################### 2, deep feature extraction ###################################### self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = embed_dim self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # print('patches_resolution: ', patches_resolution) # merge non-overlapping patches into image self.patch_unembed = PatchUnEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build Residual Swin Transformer blocks (RSTB) self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = RSTB(dim=embed_dim, input_resolution=(patches_resolution[0], patches_resolution[1]), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results norm_layer=norm_layer, downsample=None, use_checkpoint=use_checkpoint, img_size=img_size, patch_size=patch_size, resi_connection=resi_connection ) self.layers.append(layer) self.norm = norm_layer(self.num_features) # build the last conv layer in deep feature extraction if resi_connection == '1conv': self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) elif resi_connection == '3conv': # to save parameters and memory self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) ##################################################################################################### ################################ 3, high quality image reconstruction ################################ if self.upsampler == 'pixelshuffle': # for classical SR self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) self.upsample = Upsample(upscale, num_feat) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) elif self.upsampler == 'pixelshuffledirect': # for lightweight SR (to save parameters) self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, (patches_resolution[0], patches_resolution[1])) elif self.upsampler == 'nearest+conv': # for real-world SR (less artifacts) assert self.upscale == 4, 'only support x4 now.' self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) else: # for image denoising and JPEG compression artifact reduction self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def check_image_size(self, x): _, _, h, w = x.size() mod_pad_h = (self.window_size - h % self.window_size) % self.window_size mod_pad_w = (self.window_size - w % self.window_size) % self.window_size x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') return x def forward_features(self, x): x_size = (x.shape[2], x.shape[3]) x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x, x_size) x = self.norm(x) # B L C x = self.patch_unembed(x, x_size) return x def forward(self, x): H, W = x.shape[2:] x = self.check_image_size(x) self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range if self.upsampler == 'pixelshuffle': # for classical SR x = self.conv_first(x) x = self.conv_after_body(self.forward_features(x)) + x x = self.conv_before_upsample(x) x = self.conv_last(self.upsample(x)) elif self.upsampler == 'pixelshuffledirect': # for lightweight SR x = self.conv_first(x) x = self.conv_after_body(self.forward_features(x)) + x x = self.upsample(x) elif self.upsampler == 'nearest+conv': # for real-world SR x = self.conv_first(x) x = self.conv_after_body(self.forward_features(x)) + x x = self.conv_before_upsample(x) x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) x = self.conv_last(self.lrelu(self.conv_hr(x))) else: # for image denoising and JPEG compression artifact reduction x_first = self.conv_first(x) res = self.conv_after_body(self.forward_features(x_first)) + x_first x = x + self.conv_last(res) x = x / self.img_range + self.mean return x[:, :, :H*self.upscale, :W*self.upscale] def flops(self): flops = 0 H, W = self.patches_resolution flops += H * W * 3 * self.embed_dim * 9 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += H * W * 3 * self.embed_dim * self.embed_dim flops += self.upsample.flops() return flops if __name__ == '__main__': upscale = 4 window_size = 8 height = (1024 // upscale // window_size + 1) * window_size width = (720 // upscale // window_size + 1) * window_size model = SwinIR(upscale=2, img_size=(height, width), window_size=window_size, img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') print(model) print(height, width, model.flops() / 1e9) x = torch.randn((1, 3, height, width)) x = model(x) print(x.shape) ================================================ FILE: code/real/bsrt/model/utils/interp_methods.py ================================================ from math import pi try: import torch except ImportError: torch = None try: import numpy except ImportError: numpy = None if numpy is None and torch is None: raise ImportError("Must have either Numpy or PyTorch but both not found") def set_framework_dependencies(x): if type(x) is numpy.ndarray: to_dtype = lambda a: a fw = numpy else: to_dtype = lambda a: a.to(x.dtype) fw = torch eps = fw.finfo(fw.float32).eps return fw, to_dtype, eps def support_sz(sz): def wrapper(f): f.support_sz = sz return f return wrapper @support_sz(4) def cubic(x): fw, to_dtype, eps = set_framework_dependencies(x) absx = fw.abs(x) absx2 = absx ** 2 absx3 = absx ** 3 return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * to_dtype((1. < absx) & (absx <= 2.))) @support_sz(4) def lanczos2(x): fw, to_dtype, eps = set_framework_dependencies(x) return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) @support_sz(6) def lanczos3(x): fw, to_dtype, eps = set_framework_dependencies(x) return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) @support_sz(2) def linear(x): fw, to_dtype, eps = set_framework_dependencies(x) return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * to_dtype((0 <= x) & (x <= 1))) @support_sz(1) def box(x): fw, to_dtype, eps = set_framework_dependencies(x) return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) ================================================ FILE: code/real/bsrt/model/utils/psconv.py ================================================ import torch import torch.nn as nn class PyConv2d(nn.Module): """PyConv2d with padding (general case). Applies a 2D PyConv over an input signal composed of several input planes. Args: in_channels (int): Number of channels in the input image out_channels (list): Number of channels for each pyramid level produced by the convolution pyconv_kernels (list): Spatial size of the kernel for each pyramid level pyconv_groups (list): Number of blocked connections from input channels to output channels for each pyramid level stride (int or tuple, optional): Stride of the convolution. Default: 1 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False`` Example:: >>> # PyConv with two pyramid levels, kernels: 3x3, 5x5 >>> m = PyConv2d(in_channels=64, out_channels=[32, 32], pyconv_kernels=[3, 5], pyconv_groups=[1, 4]) >>> input = torch.randn(4, 64, 56, 56) >>> output = m(input) >>> # PyConv with three pyramid levels, kernels: 3x3, 5x5, 7x7 >>> m = PyConv2d(in_channels=64, out_channels=[16, 16, 32], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8]) >>> input = torch.randn(4, 64, 56, 56) >>> output = m(input) """ def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_groups, stride=1, dilation=1, bias=False): super(PyConv2d, self).__init__() assert len(out_channels) == len(pyconv_kernels) == len(pyconv_groups) self.pyconv_levels = [None] * len(pyconv_kernels) for i in range(len(pyconv_kernels)): self.pyconv_levels[i] = nn.Conv2d(in_channels, out_channels[i], kernel_size=pyconv_kernels[i], stride=stride, padding=pyconv_kernels[i] // 2, groups=pyconv_groups[i], dilation=dilation, bias=bias) self.pyconv_levels = nn.ModuleList(self.pyconv_levels) def forward(self, x): out = [] for level in self.pyconv_levels: out.append(level(x)) return torch.cat(out, 1) ################################################################ class PSConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, parts=4, bias=False): super(PSConv2d, self).__init__() self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, dilation, dilation, groups=parts, bias=bias) self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * dilation, 2 * dilation, groups=parts, bias=bias) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) def backward_hook(grad): out = grad.clone() out[self.mask] = 0 return out self.mask = torch.zeros(self.conv.weight.shape).byte().cuda() _in_channels = in_channels // parts _out_channels = out_channels // parts for i in range(parts): self.mask[i * _out_channels: (i + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1 self.mask[(i + parts//2)%parts * _out_channels: ((i + parts//2)%parts + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 self.conv.weight.data[self.mask] = 0 self.conv.weight.register_hook(backward_hook) self.weight = self.conv.weight self.bias = self.conv.bias def forward(self, x): x1, x2 = x.chunk(2, dim=1) x_shift = self.gwconv_shift(torch.cat((x2, x1), dim=1)) return self.gwconv(x) + self.conv(x) + x_shift # PSConv-based Group Convolution class PSGConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, parts=4, bias=False): super(PSGConv2d, self).__init__() self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups * parts, bias=bias) self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * padding, 2 * dilation, groups=groups * parts, bias=bias) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias) def backward_hook(grad): out = grad.clone() out[self.mask] = 0 return out self.mask = torch.zeros(self.conv.weight.shape).bool().cuda() _in_channels = in_channels // (groups * parts) _out_channels = out_channels // (groups * parts) for i in range(parts): for j in range(groups): self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1 self.mask[((i + parts // 2) % parts + j * groups) * _out_channels: ((i + parts // 2) % parts + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 self.conv.weight.data[self.mask] = 0 self.conv.weight.register_hook(backward_hook) self.groups = groups self.weight = self.conv.weight self.bias = self.conv.bias def forward(self, x): x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1)) x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1) x_shift = self.gwconv_shift(x_merge) gx = self.gwconv(x) cx = self.conv(x) # print(x.shape, gx.shape, cx.shape, x_merge.shape, x_shift.shape) return gx + cx + x_shift ================================================ FILE: code/real/bsrt/model/utils/resize_right.py ================================================ import warnings from math import ceil import model.utils.interp_methods as interp_methods class NoneClass: pass try: import torch from torch import nn nnModuleWrapped = nn.Module except ImportError: warnings.warn('No PyTorch found, will work only with Numpy') torch = None nnModuleWrapped = NoneClass try: import numpy except ImportError: warnings.warn('No Numpy found, will work only with PyTorch') numpy = None if numpy is None and torch is None: raise ImportError("Must have either Numpy or PyTorch but both not found") def resize(input, scale_factors=None, out_shape=None, interp_method=interp_methods.cubic, support_sz=None, antialiasing=True): # get properties of the input tensor in_shape, n_dims = input.shape, input.ndim # fw stands for framework that can be either numpy or torch, # determined by the input type fw = numpy if type(input) is numpy.ndarray else torch eps = fw.finfo(fw.float32).eps # set missing scale factors or output shapem one according to another, # scream if both missing scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw) # sort indices of dimensions according to scale of each dimension. # since we are going dim by dim this is efficient sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind]) if scale_factors[dim] != 1.] # unless support size is specified by the user, it is an attribute # of the interpolation method if support_sz is None: support_sz = interp_method.support_sz # when using pytorch, we need to know what is the input tensor device if fw is torch: device = input.device # output begins identical to input and changes with each iteration output = input # iterate over dims for dim, scale_factor in sorted_filtered_dims_and_scales: # get 1d set of weights and fields of view for each output location # along this dim field_of_view, weights = prepare_weights_and_field_of_view_1d( dim, scale_factor, in_shape[dim], out_shape[dim], interp_method, support_sz, antialiasing, fw, eps, device) # multiply the weights by the values in the field of view and # aggreagate output = apply_weights(output, field_of_view, weights, dim, n_dims, fw) return output class ResizeLayer(nnModuleWrapped): def __init__(self, in_shape, scale_factors=None, out_shape=None, interp_method=interp_methods.cubic, support_sz=None, antialiasing=True): super(ResizeLayer, self).__init__() # fw stands for framework, that can be either numpy or torch. since # this is a torch layer, only one option in this case. fw = torch eps = fw.finfo(fw.float32).eps # set missing scale factors or output shapem one according to another, # scream if both missing scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw) # unless support size is specified by the user, it is an attribute # of the interpolation method if support_sz is None: support_sz = interp_method.support_sz self.n_dims = len(in_shape) # sort indices of dimensions according to scale of each dimension. # since we are going dim by dim this is efficient self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) for dim in sorted(range(self.n_dims), key=lambda ind: scale_factors[ind]) if scale_factors[dim] != 1.] # iterate over dims field_of_view_list = [] weights_list = [] for dim, scale_factor in self.sorted_filtered_dims_and_scales: # get 1d set of weights and fields of view for each output # location along this dim field_of_view, weights = prepare_weights_and_field_of_view_1d( dim, scale_factor, in_shape[dim], out_shape[dim], interp_method, support_sz, antialiasing, fw, eps, input.device) # keep weights and fields of views for all dims weights_list.append(nn.Parameter(weights, requires_grad=False)) field_of_view_list.append(nn.Parameter(field_of_view, requires_grad=False)) self.field_of_view = nn.ParameterList(field_of_view_list) self.weights = nn.ParameterList(weights_list) self.in_shape = in_shape def forward(self, input): # output begins identical to input and changes with each iteration output = input for (dim, scale_factor), field_of_view, weights in zip( self.sorted_filtered_dims_and_scales, self.field_of_view, self.weights): # multiply the weights by the values in the field of view and # aggreagate output = apply_weights(output, field_of_view, weights, dim, self.n_dims, torch) return output def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz, interp_method, support_sz, antialiasing, fw, eps, device=None): # If antialiasing is taking place, we modify the window size and the # interpolation method (see inside function) interp_method, cur_support_sz = apply_antialiasing_if_needed( interp_method, support_sz, scale_factor, antialiasing) # STEP 1- PROJECTED GRID: The non-integer locations of the projection of # output pixel locations to the input tensor projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device) # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels # that influence it field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps) # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the # field of view for each output pixel weights = get_weights(interp_method, projected_grid, field_of_view) return field_of_view, weights def apply_weights(input, field_of_view, weights, dim, n_dims, fw): # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying # its set of weights with the pixel values in its field of view. # We now multiply the fields of view with their matching weights. # We do this by tensor multiplication and broadcasting. # this step is separated to a different function, so that it can be # repeated with the same calculated weights and fields. # for this operations we assume the resized dim is the first one. # so we transpose and will transpose back after multiplying tmp_input = fw_swapaxes(input, dim, 0, fw) # field_of_view is a tensor of order 2: for each output (1d location # along cur dim)- a list of 1d neighbors locations. # note that this whole operations is applied to each dim separately, # this is why it is all in 1d. # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1: # for each output pixel (this time indicated in all dims), these are the # values of the neighbors in the 1d field of view. note that we only # consider neighbors along the current dim, but such set exists for every # multi-dim location, hence the final tensor order is image_dims+1. neighbors = tmp_input[field_of_view] # weights is an order 2 tensor: for each output location along 1d- a list # of weighs matching the field of view. we augment it with ones, for # broadcasting, so that when multiplies some tensor the weights affect # only its first dim. tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1))) # now we simply multiply the weights with the neighbors, and then sum # along the field of view, to get a single value per out pixel tmp_output = (neighbors * tmp_weights).sum(1) # we transpose back the resized dim to its original position return fw_swapaxes(tmp_output, 0, dim, fw) def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw): # eventually we must have both scale-factors and out-sizes for all in/out # dims. however, we support many possible partial arguments if scale_factors is None and out_shape is None: raise ValueError("either scale_factors or out_shape should be " "provided") if out_shape is not None: # if out_shape has less dims than in_shape, we defaultly resize the # first dims for numpy and last dims for torch out_shape = (list(out_shape) + list(in_shape[:-len(out_shape)]) if fw is numpy else list(in_shape[:-len(out_shape)]) + list(out_shape)) if scale_factors is None: # if no scale given, we calculate it as the out to in ratio # (not recomended) scale_factors = [out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape)] if scale_factors is not None: # by default, if a single number is given as scale, we assume resizing # two dims (most common are images with 2 spatial dims) scale_factors = (scale_factors if isinstance(scale_factors, (list, tuple)) else [scale_factors, scale_factors]) # if less scale_factors than in_shape dims, we defaultly resize the # first dims for numpy and last dims for torch scale_factors = (list(scale_factors) + [1] * (len(in_shape) - len(scale_factors)) if fw is numpy else [1] * (len(in_shape) - len(scale_factors)) + list(scale_factors)) if out_shape is None: # when no out_shape given, it is calculated by multiplying the # scale by the in_shape (not recomended) out_shape = [ceil(scale_factor * in_sz) for scale_factor, in_sz in zip(scale_factors, in_shape)] # next line intentionally after out_shape determined for stability scale_factors = [float(sf) for sf in scale_factors] return scale_factors, out_shape def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None): # we start by having the ouput coordinates which are just integer locations out_coordinates = fw.arange(out_sz) # if using torch we need to match the grid tensor device to the input device out_coordinates = fw_set_device(out_coordinates, device, fw) # This is projecting the ouput pixel locations in 1d to the input tensor, # as non-integer locations. # the following fomrula is derived in the paper # "From Discrete to Continuous Convolutions" by Shocher et al. return (out_coordinates / scale_factor + (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor)) def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps): # for each output pixel, map which input pixels influence it, in 1d. # we start by calculating the leftmost neighbor, using half of the window # size (eps is for when boundary is exact int) left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) # then we simply take all the pixel centers in the field by counting # window size pixels from the left boundary ordinal_numbers = fw.arange(ceil(cur_support_sz - eps)) # in case using torch we need to match the device ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw) field_of_view = left_boundaries[:, None] + ordinal_numbers # next we do a trick instead of padding, we map the field of view so that # it would be like mirror padding, without actually padding # (which would require enlarging the input tensor) mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw) field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])] field_of_view = fw_set_device(field_of_view,projected_grid.device, fw) return field_of_view def get_weights(interp_method, projected_grid, field_of_view): # the set of weights per each output pixels is the result of the chosen # interpolation method applied to the distances between projected grid # locations and the pixel-centers in the field of view (distances are # directed, can be positive or negative) weights = interp_method(projected_grid[:, None] - field_of_view) # we now carefully normalize the weights to sum to 1 per each output pixel sum_weights = weights.sum(1, keepdims=True) sum_weights[sum_weights == 0] = 1 return weights / sum_weights def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, antialiasing): # antialiasing is "stretching" the field of view according to the scale # factor (only for downscaling). this is low-pass filtering. this # requires modifying both the interpolation (stretching the 1d # function and multiplying by the scale-factor) and the window size. if scale_factor >= 1.0 or not antialiasing: return interp_method, support_sz cur_interp_method = (lambda arg: scale_factor * interp_method(scale_factor * arg)) cur_support_sz = support_sz / scale_factor return cur_interp_method, cur_support_sz def fw_ceil(x, fw): if fw is numpy: return fw.int_(fw.ceil(x)) else: return x.ceil().long() def fw_cat(x, fw): if fw is numpy: return fw.concatenate(x) else: return fw.cat(x) def fw_swapaxes(x, ax_1, ax_2, fw): if fw is numpy: return fw.swapaxes(x, ax_1, ax_2) else: return x.transpose(ax_1, ax_2) def fw_set_device(x, device, fw): if fw is numpy: return x else: return x.to(device) ================================================ FILE: code/real/bsrt/option.py ================================================ import argparse parser = argparse.ArgumentParser(description='EDSR and MDSR') parser.add_argument('--n_resblocks', type=int, default=16, help='number of residual blocks') parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps') parser.add_argument('--n_colors', type=int, default=3, help='number of color channels to use') parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') parser.add_argument('--burst_size', type=int, default=14, help='burst size, max 14') parser.add_argument('--burst_channel', type=int, default=4, help='RAW channel, default:4') parser.add_argument('--swinfeature', action='store_true', help='use swin transformer to extract features') parser.add_argument('--model_level', type=str, default='S', help='S: small, L: large') ################## fine-tune ################## parser.add_argument('--finetune', action='store_true', help='finetune model') parser.add_argument('--finetune_align', action='store_true', help='finetune alignment module') parser.add_argument('--finetune_swin', action='store_true', help='finetune swin trans module') parser.add_argument('--finetune_conv', action='store_true', help='finetune rest convs') parser.add_argument('--finetune_prelayer', action='store_true', help='finetune finetune pre feature extract layer') parser.add_argument('--finetune_upconv', action='store_true', help='finetune finetune up conv layer') parser.add_argument('--finetune_spynet', action='store_true', help='finetune finetune up conv layer') # Hardware specifications parser.add_argument('--n_threads', type=int, default=6, help='number of threads for data loading') parser.add_argument('--cpu', action='store_true', help='use cpu only') parser.add_argument('--n_GPUs', type=int, default=1, help='number of GPUs') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--local_rank', type=int, default=-1, help='proc index') parser.add_argument('--fp16', action='store_true', help='use fp16 only') parser.add_argument('--use_checkpoint', action='store_true', help='use use_checkpoint in swin transformer') # Data specifications parser.add_argument('--root', type=str, default='/data/dataset/ntire21/burstsr/real', help='dataset directory') parser.add_argument('--val_root', type=str, default='../test_set', help='dataset directory') parser.add_argument('--mode', type=str, default='train', help='demo image directory') parser.add_argument('--scale', type=str, default='4', help='super resolution scale') parser.add_argument('--patch_size', type=int, default=256, help='output patch size') parser.add_argument('--rgb_range', type=int, default=1, help='maximum value of RGB') parser.add_argument('--chop', action='store_true', help='enable memory-efficient forward') parser.add_argument('--no_augment', action='store_true', help='do not use data augmentation') # Model specifications parser.add_argument('--model', default='LRSC_EDVR', help='model name') parser.add_argument('--act', type=str, default='relu', help='activation function') parser.add_argument('--pre_train', type=str, default='', help='pre-trained model directory') parser.add_argument('--extend', type=str, default='.', help='pre-trained model directory') parser.add_argument('--res_scale', type=float, default=1, help='residual scaling') parser.add_argument('--shift_mean', default=True, help='subtract pixel mean from the input') parser.add_argument('--dilation', action='store_true', help='use dilated convolution') parser.add_argument('--precision', type=str, default='single', choices=('single', 'half'), help='FP precision for test (single | half)') # Option for Residual channel attention network (RCAN) parser.add_argument('--n_resgroups', type=int, default=20, help='number of residual groups') parser.add_argument('--reduction', type=int, default=16, help='number of feature maps reduction') parser.add_argument('--DA', action='store_true', help='use Dual Attention') parser.add_argument('--CA', action='store_true', help='use Channel Attention') parser.add_argument('--non_local', action='store_true', help='use Dual Attention') # Training specifications parser.add_argument('--reset', action='store_true', help='reset the training') parser.add_argument('--test_every', type=int, default=1000, help='do test per every N batches') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train') parser.add_argument('--batch_size', type=int, default=8, help='input batch size for training') parser.add_argument('--split_batch', type=int, default=1, help='split the batch into smaller chunks') parser.add_argument('--self_ensemble', action='store_true', help='use self-ensemble method for test') parser.add_argument('--test_only', action='store_true', help='set this option to test the model') parser.add_argument('--gan_k', type=int, default=1, help='k value for adversarial loss') # Optimization specifications parser.add_argument('--decay', type=str, default='40-80', help='learning rate decay type') parser.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay') parser.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM', 'RMSprop'), help='optimizer to use (SGD | ADAM | RMSprop)') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), help='ADAM beta') parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') parser.add_argument('--gclip', type=float, default=0, help='gradient clipping threshold (0 = no clipping)') # Loss specifications parser.add_argument('--loss', type=str, default='1*L1', help='loss function configuration') parser.add_argument('--skip_threshold', type=float, default='1e8', help='skipping batch that has large error') # Log specifications parser.add_argument('--save', type=str, default='test', help='file name to save') parser.add_argument('--load', type=str, default='', help='file name to load') parser.add_argument('--resume', type=int, default=0, help='resume from specific checkpoint') parser.add_argument('--save_models', action='store_true', help='save all intermediate models') parser.add_argument('--print_every', type=int, default=10, help='how many batches to wait before logging training status') parser.add_argument('--save_results', action='store_true', help='save output results') parser.add_argument('--save_gt', action='store_true', help='save low-resolution and high-resolution images together') args = parser.parse_args() args.scale = list(map(lambda x: int(x), args.scale.split('+'))) if args.epochs == 0: args.epochs = 1e8 for arg in vars(args): if vars(args)[arg] == 'True': vars(args)[arg] = True elif vars(args)[arg] == 'False': vars(args)[arg] = False ================================================ FILE: code/real/bsrt/pwcnet/LICENSE ================================================ GNU GENERAL PUBLIC LICENSE Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU General Public License is a free, copyleft license for software and other kinds of works. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. We, the Free Software Foundation, use the GNU General Public License for most of our software; it applies also to any other work released this way by its authors. You can apply it to your programs, too. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. To protect your rights, we need to prevent others from denying you these rights or asking you to surrender the rights. Therefore, you have certain responsibilities if you distribute copies of the software, or if you modify it: responsibilities to respect the freedom of others. For example, if you distribute copies of such a program, whether gratis or for a fee, you must pass on to the recipients the same freedoms that you received. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. Developers that use the GNU GPL protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License giving you legal permission to copy, distribute and/or modify it. For the developers' and authors' protection, the GPL clearly explains that there is no warranty for this free software. For both users' and authors' sake, the GPL requires that modified versions be marked as changed, so that their problems will not be attributed erroneously to authors of previous versions. Some devices are designed to deny users access to install or run modified versions of the software inside them, although the manufacturer can do so. This is fundamentally incompatible with the aim of protecting users' freedom to change the software. The systematic pattern of such abuse occurs in the area of products for individuals to use, which is precisely where it is most unacceptable. Therefore, we have designed this version of the GPL to prohibit the practice for those products. If such problems arise substantially in other domains, we stand ready to extend this provision to those domains in future versions of the GPL, as needed to protect the freedom of users. Finally, every program is threatened constantly by software patents. States should not allow patents to restrict development and use of software on general-purpose computers, but in those that do, we wish to avoid the special danger that patents applied to a free program could make it effectively proprietary. To prevent this, the GPL assures that patents cannot be used to render the program non-free. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Use with the GNU Affero General Public License. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU Affero General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the special requirements of the GNU Affero General Public License, section 13, concerning interaction through a network will apply to the combination as such. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. {one line to give the program's name and a brief idea of what it does.} Copyright (C) {year} {name of author} This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If the program does terminal interaction, make it output a short notice like this when it starts in an interactive mode: {project} Copyright (C) {year} {fullname} This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, your program's commands might be different; for a GUI interface, you would use an "about box". You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU GPL, see . The GNU General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read . ================================================ FILE: code/real/bsrt/pwcnet/README.md ================================================ # pytorch-pwc This is a personal reimplementation of PWC-Net [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the licensing terms of the authors. Should you be making use of this particular implementation, please acknowledge it appropriately [2]. Paper For the original version of this work, please see: https://github.com/NVlabs/PWC-Net
Another optical flow implementation from me: https://github.com/sniklaus/pytorch-liteflownet
And another optical flow implementation from me: https://github.com/sniklaus/pytorch-unflow
Yet another optical flow implementation from me: https://github.com/sniklaus/pytorch-spynet ## background The authors of PWC-Net are thankfully already providing a reference implementation in PyTorch. However, its initial version did not reach the performance of the original Caffe version. This is why I created this repositroy, in which I replicated the performance of the official Caffe version by utilizing its weights. The official PyTorch implementation has adopted my approach of using the Caffe weights since then, which is why they are all performing equally well now. Many people have reported issues with CUDA when trying to get the official PyTorch version to run though, while my reimplementaiton does not seem to be subject to such problems. ## setup To download the pre-trained models, run `bash download.bash`. These originate from the original authors, I just converted them to PyTorch. The correlation layer is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided binary packages as outlined in the CuPy repository. ## usage To run it on your own pair of images, use the following command. You can choose between two models, please make sure to see their paper / the code for more details. ``` python run.py --model default --first ./images/first.png --second ./images/second.png --out ./out.flo ``` I am afraid that I cannot guarantee that this reimplementation is correct. However, it produced results identical to the Caffe implementation of the original authors in the examples that I tried. Please feel free to contribute to this repository by submitting issues and pull requests. ## comparison

Comparison

## license As stated in the licensing terms of the authors of the paper, the models are free for non-commercial share-alike purpose. Please make sure to further consult their licensing terms. ## references ``` [1] @inproceedings{Sun_CVPR_2018, author = {Deqing Sun and Xiaodong Yang and Ming-Yu Liu and Jan Kautz}, title = {{PWC-Net}: {CNNs} for Optical Flow Using Pyramid, Warping, and Cost Volume}, booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, year = {2018} } ``` ``` [2] @misc{pytorch-pwc, author = {Simon Niklaus}, title = {A Reimplementation of {PWC-Net} Using {PyTorch}}, year = {2018}, howpublished = {\url{https://github.com/sniklaus/pytorch-pwc}} } ``` ================================================ FILE: code/real/bsrt/pwcnet/__init__.py ================================================ ================================================ FILE: code/real/bsrt/pwcnet/comparison/comparison.py ================================================ #!/usr/bin/env python import math import moviepy import moviepy.editor import numpy import PIL import PIL.Image import PIL.ImageFont import PIL.ImageDraw intX = 32 intY = 436 - 64 objImages = [ { 'strFile': 'official - caffe.png', 'strText': 'official - Caffe' }, { 'strFile': 'this - pytorch.png', 'strText': 'this - PyTorch' } ] npyImages = [] for objImage in objImages: objOutput = PIL.Image.open(objImage['strFile']).convert('RGB') for intU in [ intShift - 10 for intShift in range(20) ]: for intV in [ intShift - 10 for intShift in range(20) ]: if math.sqrt(math.pow(intU, 2.0) + math.pow(intV, 2.0)) <= 5.0: PIL.ImageDraw.Draw(objOutput).text((intX + intU, intY + intV), objImage['strText'], (255, 255, 255), PIL.ImageFont.truetype('freefont/FreeSerifBold.ttf', 32)) # end # end # end PIL.ImageDraw.Draw(objOutput).text((intX, intY), objImage['strText'], (0, 0, 0), PIL.ImageFont.truetype('freefont/FreeSerifBold.ttf', 32)) npyImages.append(numpy.array(objOutput)) # end moviepy.editor.ImageSequenceClip(sequence=npyImages, fps=1).write_gif(filename='comparison.gif', program='ImageMagick', opt='optimizeplus') ================================================ FILE: code/real/bsrt/pwcnet/correlation/README.md ================================================ This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately. ================================================ FILE: code/real/bsrt/pwcnet/correlation/correlation.py ================================================ #!/usr/bin/env python import torch import cupy import re # from torch.cuda.amp import custom_fwd, custom_bwd kernel_Correlation_rearrange = ''' extern "C" __global__ void kernel_Correlation_rearrange( const int n, const float* input, float* output ) { int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; if (intIndex >= n) { return; } int intSample = blockIdx.z; int intChannel = blockIdx.y; float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; __syncthreads(); int intPaddedY = (intIndex / SIZE_3(input)) + 4; int intPaddedX = (intIndex % SIZE_3(input)) + 4; int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; } ''' kernel_Correlation_updateOutput = ''' extern "C" __global__ void kernel_Correlation_updateOutput( const int n, const float* rbot0, const float* rbot1, float* top ) { extern __shared__ char patch_data_char[]; float *patch_data = (float *)patch_data_char; // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 int x1 = blockIdx.x + 4; int y1 = blockIdx.y + 4; int item = blockIdx.z; int ch_off = threadIdx.x; // Load 3D patch into shared shared memory for (int j = 0; j < 1; j++) { // HEIGHT for (int i = 0; i < 1; i++) { // WIDTH int ji_off = (j + i) * SIZE_3(rbot0); for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; int idxPatchData = ji_off + ch; patch_data[idxPatchData] = rbot0[idx1]; } } } __syncthreads(); __shared__ float sum[32]; // Compute correlation for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { sum[ch_off] = 0; int s2o = top_channel % 9 - 4; int s2p = top_channel / 9 - 4; for (int j = 0; j < 1; j++) { // HEIGHT for (int i = 0; i < 1; i++) { // WIDTH int ji_off = (j + i) * SIZE_3(rbot0); for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS int x2 = x1 + s2o; int y2 = y1 + s2p; int idxPatchData = ji_off + ch; int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; } } } __syncthreads(); if (ch_off == 0) { float total_sum = 0; for (int idx = 0; idx < 32; idx++) { total_sum += sum[idx]; } const int sumelems = SIZE_3(rbot0); const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; } } } ''' kernel_Correlation_updateGradFirst = ''' #define ROUND_OFF 50000 extern "C" __global__ void kernel_Correlation_updateGradFirst( const int n, const int intSample, const float* rbot0, const float* rbot1, const float* gradOutput, float* gradFirst, float* gradSecond ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { int n = intIndex % SIZE_1(gradFirst); // channels int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos // round_off is a trick to enable integer division with ceil, even for negative numbers // We use a large offset, for the inner part not to become negative. const int round_off = ROUND_OFF; const int round_off_s1 = round_off; // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) // Same here: int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) float sum = 0; if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { xmin = max(0,xmin); xmax = min(SIZE_3(gradOutput)-1,xmax); ymin = max(0,ymin); ymax = min(SIZE_2(gradOutput)-1,ymax); for (int p = -4; p <= 4; p++) { for (int o = -4; o <= 4; o++) { // Get rbot1 data: int s2o = o; int s2p = p; int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] // Index offset for gradOutput in following loops: int op = (p+4) * 9 + (o+4); // index[o,p] int idxopoffset = (intSample * SIZE_1(gradOutput) + op); for (int y = ymin; y <= ymax; y++) { for (int x = xmin; x <= xmax; x++) { int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] sum += gradOutput[idxgradOutput] * bot1tmp; } } } } } const int sumelems = SIZE_1(gradFirst); const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; } } ''' kernel_Correlation_updateGradSecond = ''' #define ROUND_OFF 50000 extern "C" __global__ void kernel_Correlation_updateGradSecond( const int n, const int intSample, const float* rbot0, const float* rbot1, const float* gradOutput, float* gradFirst, float* gradSecond ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { int n = intIndex % SIZE_1(gradSecond); // channels int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos // round_off is a trick to enable integer division with ceil, even for negative numbers // We use a large offset, for the inner part not to become negative. const int round_off = ROUND_OFF; const int round_off_s1 = round_off; float sum = 0; for (int p = -4; p <= 4; p++) { for (int o = -4; o <= 4; o++) { int s2o = o; int s2p = p; //Get X,Y ranges and clamp // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) // Same here: int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { xmin = max(0,xmin); xmax = min(SIZE_3(gradOutput)-1,xmax); ymin = max(0,ymin); ymax = min(SIZE_2(gradOutput)-1,ymax); // Get rbot0 data: int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] // Index offset for gradOutput in following loops: int op = (p+4) * 9 + (o+4); // index[o,p] int idxopoffset = (intSample * SIZE_1(gradOutput) + op); for (int y = ymin; y <= ymax; y++) { for (int x = xmin; x <= xmax; x++) { int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] sum += gradOutput[idxgradOutput] * bot0tmp; } } } } } const int sumelems = SIZE_1(gradSecond); const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; } } ''' def cupy_kernel(strFunction, objVariables): strKernel = globals()[strFunction] while True: objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) if objMatch is None: break # end intArg = int(objMatch.group(2)) strTensor = objMatch.group(4) intSizes = objVariables[strTensor].size() strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) # end while True: objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) if objMatch is None: break # end intArgs = int(objMatch.group(2)) strArgs = objMatch.group(4).split(',') strTensor = strArgs[0] intStrides = objVariables[strTensor].stride() strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') # end return strKernel # end @cupy.memoize(for_each_device=True) def cupy_launch(strFunction, strKernel): return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) # end class _FunctionCorrelation(torch.autograd.Function): @staticmethod # @custom_fwd#(cast_inputs=torch.float32) def forward(self, first, second): rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) self.save_for_backward(first, second, rbot0, rbot1) assert(first.is_contiguous() == True) assert(second.is_contiguous() == True) output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ]) if first.is_cuda == True: n = first.shape[2] * first.shape[3] cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 'input': first, 'output': rbot0 }))( grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), block=tuple([ 16, 1, 1 ]), args=[ n, first.data_ptr(), rbot0.data_ptr() ] ) n = second.shape[2] * second.shape[3] cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 'input': second, 'output': rbot1 }))( grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), block=tuple([ 16, 1, 1 ]), args=[ n, second.data_ptr(), rbot1.data_ptr() ] ) n = output.shape[1] * output.shape[2] * output.shape[3] cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { 'rbot0': rbot0, 'rbot1': rbot1, 'top': output }))( grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), block=tuple([ 32, 1, 1 ]), shared_mem=first.shape[1] * 4, args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] ) elif first.is_cuda == False: raise NotImplementedError() # end return output # end @staticmethod # @custom_bwd def backward(self, gradOutput): first, second, rbot0, rbot1 = self.saved_tensors assert(gradOutput.is_contiguous() == True) gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None if first.is_cuda == True: if gradFirst is not None: for intSample in range(first.shape[0]): n = first.shape[1] * first.shape[2] * first.shape[3] cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { 'rbot0': rbot0, 'rbot1': rbot1, 'gradOutput': gradOutput, 'gradFirst': gradFirst, 'gradSecond': None }))( grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), block=tuple([ 512, 1, 1 ]), args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] ) # end # end if gradSecond is not None: for intSample in range(first.shape[0]): n = first.shape[1] * first.shape[2] * first.shape[3] cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { 'rbot0': rbot0, 'rbot1': rbot1, 'gradOutput': gradOutput, 'gradFirst': None, 'gradSecond': gradSecond }))( grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), block=tuple([ 512, 1, 1 ]), args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] ) # end # end elif first.is_cuda == False: raise NotImplementedError() # end return gradFirst, gradSecond # end # end def FunctionCorrelation(tenFirst, tenSecond): return _FunctionCorrelation.apply(tenFirst, tenSecond) # end class ModuleCorrelation(torch.nn.Module): def __init__(self): super(ModuleCorrelation, self).__init__() # end def forward(self, tenFirst, tenSecond): return _FunctionCorrelation.apply(tenFirst, tenSecond) # end # end ================================================ FILE: code/real/bsrt/pwcnet/download.bash ================================================ #!/bin/bash wget --verbose --continue --timestamping http://content.sniklaus.com/github/pytorch-pwc/network-chairs-things.pytorch wget --verbose --continue --timestamping http://content.sniklaus.com/github/pytorch-pwc/network-default.pytorch ================================================ FILE: code/real/bsrt/pwcnet/images/README.md ================================================ The used example originates from the MPI Sintel dataset: http://sintel.is.tue.mpg.de/ ================================================ FILE: code/real/bsrt/pwcnet/pwcnet.py ================================================ # Based on run.py from PWCNet import torch import getopt import math import numpy import PIL.Image import sys from torch.cuda.amp import autocast try: from pwcnet.correlation import correlation # the custom cost volume layer except: sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python backwarp_tenGrid = {} backwarp_tenPartial = {} # @autocast(enabled=False) def backwarp(tenInput, tenFlow): if str(tenFlow.shape) not in backwarp_tenGrid: tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() if str(tenFlow.shape) not in backwarp_tenPartial: backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ]) tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1) tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False) tenMask = tenOutput[:, -1:, :, :] tenMask[tenMask > 0.999] = 1.0 tenMask[tenMask < 1.0] = 0.0 return tenOutput[:, :-1, :, :].contiguous() * tenMask.contiguous() class Network(torch.nn.Module): def __init__(self): super(Network, self).__init__() class Extractor(torch.nn.Module): def __init__(self): super(Extractor, self).__init__() self.netOne = torch.nn.Sequential( torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netTwo = torch.nn.Sequential( torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netThr = torch.nn.Sequential( torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFou = torch.nn.Sequential( torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFiv = torch.nn.Sequential( torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netSix = torch.nn.Sequential( torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) def forward(self, tenInput): tenOne = self.netOne(tenInput) tenTwo = self.netTwo(tenOne) tenThr = self.netThr(tenTwo) tenFou = self.netFou(tenThr) tenFiv = self.netFiv(tenFou) tenSix = self.netSix(tenFiv) return [tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix] class Decoder(torch.nn.Module): def __init__(self, intLevel): super(Decoder, self).__init__() intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1] intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0] if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1) if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1) if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] self.netOne = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netTwo = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netThr = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFou = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFiv = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netSix = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1) ) # end def forward(self, tenFirst, tenSecond, objPrevious): tenFlow = None tenFeat = None if objPrevious is None: tenFlow = None tenFeat = None tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False) tenFeat = torch.cat([ tenVolume ], 1) elif objPrevious is not None: tenFlow = self.netUpflow(objPrevious['tenFlow']) tenFeat = self.netUpfeat(objPrevious['tenFeat']) tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False) tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1) tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1) tenFlow = self.netSix(tenFeat) return { 'tenFlow': tenFlow, 'tenFeat': tenFeat } class Refiner(torch.nn.Module): def __init__(self): super(Refiner, self).__init__() self.netMain = torch.nn.Sequential( torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) ) def forward(self, tenInput): return self.netMain(tenInput) self.netExtractor = Extractor() self.netTwo = Decoder(2) self.netThr = Decoder(3) self.netFou = Decoder(4) self.netFiv = Decoder(5) self.netSix = Decoder(6) self.netRefiner = Refiner() # self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(__file__.replace('run.py', 'network-' + arguments_strModel + '.pytorch')).items() }) def forward(self, tenFirst, tenSecond): tenFirst = self.netExtractor(tenFirst) tenSecond = self.netExtractor(tenSecond) objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) return objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat']) class PWCNet(torch.nn.Module): def __init__(self, load_pretrained=True, weights_path=None, rgb2bgr=False): super(PWCNet, self).__init__() self.net = Network() self.rgb2bgr = rgb2bgr if load_pretrained: if weights_path is None: raise Exception else: weights_dict = torch.load(weights_path) self.net.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in weights_dict.items()}) # @autocast() def forward(self, source_img, target_img): assert (source_img.shape[-1] == target_img.shape[-1]) assert (source_img.shape[-2] == target_img.shape[-2]) int_width = source_img.shape[-1] int_height = source_img.shape[-2] source_img = source_img.view(-1, 3, int_height, int_width) target_img = target_img.view(-1, 3, int_height, int_width) if self.rgb2bgr: source_img = source_img[:, [2, 1, 0]].contiguous() target_img = target_img[:, [2, 1, 0]].contiguous() int_preprocessed_width = int(math.floor(math.ceil(int_width / 64.0) * 64.0)) int_preprocessed_height = int(math.floor(math.ceil(int_height / 64.0) * 64.0)) # Make size multiple of 64 source_img_re = torch.nn.functional.interpolate(input=source_img, size=(int_preprocessed_height, int_preprocessed_width), mode='bilinear', align_corners=False) target_img_re = torch.nn.functional.interpolate(input=target_img, size=(int_preprocessed_height, int_preprocessed_width), mode='bilinear', align_corners=False) flow = self.net(target_img_re, source_img_re) flow = 20.0 * torch.nn.functional.interpolate(input=flow, size=(int_height, int_width), mode='bilinear', align_corners=False) scale_factor_x = float(int_width) / float(int_preprocessed_width) scale_factor_y = float(int_height) / float(int_preprocessed_height) flow = torch.stack((flow[:, 0] * scale_factor_x, flow[:, 1] * scale_factor_y), dim=1) return flow ================================================ FILE: code/real/bsrt/pwcnet/requirements.txt ================================================ cupy>=5.0.0 numpy>=1.15.0 Pillow>=5.0.0 torch>=1.3.0 ================================================ FILE: code/real/bsrt/pwcnet/run.py ================================================ #!/usr/bin/env python import torch import getopt import math import numpy import os import PIL import PIL.Image import sys try: from .correlation import correlation # the custom cost volume layer except: sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python # end ########################################################## assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0 torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance ########################################################## arguments_strModel = 'default' arguments_strFirst = './images/first.png' arguments_strSecond = './images/second.png' arguments_strOut = './out.flo' for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use if strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame if strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored # end ########################################################## backwarp_tenGrid = {} backwarp_tenPartial = {} def backwarp(tenInput, tenFlow): if str(tenFlow.shape) not in backwarp_tenGrid: tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda() # end if str(tenFlow.shape) not in backwarp_tenPartial: backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ]) # end tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1) tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False) tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0 return tenOutput[:, :-1, :, :] * tenMask # end ########################################################## class Network(torch.nn.Module): def __init__(self): super(Network, self).__init__() class Extractor(torch.nn.Module): def __init__(self): super(Extractor, self).__init__() self.netOne = torch.nn.Sequential( torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netTwo = torch.nn.Sequential( torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netThr = torch.nn.Sequential( torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFou = torch.nn.Sequential( torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFiv = torch.nn.Sequential( torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netSix = torch.nn.Sequential( torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) # end def forward(self, tenInput): tenOne = self.netOne(tenInput) tenTwo = self.netTwo(tenOne) tenThr = self.netThr(tenTwo) tenFou = self.netFou(tenThr) tenFiv = self.netFiv(tenFou) tenSix = self.netSix(tenFiv) return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ] # end # end class Decoder(torch.nn.Module): def __init__(self, intLevel): super(Decoder, self).__init__() intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1] intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0] if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1) if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1) if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] self.netOne = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netTwo = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netThr = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFou = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFiv = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netSix = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1) ) # end def forward(self, tenFirst, tenSecond, objPrevious): tenFlow = None tenFeat = None if objPrevious is None: tenFlow = None tenFeat = None tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False) tenFeat = torch.cat([ tenVolume ], 1) elif objPrevious is not None: tenFlow = self.netUpflow(objPrevious['tenFlow']) tenFeat = self.netUpfeat(objPrevious['tenFeat']) tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False) tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1) # end tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1) tenFlow = self.netSix(tenFeat) return { 'tenFlow': tenFlow, 'tenFeat': tenFeat } # end # end class Refiner(torch.nn.Module): def __init__(self): super(Refiner, self).__init__() self.netMain = torch.nn.Sequential( torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) ) # end def forward(self, tenInput): return self.netMain(tenInput) # end # end self.netExtractor = Extractor() self.netTwo = Decoder(2) self.netThr = Decoder(3) self.netFou = Decoder(4) self.netFiv = Decoder(5) self.netSix = Decoder(6) self.netRefiner = Refiner() self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(__file__.replace('run.py', 'network-' + arguments_strModel + '.pytorch')).items() }) # end def forward(self, tenFirst, tenSecond): tenFirst = self.netExtractor(tenFirst) tenSecond = self.netExtractor(tenSecond) objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) return objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat']) # end # end netNetwork = None ########################################################## def estimate(tenFirst, tenSecond): global netNetwork if netNetwork is None: netNetwork = Network().cuda().eval() # end assert(tenFirst.shape[1] == tenSecond.shape[1]) assert(tenFirst.shape[2] == tenSecond.shape[2]) intWidth = tenFirst.shape[2] intHeight = tenFirst.shape[1] assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) tenFlow = 20.0 * torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False) tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) return tenFlow[0, :, :, :].cpu() # end ########################################################## if __name__ == '__main__': tenFirst = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) tenSecond = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) tenOutput = estimate(tenFirst, tenSecond) objOutput = open(arguments_strOut, 'wb') numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput) numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput) numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput) objOutput.close() print('finished...') # end ================================================ FILE: code/real/bsrt/requirements.txt ================================================ matplotlib imageio opencv-python tensorboardX tqdm timm ================================================ FILE: code/real/bsrt/scripts/__init__.py ================================================ ================================================ FILE: code/real/bsrt/scripts/cal_mean_std.py ================================================ import torch import numpy as np from tqdm import tqdm from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image from datasets.synthetic_burst_train_set import SyntheticBurst from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB def main(): train_zurich_raw2rgb = ZurichRAW2RGB(root='/data/dataset/ntire21/burstsr/synthetic', split='train') train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=14, crop_sz=384) means = [] stds = [] for data in tqdm(train_data): print(data.shape) break if __name__ == '__main__': # if not args.cpu: torch.cuda.set_device(0) main() ================================================ FILE: code/real/bsrt/scripts/demo.sh ================================================ set -ex rlaunch --cpu=4 --gpu=1 --memory=10240 -- python ./scripts/evaluate_burstsr_val.py ================================================ FILE: code/real/bsrt/scripts/download_burstsr_dataset.py ================================================ import os import urllib.request import zipfile import shutil import argparse def download_burstsr_dataset(download_path): out_dir = download_path + '/burstsr_dataset' # Download train folders for i in range(9): if not os.path.isfile('{}/train_{:02d}.zip'.format(out_dir, i)): print('Downloading train_{:02d}'.format(i)) urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/train_{:02d}.zip'.format(i), '{}/tmp.zip'.format(out_dir)) os.rename('{}/tmp.zip'.format(out_dir), '{}/train_{:02d}.zip'.format(out_dir, i)) # Download val folder if not os.path.isfile('{}/val.zip'.format(out_dir)): print('Downloading val') urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/val.zip', '{}/tmp.zip'.format(out_dir)) os.rename('{}/tmp.zip'.format(out_dir), '{}/val.zip'.format(out_dir)) # Unpack train set for i in range(9): print('Unpacking train_{:02d}'.format(i)) with zipfile.ZipFile('{}/train_{:02d}.zip'.format(out_dir, i), 'r') as zip_ref: zip_ref.extractall('{}'.format(out_dir)) # Move files to a common directory os.makedirs('{}/train'.format(out_dir), exist_ok=True) for i in range(9): file_list = os.listdir('{}/train_{:02d}'.format(out_dir, i)) for b in file_list: source_dir = '{}/train_{:02d}/{}'.format(out_dir, i, b) dst_dir = '{}/train/{}'.format(out_dir, b) if os.path.isdir(source_dir): shutil.move(source_dir, dst_dir) # Delete individual subsets for i in range(9): shutil.rmtree('{}/train_{:02d}'.format(out_dir, i)) # Unpack val set print('Unpacking val') with zipfile.ZipFile('{}/val.zip'.format(out_dir), 'r') as zip_ref: zip_ref.extractall('{}'.format(out_dir)) def main(): parser = argparse.ArgumentParser(description='Downloads and unpacks BurstSR dataset') parser.add_argument('path', type=str, help='Path where the dataset will be downloaded') args = parser.parse_args() download_burstsr_dataset(args.path) if __name__ == '__main__': main() ================================================ FILE: code/real/bsrt/scripts/evaluate.sh ================================================ set -ex rlaunch --cpu=4 --gpu=1 --memory=10240 -- python scripts/evaluate_burstsr_val.py ================================================ FILE: code/real/bsrt/scripts/evaluate_burstsr_val.py ================================================ import torch.nn.functional as F from datasets.burstsr_dataset import BurstSRDataset from utils.metrics import AlignedPSNR from pwcnet.pwcnet import PWCNet root = '/data/dataset/ntire21/burstsr/real/NTIRE/burstsr_dataset' class SimpleBaseline: def __init__(self): pass def __call__(self, burst): burst_rgb = burst[:, 0, [0, 1, 3]] burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') return burst_rgb def main(): # Load dataset dataset = BurstSRDataset(root=root, split='val', burst_size=14, crop_sz=80, random_flip=False) # TODO Set your network here net = SimpleBaseline() device = 'cuda' # Load alignment network, used in AlignedPSNR alignment_net = PWCNet(load_pretrained=True, weights_path='PATH_TO_PWCNET_WEIGHTS') alignment_net = alignment_net.to(device) aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40) scores_all = [] for idx in range(len(dataset)): burst, frame_gt, meta_info_burst, meta_info_gt = dataset[idx] burst = burst.unsqueeze(0).to(device) frame_gt = frame_gt.unsqueeze(0).to(device) net_pred = net(burst) # Calculate Aligned PSNR score = aligned_psnr_fn(net_pred, frame_gt, burst) scores_all.append(score) mean_psnr = sum(scores_all) / len(scores_all) print('Mean PSNR is {:0.3f}'.format(mean_psnr.item())) if __name__ == '__main__': main() ================================================ FILE: code/real/bsrt/scripts/save_results_synburst_val.py ================================================ import torch.nn.functional as F import cv2 from datasets.synthetic_burst_val_set import SyntheticBurstVal import torch import numpy as np import os class SimpleBaseline: def __init__(self): pass def __call__(self, burst): burst_rgb = burst[:, 0, [0, 1, 3]] burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') return burst_rgb def main(): dataset = SyntheticBurstVal('PATH_TO_SyntheticBurstVal') out_dir = 'PATH_WHERE_RESULTS_ARE_SAVED' # TODO Set your network here net = SimpleBaseline() device = 'cuda' os.makedirs(out_dir, exist_ok=True) for idx in range(len(dataset)): burst, burst_name = dataset[idx] burst = burst.to(device).unsqueeze(0) with torch.no_grad(): net_pred = net(burst) # Normalize to 0 2^14 range and convert to numpy array net_pred_np = (net_pred.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16) # Save predictions as png cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np) if __name__ == '__main__': main() ================================================ FILE: code/real/bsrt/scripts/test_burstsr_dataset.py ================================================ import torch.nn.functional as F import cv2 from datasets.burstsr_dataset import BurstSRDataset from torch.utils.data.dataloader import DataLoader from utils.metrics import AlignedPSNR from utils.postprocessing_functions import BurstSRPostProcess from utils.data_format_utils import convert_dict from pwcnet.pwcnet import PWCNet def main(): # Load dataset dataset = BurstSRDataset(root='PATH_TO_BURST_SR', split='val', burst_size=3, crop_sz=56, random_flip=False) data_loader = DataLoader(dataset, batch_size=2) # Load alignment network, used in AlignedPSNR alignment_net = PWCNet(load_pretrained=True, weights_path='PATH_TO_PWCNET_WEIGHTS') alignment_net = alignment_net.to('cuda') aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40) # Postprocessing function to obtain sRGB images postprocess_fn = BurstSRPostProcess(return_np=True) for d in data_loader: burst, frame_gt, meta_info_burst, meta_info_gt = d # A simple baseline which upsamples the base image using bilinear upsampling burst_rgb = burst[:, 0, [0, 1, 3]] burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') # Calculate Aligned PSNR score = aligned_psnr_fn(burst_rgb.cuda(), frame_gt.cuda(), burst.cuda()) print('PSNR is {:0.3f}'.format(score)) meta_info_gt = convert_dict(meta_info_gt, burst.shape[0]) # Apply simple post-processing to obtain RGB images pred_0 = postprocess_fn.process(burst_rgb[0], meta_info_gt[0]) gt_0 = postprocess_fn.process(frame_gt[0], meta_info_gt[0]) pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR) gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR) # Visualize input, ground truth cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0) cv2.imshow('GT', gt_0) input_key = cv2.waitKey(0) if input_key == ord('q'): return if __name__ == '__main__': main() ================================================ FILE: code/real/bsrt/scripts/test_synthetic_bursts.py ================================================ import torch.nn.functional as F import cv2 from datasets.synthetic_burst_train_set import SyntheticBurst from torch.utils.data.dataloader import DataLoader from utils.metrics import PSNR from utils.postprocessing_functions import SimplePostProcess from utils.data_format_utils import convert_dict from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB def main(): zurich_raw2rgb = ZurichRAW2RGB(root='PATH_TO_ZURICH_RAW_TO_RGB', split='test') dataset = SyntheticBurst(zurich_raw2rgb, burst_size=3, crop_sz=256) data_loader = DataLoader(dataset, batch_size=2) # Function to calculate PSNR. Note that the boundary pixels (40 pixels) will be ignored during PSNR computation psnr_fn = PSNR(boundary_ignore=40) # Postprocessing function to obtain sRGB images postprocess_fn = SimplePostProcess(return_np=True) for d in data_loader: burst, frame_gt, flow_vectors, meta_info = d # A simple baseline which upsamples the base image using bilinear upsampling burst_rgb = burst[:, 0, [0, 1, 3]] burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') # Calculate PSNR score = psnr_fn(burst_rgb, frame_gt) print('PSNR is {:0.3f}'.format(score)) meta_info = convert_dict(meta_info, burst.shape[0]) # Apply simple post-processing to obtain RGB images pred_0 = postprocess_fn.process(burst_rgb[0], meta_info[0]) gt_0 = postprocess_fn.process(frame_gt[0], meta_info[0]) pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR) gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR) # Visualize input, ground truth cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0) cv2.imshow('GT', gt_0) input_key = cv2.waitKey(0) if input_key == ord('q'): return if __name__ == '__main__': main() ================================================ FILE: code/real/bsrt/test.py ================================================ import torch.nn.functional as F import cv2 import torch import numpy as np import os from tqdm import tqdm from datasets.realworld_burst_test_set import RealWorldBurstTest from datasets.burstsr_dataset import flatten_raw_image_batch, pack_raw_image_batch import model import utility from option import args import torch.multiprocessing as mp import torch.backends.cudnn as cudnn import torch.utils.data.distributed import time checkpoint = utility.checkpoint(args) def main_worker(local_rank, nprocs, args): device = 'cuda' cudnn.benchmark = True args.local_rank = local_rank utility.setup(local_rank, nprocs) torch.cuda.set_device(local_rank) dataset = RealWorldBurstTest(args.root) out_dir = 'bsrt_realworld' os.makedirs(out_dir, exist_ok=True) _model = model.Model(args, checkpoint) tt = [] for idx in tqdm(range(len(dataset))): burst, meta_info = dataset[idx] burst_name = meta_info['burst_name'] burst = burst.to(device).unsqueeze(0) with torch.no_grad(): tic = time.time() sr = _model(burst, 0).float() toc = time.time() tt.append(toc-tic) # Normalize to 0 2^14 range and convert to numpy array net_pred_np = (sr.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16) cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np) print('avg time: {:.4f}'.format(np.mean(tt))) utility.cleanup() def main(): mp.spawn(main_worker, nprocs=1, args=(1, args)) if __name__ == '__main__': main() ================================================ FILE: code/real/bsrt/test_real.py ================================================ import cv2 import torch import numpy as np import os from tqdm import tqdm import random import utility from option import args import torchvision.utils as tvutils from pwcnet.pwcnet import PWCNet from utils.postprocessing_functions import BurstSRPostProcess from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image_batch, pack_raw_image from utils.metrics import AlignedPSNR from utils.data_format_utils import convert_dict from data_processing.camera_pipeline import demosaic import model import torch.multiprocessing as mp import torch.backends.cudnn as cudnn import torch.utils.data.distributed import time from torchsummaryX import summary checkpoint = utility.checkpoint(args) def main(): mp.spawn(main_worker, nprocs=1, args=(1, args)) def main_worker(local_rank, nprocs, args): cudnn.benchmark = True args.local_rank = local_rank utility.setup(local_rank, nprocs) torch.cuda.set_device(local_rank) dataset = BurstSRDataset(root=args.root, burst_size=14, crop_sz=80, split='val') out_dir = 'val/bsrt_real' _model = model.Model(args, checkpoint) for param in _model.parameters(): param.requires_grad = False alignment_net = PWCNet(load_pretrained=True, weights_path='./pwcnet/pwcnet-network-default.pth') alignment_net = alignment_net.to('cuda') for param in alignment_net.parameters(): param.requires_grad = False aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40) postprocess_fn = BurstSRPostProcess(return_np=True) os.makedirs(out_dir, exist_ok=True) tt = [] psnrs, ssims, lpipss = [], [], [] for idx in tqdm(range(len(dataset))): burst_, gt, meta_info_burst, meta_info_gt = dataset[idx] burst_ = burst_.unsqueeze(0).cuda() gt = gt.unsqueeze(0).cuda() # burst = flatten_raw_image_batch(burst_) name = meta_info_burst['burst_name'] with torch.no_grad(): tic = time.time() sr = _model(burst_, 0).float() toc = time.time() tt.append(toc-tic) # sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short() # sr = sr_int.float() / (2 ** 14) psnr, ssim, lpips = aligned_psnr_fn(sr, gt, burst_) psnrs.append(psnr.item()) ssims.append(ssim.item()) lpipss.append(lpips.item()) # lrs = burst_[0] # os.makedirs(f'{out_dir}/{name}', exist_ok=True) # for i, lr in enumerate(lrs): # # print(lr[[0, 1, 3],...].shape) # lr = postprocess_fn.process(lr[[0, 1, 3],...], meta_info_burst) # lr = cv2.cvtColor(lr, cv2.COLOR_RGB2BGR) # cv2.imwrite('{}/{}/{:2d}.png'.format(out_dir, name, i), lr) # gt = postprocess_fn.process(gt[0], meta_info_burst) # gt = cv2.cvtColor(gt, cv2.COLOR_RGB2BGR) # cv2.imwrite('{}/{}_gt.png'.format(out_dir, name), gt) # sr_ = postprocess_fn.process(sr[0], meta_info_burst) # sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR) # cv2.imwrite('{}/{}_bsrt.png'.format(out_dir, name), sr_) del burst_ del sr del gt print(f'avg PSNR: {np.mean(psnrs):.6f}') print(f'avg SSIM: {np.mean(ssims):.6f}') print(f'avg LPIPS: {np.mean(lpipss):.6f}') print(f' avg time: {np.mean(tt):.6f}') # utility.cleanup() if __name__ == '__main__': main() ================================================ FILE: code/real/bsrt/trainer.py ================================================ import os import sys from decimal import Decimal import cv2 import utility import torchvision.utils as tvutils import torch.nn.functional as F import random import torch from tensorboardX import SummaryWriter from pwcnet.pwcnet import PWCNet from utils.postprocessing_functions import BurstSRPostProcess from utils.data_format_utils import convert_dict from utils.metrics import AlignedL1, AlignedPSNR from datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch, pack_raw_image_batch from data_processing.camera_pipeline import demosaic from tqdm import tqdm from loss.filter import Filter from torch.cuda.amp import autocast as autocast, GradScaler train_log_dir = '../train_log/' exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] tfboard_name = exp_name + "_" exp_train_log_dir = os.path.join(train_log_dir, exp_name) LOG_DIR = os.path.join(exp_train_log_dir, 'logs') # save img path IMG_SAVE_DIR = os.path.join(exp_train_log_dir, 'img_log') # Where to load model LOAD_MODEL_DIR = os.path.join(exp_train_log_dir, 'models') # Where to save new model SAVE_MODEL_DIR = os.path.join(exp_train_log_dir, 'real_models') # Where to save visualization images (for report) RESULTS_DIR = os.path.join(exp_train_log_dir, 'report') utility.mkdir(SAVE_MODEL_DIR) utility.mkdir(IMG_SAVE_DIR) utility.mkdir(LOG_DIR) class Trainer(): def __init__(self, args, train_loader, train_sampler, valid_loader, my_model, my_loss, ckp): self.args = args self.scale = args.scale[0] self.ckp = ckp self.loader_train = train_loader self.loader_valid = valid_loader self.train_sampler = train_sampler self.model = my_model self.loss = my_loss self.optimizer = utility.make_optimizer(args, self.model) # Postprocessing function to obtain sRGB images self.postprocess_fn = BurstSRPostProcess(return_np=True) self.alignment_net = PWCNet(load_pretrained=True, weights_path='./pwcnet/pwcnet-network-default.pth') self.alignment_net = self.alignment_net.to('cuda') for param in self.alignment_net.parameters(): param.requires_grad = False self.aligned_psnr_fn = AlignedPSNR(alignment_net=self.alignment_net, boundary_ignore=40) if 'L1' in args.loss: self.aligned_loss = AlignedL1(alignment_net=self.alignment_net, boundary_ignore=40) if self.args.fp16: self.scaler = GradScaler() self.best_psnr = 0. self.best_epoch = 0 if self.args.load != '': self.optimizer.load(ckp.dir, epoch=len(ckp.log)) self.error_last = 1e8 self.glob_iter = 0 self.log_dir = LOG_DIR + "/" + args.save self.img_save_dir = IMG_SAVE_DIR + "/" + args.save # Where to load model self.load_model_dir = LOAD_MODEL_DIR + "/" + args.save # Where to save new model self.save_model_dir = SAVE_MODEL_DIR + "/" + args.save # Where to save visualization images (for report) self.results_dir = RESULTS_DIR + "/" + args.save self.writer = SummaryWriter(log_dir=self.log_dir) utility.mkdir(self.save_model_dir) utility.mkdir(self.img_save_dir) utility.mkdir(self.log_dir) utility.mkdir('frames') def train(self): self.loss.step() epoch = self.optimizer.get_last_epoch() + 1 lr = self.optimizer.get_lr() if self.train_sampler: self.train_sampler.set_epoch(epoch) if epoch % 100 == 0: self.ckp.write_log( '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) ) self.loss.start_log() # self.test() self.model.train() if self.args.local_rank <= 0: timer_data, timer_model, timer_epoch = utility.timer(), utility.timer(), utility.timer() timer_epoch.tic() for batch, batch_value in enumerate(self.loader_train): burst, gt, meta_info_burst, meta_info_gt = batch_value burst, gt = self.prepare(burst, gt) # burst = flatten_raw_image_batch(burst_) if self.args.local_rank == 0: timer_data.hold() timer_model.tic() if self.args.fp16: with autocast(): sr = self.model(burst, 0).float() # loss = self.aligned_loss(sr, gt, burst) else: sr = self.model(burst, 0) loss = self.aligned_loss(sr, gt, burst) if self.args.n_GPUs > 1: torch.distributed.barrier() reduced_loss = utility.reduce_mean(loss, self.args.n_GPUs) else: reduced_loss = loss self.model.zero_grad() if self.args.fp16: self.scaler.scale(loss).backward() # torch.nn.utils.clip_grad_value_(self.model.parameters(), .01) if torch.isinf(sr).sum() + torch.isnan(sr).sum() <= 0: self.scaler.step(self.optimizer) self.scaler.update() else: print(f'Nan num: {torch.isnan(sr).sum()}, inf num: {torch.isinf(sr).sum()}') reduced_loss = None os._exit(0) sys.exit(0) else: loss.backward() # torch.nn.utils.clip_grad_value_(self.model.parameters(), .01) if torch.isinf(sr).sum() + torch.isnan(sr).sum() <= 0: self.optimizer.step() else: print(f'Nan num: {torch.isnan(sr).sum()}, inf num: {torch.isinf(sr).sum()}') reduced_loss = None if self.args.local_rank == 0: timer_model.hold() if epoch % 1 == 0 and batch % 10 == 0: self.writer.add_scalars('Loss', {tfboard_name + '_mse_L1': reduced_loss.detach().cpu().numpy()}, self.glob_iter) if (batch + 1) % self.args.print_every == 0: self.ckp.write_log('[{}/{}]\t[{:.4f}]\t{:.1f}+{:.1f}s'.format( (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), reduced_loss.item(), timer_model.release(), timer_data.release())) self.glob_iter += 1 timer_data.tic() if self.args.local_rank <= 0 and (batch + 1) % 200 == 0: if not self.args.test_only: filename = exp_name + '_latest' + '.pth' self.save_model(filename) if self.args.local_rank <= 0: timer_epoch.hold() print('Epoch {} cost time: {:.1f}s, lr: {:5f}'.format(epoch, timer_epoch.release(), lr)) if (epoch) % 1 == 0 and not self.args.test_only: filename = exp_name + '_epoch_' + str(epoch) + '.pth' self.save_model(filename) if not self.args.test_only: filename = exp_name + '_latest' + '.pth' self.save_model(filename) torch.cuda.synchronize() torch.cuda.empty_cache() self.test() self.loss.end_log(len(self.loader_train)) self.error_last = self.loss.log[-1, -1] self.optimizer.schedule() def test(self): torch.set_grad_enabled(False) def ttaup(burst): # burst0 = flatten_raw_image_batch(burst) # B, T, C, H, W # burst1 = utility.bayer_aug(burst0, flip_h=False, flip_w=False, transpose=True) # burst0 = pack_raw_image_batch(burst0) # burst1 = pack_raw_image_batch(burst1) return [burst] def ttadown(bursts): burst0 = bursts[0] # burst1 = bursts[1].permute(0, 1, 3, 2) # out = (burst0 + burst1) / 2 out = burst0 return out epoch = self.optimizer.get_last_epoch() + 1 self.model.eval() if self.args.local_rank == 0: print("Testing...") timer_test = utility.timer() if epoch == 1 or epoch % 1 == 0: self.model.eval() total_psnr = 0 total_ssim = 0 total_lpips = 0 count = 0 for i, batch_value in tqdm(enumerate(self.loader_valid)): burst, gt, meta_info_burst, meta_info_gt = batch_value burst, gt = self.prepare(burst, gt) # burst_ = flatten_raw_image_batch(burst) bursts = ttaup(burst) with torch.no_grad(): srs = [] for b in bursts: if self.args.fp16: with autocast(): sr = self.model(b, 0).float() else: sr = self.model(b, 0).float() srs.append(sr) sr = ttadown(srs) # sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short() # sr = sr_int.float() / (2 ** 14) score, ssim_score, lpips_score = self.aligned_psnr_fn(sr, gt, burst) if self.args.n_GPUs > 1: torch.distributed.barrier() score = utility.reduce_mean(score, self.args.n_GPUs) ssim_score = utility.reduce_mean(ssim_score, self.args.n_GPUs) lpips_score = utility.reduce_mean(lpips_score, self.args.n_GPUs) total_psnr += score total_ssim += ssim_score total_lpips += lpips_score count += 1 # # if i > 3 and i < 6 and self.args.local_rank == 0: # if i > 200 and i < 400 and self.args.local_rank <= 0: # meta_info_gt = convert_dict(meta_info_gt, burst.shape[0]) # meta_info_burst = convert_dict(meta_info_burst, burst.shape[0]) # # Apply simple post-processing to obtain RGB images # in_ = demosaic(burst[0][0]) # in_ = self.postprocess_fn.process(in_, meta_info_burst[0]) # sr_ = self.postprocess_fn.process(sr[0], meta_info_gt[0]) # # gt_ = self.postprocess_fn.process(gt[0], meta_info_gt[0]) # in_ = cv2.cvtColor(in_, cv2.COLOR_RGB2BGR) # sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR) # # gt_ = cv2.cvtColor(gt_, cv2.COLOR_RGB2BGR) # cv2.imwrite('frames/{}_in.png'.format(i), in_) # cv2.imwrite('frames/{}_gt.png'.format(i), gt_) # cv2.imwrite('frames/{}_sr.png'.format(i), sr_) total_psnr = total_psnr / count total_ssim = total_ssim / count total_lpips = total_lpips / count if self.args.local_rank == 0: print("[Epoch: {}]\n[PSNR: {:.4f}][SSIM: {:.4f}][LPIPS: {:.4f}][Best PSNR: {:.4f}][Best Epoch: {}]" .format(epoch, total_psnr, total_ssim, total_lpips, self.best_psnr, self.best_epoch)) if epoch >= 1 and total_psnr > self.best_psnr: self.best_psnr = total_psnr self.best_epoch = epoch filename = exp_name + 'best_epoch.pth' self.save_model(filename) self.writer.add_scalars('PSNR', {tfboard_name + '_PSNR': total_psnr}, self.glob_iter) print('Forward: {:.2f}s\n'.format(timer_test.toc())) torch.cuda.synchronize() torch.set_grad_enabled(True) torch.cuda.empty_cache() def save_model(self, filename): print('save model...') net_save_path = os.path.join(self.save_model_dir, filename) model = self.model.model if self.args.n_GPUs > 1: model = model.module torch.save(model.state_dict(), net_save_path) def prepare(self, *args): device = torch.device('cpu' if self.args.cpu else 'cuda:{}'.format(self.args.local_rank)) def _prepare(tensor): if self.args.precision == 'half': tensor = tensor.half() return tensor.to(device) # print(_prepare(args[0]).device) return [_prepare(a) for a in args] def terminate(self): if self.args.test_only: self.test() return True else: epoch = self.optimizer.get_last_epoch() + 1 return epoch >= self.args.epochs ================================================ FILE: code/real/bsrt/utility.py ================================================ import math import time import datetime from multiprocessing import Process from multiprocessing import Queue import matplotlib.pyplot as plt import numpy as np import imageio import os import sys import torch import torch.optim as optim import torch.optim.lr_scheduler as lrs import torch.distributed as dist import matplotlib matplotlib.use('Agg') def reduce_mean(tensor, nprocs): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= nprocs return rt def setup(rank, world_size): if sys.platform == 'win32': # Distributed package only covers collective communications with Gloo # backend and FileStore on Windows platform. Set init_method parameter # in init_process_group to a local file. # Example init_method="file:///f:/libtmp/some_file" init_method = "tcp://localhost:1234" # initialize the process group dist.init_process_group( "gloo", init_method=init_method, rank=rank, world_size=world_size ) else: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12256' # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def mkdir(path): if not os.path.exists(path): os.makedirs(path) class timer(): def __init__(self): self.acc = 0 self.tic() def tic(self): self.t0 = time.time() def toc(self, restart=False): diff = time.time() - self.t0 if restart: self.t0 = time.time() return diff def hold(self): self.acc += self.toc() def release(self): ret = self.acc self.acc = 0 return ret def reset(self): self.acc = 0 class checkpoint(): def __init__(self, args): self.args = args self.ok = True self.log = torch.Tensor() now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') if not args.load: if not args.save: args.save = now self.dir = os.path.join('..', 'experiment', args.save) else: self.dir = os.path.join('..', 'experiment', args.load) if os.path.exists(self.dir): self.log = torch.load(self.get_path('psnr_log.pt')) print('Continue from epoch {}...'.format(len(self.log))) else: args.load = '' if args.reset: os.system('rm -rf ' + self.dir) args.load = '' os.makedirs(self.dir, exist_ok=True) os.makedirs(self.get_path('model'), exist_ok=True) # for d in args.data_test: # os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w' self.log_file = open(self.get_path('log.txt'), open_type) with open(self.get_path('config.txt'), open_type) as f: f.write(now + '\n\n') for arg in vars(args): f.write('{}: {}\n'.format(arg, getattr(args, arg))) f.write('\n') self.n_processes = 8 def get_path(self, *subdir): return os.path.join(self.dir, *subdir) def save(self, trainer, epoch, is_best=False): trainer.model.save(self.get_path('model'), epoch, is_best=is_best) trainer.loss.save(self.dir) trainer.loss.plot_loss(self.dir, epoch) self.plot_psnr(epoch) trainer.optimizer.save(self.dir) torch.save(self.log, self.get_path('psnr_log.pt')) def add_log(self, log): self.log = torch.cat([self.log, log]) def write_log(self, log, refresh=False): print(log) self.log_file.write(log + '\n') if refresh: self.log_file.close() self.log_file = open(self.get_path('log.txt'), 'a') def done(self): self.log_file.close() def plot_psnr(self, epoch): axis = np.linspace(1, epoch, epoch) for idx_data, d in enumerate(self.args.data_test): label = 'SR on {}'.format(d) fig = plt.figure() plt.title(label) for idx_scale, scale in enumerate(self.args.scale): plt.plot( axis, self.log[:, idx_data, idx_scale].numpy(), label='Scale {}'.format(scale) ) plt.legend() plt.xlabel('Epochs') plt.ylabel('PSNR') plt.grid(True) plt.savefig(self.get_path('test_{}.pdf'.format(d))) plt.close(fig) def begin_background(self): self.queue = Queue() def bg_target(queue): while True: if not queue.empty(): filename, tensor = queue.get() if filename is None: break imageio.imwrite(filename, tensor.numpy()) self.process = [ Process(target=bg_target, args=(self.queue,)) \ for _ in range(self.n_processes) ] for p in self.process: p.start() def end_background(self): for _ in range(self.n_processes): self.queue.put((None, None)) while not self.queue.empty(): time.sleep(1) for p in self.process: p.join() def save_results(self, dataset, filename, save_list, scale): if self.args.save_results: filename = self.get_path( 'results-{}'.format(dataset.dataset.name), '{}_x{}_'.format(filename, scale) ) postfix = ('SR', 'LR', 'HR') for v, p in zip(save_list, postfix): normalized = v[0].mul(255 / self.args.rgb_range) tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) def quantize(img, rgb_range): pixel_range = 255 / rgb_range return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) def calc_psnr(sr, hr, scale, rgb_range, dataset=None): if hr.nelement() == 1: return 0 diff = (sr - hr) / rgb_range if dataset and dataset.dataset.benchmark: shave = scale if diff.size(1) > 1: gray_coeffs = [65.738, 129.057, 25.064] convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 diff = diff.mul(convert).sum(dim=1) else: shave = scale + 6 valid = diff[..., shave:-shave, shave:-shave] mse = valid.pow(2).mean() return -10 * math.log10(mse) def make_optimizer(args, target): ''' make optimizer and scheduler together ''' # optimizer trainable = filter(lambda x: x.requires_grad, target.parameters()) kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} if args.optimizer == 'SGD': optimizer_class = optim.SGD kwargs_optimizer['momentum'] = args.momentum elif args.optimizer == 'ADAM': optimizer_class = optim.Adam kwargs_optimizer['betas'] = args.betas kwargs_optimizer['eps'] = args.epsilon elif args.optimizer == 'RMSprop': optimizer_class = optim.RMSprop kwargs_optimizer['eps'] = args.epsilon # scheduler milestones = list(map(lambda x: int(x), args.decay.split('-'))) kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} scheduler_class = lrs.MultiStepLR class CustomOptimizer(optimizer_class): def __init__(self, *args, **kwargs): super(CustomOptimizer, self).__init__(*args, **kwargs) def _register_scheduler(self, scheduler_class, **kwargs): self.scheduler = scheduler_class(self, **kwargs) def save(self, save_dir): torch.save(self.state_dict(), self.get_dir(save_dir)) def load(self, load_dir, epoch=1): self.load_state_dict(torch.load(self.get_dir(load_dir))) if epoch > 1: for _ in range(epoch): self.scheduler.step() def get_dir(self, dir_path): return os.path.join(dir_path, 'optimizer.pt') def schedule(self): self.scheduler.step() def get_lr(self): return self.scheduler.get_last_lr()[0] def get_last_epoch(self): return self.scheduler.last_epoch optimizer = CustomOptimizer(trainable, **kwargs_optimizer) optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) return optimizer def write_gray_to_tfboard(img): img_debug = img[0, ...].detach().cpu().numpy() # img_debug = cv2.normalize(img_debug, None, 0, 255, # cv2.NORM_MINMAX, cv2.CV_8U) img_debug = img_debug * 255 img_debug = np.clip(img_debug, 0, 255) img_debug = img_debug.astype(np.uint8) return img_debug[0, ...] ######################## BayerUnifyAug ############################ BAYER_PATTERNS = ["RGGB", "BGGR", "GRBG", "GBRG"] NORMALIZATION_MODE = ["crop", "pad"] def bayer_unify(raw, input_pattern, target_pattern, mode) -> np.ndarray: """ Convert a bayer raw image from one bayer pattern to another. mode: {"crop", "pad"} The way to handle submosaic shift. "crop" abandons the outmost pixels, and "pad" introduces extra pixels. Use "crop" in training and "pad" in testing. """ if input_pattern == target_pattern: h_offset, w_offset = 0, 0 elif input_pattern[0] == target_pattern[2] and input_pattern[1] == target_pattern[3]: h_offset, w_offset = 1, 0 elif input_pattern[0] == target_pattern[1] and input_pattern[2] == target_pattern[3]: h_offset, w_offset = 0, 1 elif input_pattern[0] == target_pattern[3] and input_pattern[1] == target_pattern[2]: h_offset, w_offset = 1, 1 else: # This is not happening in ["RGGB", "BGGR", "GRBG", "GBRG"] raise RuntimeError('Unexpected pair of input and target bayer pattern!') if mode == "pad": # out = np.pad(raw, [[h_offset, h_offset], [w_offset, w_offset]], 'reflect') out = F.pad(raw, (w_offset, w_offset, h_offset, h_offset), mode='reflect') elif mode == "crop": _, _, _, h, w = raw.shape out = raw[..., h_offset:h - h_offset, w_offset:w - w_offset] else: raise ValueError('Unknown normalization mode!') return out def bayer_aug(raw, flip_h=False, flip_w=False, transpose=False, input_pattern='RGGB') -> np.ndarray: """ Apply augmentation to a bayer raw image. """ aug_pattern, target_pattern = input_pattern, input_pattern out = raw if flip_h: out = torch.flip(out, [3]) # GBRG, RGGB aug_pattern = aug_pattern[2] + aug_pattern[3] + aug_pattern[0] + aug_pattern[1] if flip_w: out = torch.flip(out, [4]) aug_pattern = aug_pattern[1] + aug_pattern[0] + aug_pattern[3] + aug_pattern[2] if transpose: out = out.permute(0, 1, 2, 4, 3) aug_pattern = aug_pattern[0] + aug_pattern[2] + aug_pattern[1] + aug_pattern[3] out = bayer_unify(out, aug_pattern, target_pattern, "crop") return out ================================================ FILE: code/real/bsrt/utils/__init__.py ================================================ ================================================ FILE: code/real/bsrt/utils/data_format_utils.py ================================================ import numpy as np import torch import cv2 as cv def numpy_to_torch(a: np.ndarray): return torch.from_numpy(a).float().permute(2, 0, 1) def torch_to_numpy(a: torch.Tensor): return a.permute(1, 2, 0).cpu().numpy() def torch_to_npimage(a: torch.Tensor, unnormalize=True): a_np = torch_to_numpy(a) if unnormalize: a_np = a_np * 255 a_np = a_np.astype(np.uint8) return cv.cvtColor(a_np, cv.COLOR_RGB2BGR) def npimage_to_torch(a, normalize=True, input_bgr=True): if input_bgr: a = cv.cvtColor(a, cv.COLOR_BGR2RGB) a_t = numpy_to_torch(a) if normalize: a_t = a_t / 255.0 return a_t def convert_dict(base_dict, batch_sz): out_dict = [] for b_elem in range(batch_sz): b_info = {} for k, v in base_dict.items(): if isinstance(v, (list, torch.Tensor)): b_info[k] = v[b_elem] out_dict.append(b_info) return out_dict ================================================ FILE: code/real/bsrt/utils/debayer.py ================================================ import torch import torch.nn import torch.nn.functional class Debayer3x3(torch.nn.Module): '''Demosaicing of Bayer images using 3x3 convolutions. Requires BG-Bayer color filter array layout. That is, the image[1,1]='B', image[1,2]='G'. This corresponds to OpenCV naming conventions. Compared to Debayer2x2 this method does not use upsampling. Instead, we identify five 3x3 interpolation kernels that are sufficient to reconstruct every color channel at every pixel location. We convolve the image with these 5 kernels using stride=1 and a one pixel replication padding. Finally, we gather the correct channel values for each pixel location. Todo so, we recognize that the Bayer pattern repeats horizontally and vertically every 2 pixels. Therefore, we define the correct index lookups for a 2x2 grid cell and then repeat to image dimensions. Note, in every 2x2 grid cell we have red, blue and two greens (G1,G2). The lookups for the two greens differ. ''' def __init__(self): super(Debayer3x3, self).__init__() self.kernels = torch.nn.Parameter( torch.tensor([ [0,0,0], [0,1,0], [0,0,0], [0, 0.25, 0], [0.25, 0, 0.25], [0, 0.25, 0], [0.25, 0, 0.25], [0, 0, 0], [0.25, 0, 0.25], [0, 0, 0], [0.5, 0, 0.5], [0, 0, 0], [0, 0.5, 0], [0, 0, 0], [0, 0.5, 0], ]).view(5,1,3,3), requires_grad=False ) self.index = torch.nn.Parameter( torch.tensor([ # dest channel r [0, 3], # pixel is R,G1 [4, 2], # pixel is G2,B # dest channel g [1, 0], # pixel is R,G1 [0, 1], # pixel is G2,B # dest channel b [2, 4], # pixel is R,G1 [3, 0], # pixel is G2,B ]).view(1,3,2,2), requires_grad=False ) def forward(self, x): '''Debayer image. Parameters ---------- x : Bx1xHxW tensor Images to debayer Returns ------- rgb : Bx3xHxW tensor Color images in RGB channel order. ''' B,C,H,W = x.shape x = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') c = torch.nn.functional.conv2d(x, self.kernels, stride=1) rgb = torch.gather(c, 1, self.index.repeat(B,1,H//2,W//2)) return rgb class Debayer2x2(torch.nn.Module): '''Demosaicing of Bayer images using 2x2 convolutions. Requires BG-Bayer color filter array layout. That is, the image[1,1]='B', image[1,2]='G'. This corresponds to OpenCV naming conventions. ''' def __init__(self): super(Debayer2x2, self).__init__() self.kernels = torch.nn.Parameter( torch.tensor([ [1, 0], [0, 0], [0, 0.5], [0.5, 0], [0, 0], [0, 1], ]).view(3,1,2,2), requires_grad=False ) def forward(self, x): '''Debayer image. Parameters ---------- x : Bx1xHxW tensor Images to debayer Returns ------- rgb : Bx3xHxW tensor Color images in RGB channel order. ''' x = torch.nn.functional.conv2d(x, self.kernels, stride=2) x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) return x class DebayerSplit(torch.nn.Module): '''Demosaicing of Bayer images using 3x3 green convolution and red,blue upsampling. Requires BG-Bayer color filter array layout. That is, the image[1,1]='B', image[1,2]='G'. This corresponds to OpenCV naming conventions. ''' def __init__(self): super().__init__() self.pad = torch.nn.ReflectionPad2d(1) self.kernel = torch.nn.Parameter( torch.tensor([ [0,1,0], [1,0,1], [0,1,0] ])[None, None] * 0.25) def forward(self, x): '''Debayer image. Parameters ---------- x : Bx1xHxW tensor Images to debayer Returns ------- rgb : Bx3xHxW tensor Color images in RGB channel order. ''' B,_,H,W = x.shape red = x[:, :, ::2, ::2] blue = x[:, :, 1::2, 1::2] green = torch.nn.functional.conv2d(self.pad(x), self.kernel) green[:, :, ::2, 1::2] = x[:, :, ::2, 1::2] green[:, :, 1::2, ::2] = x[:, :, 1::2, ::2] return torch.cat(( torch.nn.functional.interpolate(red, size=(H, W), mode='bilinear', align_corners=False), green, torch.nn.functional.interpolate(blue, size=(H, W), mode='bilinear', align_corners=False)), dim=1) ================================================ FILE: code/real/bsrt/utils/interp_methods.py ================================================ from math import pi try: import torch except ImportError: torch = None try: import numpy except ImportError: numpy = None if numpy is None and torch is None: raise ImportError("Must have either Numpy or PyTorch but both not found") def set_framework_dependencies(x): if type(x) is numpy.ndarray: to_dtype = lambda a: a fw = numpy else: to_dtype = lambda a: a.to(x.dtype) fw = torch eps = fw.finfo(fw.float32).eps return fw, to_dtype, eps def support_sz(sz): def wrapper(f): f.support_sz = sz return f return wrapper @support_sz(4) def cubic(x): fw, to_dtype, eps = set_framework_dependencies(x) absx = fw.abs(x) absx2 = absx ** 2 absx3 = absx ** 3 return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * to_dtype((1. < absx) & (absx <= 2.))) @support_sz(4) def lanczos2(x): fw, to_dtype, eps = set_framework_dependencies(x) return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) @support_sz(6) def lanczos3(x): fw, to_dtype, eps = set_framework_dependencies(x) return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) @support_sz(2) def linear(x): fw, to_dtype, eps = set_framework_dependencies(x) return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * to_dtype((0 <= x) & (x <= 1))) @support_sz(1) def box(x): fw, to_dtype, eps = set_framework_dependencies(x) return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) ================================================ FILE: code/real/bsrt/utils/metrics.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F import utils.spatial_color_alignment as sca_utils from utils.spatial_color_alignment import get_gaussian_kernel, match_colors from utils.warp import warp from torch.cuda.amp import autocast from loss.Charbonnier import CharbonnierLoss as CBLoss from loss.mssim import MSSSIM from pytorch_msssim import ssim import lpips class MSSSIMLoss(nn.Module): def __init__(self, boundary_ignore=None): super().__init__() self.boundary_ignore = boundary_ignore self.msssim = MSSSIM() def forward(self, pred, gt, valid=None): if self.boundary_ignore is not None: pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_m = pred gt_m = gt loss = self.msssim(pred_m, gt_m) return loss class CharbonnierLoss(nn.Module): def __init__(self, boundary_ignore=None): super().__init__() self.boundary_ignore = boundary_ignore self.charbonnier_loss = CBLoss(reduce=True) def forward(self, pred, gt, valid=None): if self.boundary_ignore is not None: pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_m = pred gt_m = gt loss = self.charbonnier_loss(pred_m, gt_m) return loss class L1(nn.Module): def __init__(self, boundary_ignore=None): super().__init__() self.boundary_ignore = boundary_ignore def forward(self, pred, gt, valid=None): if self.boundary_ignore is not None: pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] if valid is not None: valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_m = pred gt_m = gt if valid is None: mse = F.l1_loss(pred_m, gt_m) else: mse = F.l1_loss(pred_m, gt_m, reduction='none') eps = 1e-12 elem_ratio = mse.numel() / valid.numel() mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) return mse class L2(nn.Module): def __init__(self, boundary_ignore=None): super().__init__() self.boundary_ignore = boundary_ignore def forward(self, pred, gt, valid=None): if self.boundary_ignore is not None: pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] if valid is not None: valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_m = pred gt_m = gt if valid is None: mse = F.mse_loss(pred_m, gt_m) else: mse = F.mse_loss(pred_m, gt_m, reduction='none') eps = 1e-12 elem_ratio = mse.numel() / valid.numel() mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) return mse class PSNR(nn.Module): def __init__(self, boundary_ignore=None, max_value=1.0): super().__init__() self.l2 = L2(boundary_ignore=boundary_ignore) self.max_value = max_value def psnr(self, pred, gt, valid=None): mse = self.l2(pred, gt, valid=valid) psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10() return psnr def forward(self, pred, gt, valid=None): assert pred.dim() == 4 and pred.shape == gt.shape if valid is None: psnr_all = [self.psnr(p.unsqueeze(0), g.unsqueeze(0)) for p, g in zip(pred, gt)] else: psnr_all = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), v.unsqueeze(0)) for p, g, v in zip(pred, gt, valid)] psnr = sum(psnr_all) / len(psnr_all) return psnr class AlignedL1(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None): super().__init__() self.sr_factor = sr_factor self.boundary_ignore = boundary_ignore self.alignment_net = alignment_net self.gauss_kernel, self.ksz = get_gaussian_kernel(sd=1.5) def forward(self, pred, gt, burst_input): # Estimate flow between the prediction and the ground truth with torch.no_grad(): flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6)) # Warp the prediction to the ground truth coordinates pred_warped = warp(pred, flow) # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between # the input and the ground truth sr_factor = self.sr_factor ds_factor = 1.0 / float(2.0 * sr_factor) flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous() burst_0_warped = warp(burst_0, flow_ds) frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) # Match the colorspace between the prediction and ground truth pred_warped_m, valid = match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz, self.gauss_kernel) # Ignore boundary pixels if specified if self.boundary_ignore is not None: pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_warped_m = pred_warped_m.contiguous() gt = gt.contiguous() # Estimate MSE l1 = F.l1_loss(pred_warped_m, gt, reduction='none') eps = 1e-12 elem_ratio = l1.numel() / valid.numel() l1 = (l1 * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) return l1 class AlignedL2(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None): super().__init__() self.sr_factor = sr_factor self.boundary_ignore = boundary_ignore self.alignment_net = alignment_net self.loss_fn = lpips.LPIPS(net='alex').cuda() self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5) def forward(self, pred, gt, burst_input): # Estimate flow between the prediction and the ground truth with torch.no_grad(): flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6)) # Warp the prediction to the ground truth coordinates pred_warped = warp(pred, flow) # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between # the input and the ground truth sr_factor = self.sr_factor ds_factor = 1.0 / float(2.0 * sr_factor) flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous() burst_0_warped = warp(burst_0, flow_ds) frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) # Match the colorspace between the prediction and ground truth pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz, self.gauss_kernel) # Ignore boundary pixels if specified if self.boundary_ignore is not None: pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] # Estimate MSE mse = F.mse_loss(pred_warped_m.contiguous(), gt.contiguous(), reduction='none') eps = 1e-12 elem_ratio = mse.numel() / valid.numel() mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) ss = ssim(pred_warped_m.contiguous(), gt.contiguous(), data_range=1.0, size_average=True) # eps = 1e-12 # elem_ratio = ss.numel() / valid.numel() # ss = (ss * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) lp = self.loss_fn(pred_warped_m.contiguous(), gt.contiguous()).squeeze() return mse, ss, lp class AlignedPSNR(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None, max_value=1.0): super().__init__() self.l2 = AlignedL2(alignment_net=alignment_net, sr_factor=sr_factor, boundary_ignore=boundary_ignore) self.max_value = max_value def psnr(self, pred, gt, burst_input): mse, ss, lp = self.l2(pred, gt, burst_input) psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10() return psnr, ss, lp def forward(self, pred, gt, burst_input): all_scores = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)] psnr = sum([score[0] for score in all_scores]) / len(all_scores) ssim_ = sum([score[1] for score in all_scores]) / len(all_scores) lpips_ = sum([score[2] for score in all_scores]) / len(all_scores) return psnr, ssim_, lpips_ class AlignedSSIM(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None): super().__init__() self.sr_factor = sr_factor self.boundary_ignore = boundary_ignore self.alignment_net = alignment_net self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5) def _ssim(self, pred, gt, burst_input): # Estimate flow between the prediction and the ground truth with torch.no_grad(): flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6)) # Warp the prediction to the ground truth coordinates pred_warped = warp(pred, flow) # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between # the input and the ground truth sr_factor = self.sr_factor ds_factor = 1.0 / float(2.0 * sr_factor) flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous() burst_0_warped = warp(burst_0, flow_ds) frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) # Match the colorspace between the prediction and ground truth pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz, self.gauss_kernel) # Ignore boundary pixels if specified if self.boundary_ignore is not None: pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] # Estimate MSE mse = ssim(pred_warped_m.contiguous(), gt.contiguous(), data_range=1.0, size_average=True) # print(mse.shape) # eps = 1e-12 # elem_ratio = mse.numel() / valid.numel() # mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) return mse def forward(self, pred, gt, burst_input): ssim_all = [self._ssim(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)] _ssim = sum(ssim_all) / len(ssim_all) return _ssim class AlignedLPIPS(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None): super().__init__() self.sr_factor = sr_factor self.boundary_ignore = boundary_ignore self.alignment_net = alignment_net self.loss_fn = lpips.LPIPS(net='alex').cuda() self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5) def _lpips(self, pred, gt, burst_input): # Estimate flow between the prediction and the ground truth with torch.no_grad(): flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6)) # Warp the prediction to the ground truth coordinates pred_warped = warp(pred, flow) # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between # the input and the ground truth sr_factor = self.sr_factor ds_factor = 1.0 / float(2.0 * sr_factor) flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous() burst_0_warped = warp(burst_0, flow_ds) frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) # Match the colorspace between the prediction and ground truth pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz, self.gauss_kernel) # Ignore boundary pixels if specified if self.boundary_ignore is not None: pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] # Estimate MSE mse = self.loss_fn(pred_warped_m.contiguous(), gt.contiguous()).squeeze() return mse def forward(self, pred, gt, burst_input): lpips_all = [self._lpips(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)] _lpips = sum(lpips_all) / len(lpips_all) return _lpips ================================================ FILE: code/real/bsrt/utils/postprocessing_functions.py ================================================ import torch import numpy as np import utils.data_format_utils as df_utils from data_processing.camera_pipeline import apply_gains, apply_ccm, apply_smoothstep, gamma_compression class SimplePostProcess: def __init__(self, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False): self.gains = gains self.ccm = ccm self.gamma = gamma self.smoothstep = smoothstep self.return_np = return_np def process(self, image, meta_info): return process_linear_image_rgb(image, meta_info, self.gains, self.ccm, self.gamma, self.smoothstep, self.return_np) def process_linear_image_rgb(image, meta_info, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False): if gains: image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain']) if ccm: image = apply_ccm(image, meta_info['cam2rgb']) if meta_info['gamma'] and gamma: image = gamma_compression(image) if meta_info['smoothstep'] and smoothstep: image = apply_smoothstep(image) image = image.clamp(0.0, 1.0) if return_np: image = df_utils.torch_to_npimage(image) return image class BurstSRPostProcess: def __init__(self, no_white_balance=False, gamma=True, smoothstep=True, return_np=False): self.no_white_balance = no_white_balance self.gamma = gamma self.smoothstep = smoothstep self.return_np = return_np def process(self, image, meta_info, external_norm_factor=None): return process_burstsr_image_rgb(image, meta_info, external_norm_factor=external_norm_factor, no_white_balance=self.no_white_balance, gamma=self.gamma, smoothstep=self.smoothstep, return_np=self.return_np) def process_burstsr_image_rgb(im, meta_info, return_np=False, external_norm_factor=None, gamma=True, smoothstep=True, no_white_balance=False): im = im * meta_info.get('norm_factor', 1.0) if not meta_info.get('black_level_subtracted', False): im = (im - torch.tensor(meta_info['black_level'])[[0, 1, -1]].view(3, 1, 1).to(im.device)) if not meta_info.get('while_balance_applied', False) and not no_white_balance: im = im * (meta_info['cam_wb'][[0, 1, -1]].view(3, 1, 1) / meta_info['cam_wb'][1]).to(im.device) im_out = im if external_norm_factor is None: im_out = im_out / im_out.max() else: im_out = im_out / external_norm_factor im_out = im_out.clamp(0.0, 1.0) if gamma: im_out = im_out ** (1.0 / 2.2) if smoothstep: # Smooth curve im_out = 3 * im_out ** 2 - 2 * im_out ** 3 if return_np: im_out = im_out.permute(1, 2, 0).cpu().numpy() * 255.0 im_out = im_out.astype(np.uint8) return im_out ================================================ FILE: code/real/bsrt/utils/resize_right.py ================================================ import warnings from math import ceil import interp_methods class NoneClass: pass try: import torch from torch import nn nnModuleWrapped = nn.Module except ImportError: warnings.warn('No PyTorch found, will work only with Numpy') torch = None nnModuleWrapped = NoneClass try: import numpy except ImportError: warnings.warn('No Numpy found, will work only with PyTorch') numpy = None if numpy is None and torch is None: raise ImportError("Must have either Numpy or PyTorch but both not found") def resize(input, scale_factors=None, out_shape=None, interp_method=interp_methods.cubic, support_sz=None, antialiasing=True): # get properties of the input tensor in_shape, n_dims = input.shape, input.ndim # fw stands for framework that can be either numpy or torch, # determined by the input type fw = numpy if type(input) is numpy.ndarray else torch eps = fw.finfo(fw.float32).eps # set missing scale factors or output shapem one according to another, # scream if both missing scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw) # sort indices of dimensions according to scale of each dimension. # since we are going dim by dim this is efficient sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind]) if scale_factors[dim] != 1.] # unless support size is specified by the user, it is an attribute # of the interpolation method if support_sz is None: support_sz = interp_method.support_sz # when using pytorch, we need to know what is the input tensor device if fw is torch: device = input.device # output begins identical to input and changes with each iteration output = input # iterate over dims for dim, scale_factor in sorted_filtered_dims_and_scales: # get 1d set of weights and fields of view for each output location # along this dim field_of_view, weights = prepare_weights_and_field_of_view_1d( dim, scale_factor, in_shape[dim], out_shape[dim], interp_method, support_sz, antialiasing, fw, eps, device) # multiply the weights by the values in the field of view and # aggreagate output = apply_weights(output, field_of_view, weights, dim, n_dims, fw) return output class ResizeLayer(nnModuleWrapped): def __init__(self, in_shape, scale_factors=None, out_shape=None, interp_method=interp_methods.cubic, support_sz=None, antialiasing=True): super(ResizeLayer, self).__init__() # fw stands for framework, that can be either numpy or torch. since # this is a torch layer, only one option in this case. fw = torch eps = fw.finfo(fw.float32).eps # set missing scale factors or output shapem one according to another, # scream if both missing scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw) # unless support size is specified by the user, it is an attribute # of the interpolation method if support_sz is None: support_sz = interp_method.support_sz self.n_dims = len(in_shape) # sort indices of dimensions according to scale of each dimension. # since we are going dim by dim this is efficient self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) for dim in sorted(range(self.n_dims), key=lambda ind: scale_factors[ind]) if scale_factors[dim] != 1.] # iterate over dims field_of_view_list = [] weights_list = [] for dim, scale_factor in self.sorted_filtered_dims_and_scales: # get 1d set of weights and fields of view for each output # location along this dim field_of_view, weights = prepare_weights_and_field_of_view_1d( dim, scale_factor, in_shape[dim], out_shape[dim], interp_method, support_sz, antialiasing, fw, eps, input.device) # keep weights and fields of views for all dims weights_list.append(nn.Parameter(weights, requires_grad=False)) field_of_view_list.append(nn.Parameter(field_of_view, requires_grad=False)) self.field_of_view = nn.ParameterList(field_of_view_list) self.weights = nn.ParameterList(weights_list) self.in_shape = in_shape def forward(self, input): # output begins identical to input and changes with each iteration output = input for (dim, scale_factor), field_of_view, weights in zip( self.sorted_filtered_dims_and_scales, self.field_of_view, self.weights): # multiply the weights by the values in the field of view and # aggreagate output = apply_weights(output, field_of_view, weights, dim, self.n_dims, torch) return output def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz, interp_method, support_sz, antialiasing, fw, eps, device=None): # If antialiasing is taking place, we modify the window size and the # interpolation method (see inside function) interp_method, cur_support_sz = apply_antialiasing_if_needed( interp_method, support_sz, scale_factor, antialiasing) # STEP 1- PROJECTED GRID: The non-integer locations of the projection of # output pixel locations to the input tensor projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device) # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels # that influence it field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps) # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the # field of view for each output pixel weights = get_weights(interp_method, projected_grid, field_of_view) return field_of_view, weights def apply_weights(input, field_of_view, weights, dim, n_dims, fw): # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying # its set of weights with the pixel values in its field of view. # We now multiply the fields of view with their matching weights. # We do this by tensor multiplication and broadcasting. # this step is separated to a different function, so that it can be # repeated with the same calculated weights and fields. # for this operations we assume the resized dim is the first one. # so we transpose and will transpose back after multiplying tmp_input = fw_swapaxes(input, dim, 0, fw) # field_of_view is a tensor of order 2: for each output (1d location # along cur dim)- a list of 1d neighbors locations. # note that this whole operations is applied to each dim separately, # this is why it is all in 1d. # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1: # for each output pixel (this time indicated in all dims), these are the # values of the neighbors in the 1d field of view. note that we only # consider neighbors along the current dim, but such set exists for every # multi-dim location, hence the final tensor order is image_dims+1. neighbors = tmp_input[field_of_view] # weights is an order 2 tensor: for each output location along 1d- a list # of weighs matching the field of view. we augment it with ones, for # broadcasting, so that when multiplies some tensor the weights affect # only its first dim. tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1))) # now we simply multiply the weights with the neighbors, and then sum # along the field of view, to get a single value per out pixel tmp_output = (neighbors * tmp_weights).sum(1) # we transpose back the resized dim to its original position return fw_swapaxes(tmp_output, 0, dim, fw) def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw): # eventually we must have both scale-factors and out-sizes for all in/out # dims. however, we support many possible partial arguments if scale_factors is None and out_shape is None: raise ValueError("either scale_factors or out_shape should be " "provided") if out_shape is not None: # if out_shape has less dims than in_shape, we defaultly resize the # first dims for numpy and last dims for torch out_shape = (list(out_shape) + list(in_shape[:-len(out_shape)]) if fw is numpy else list(in_shape[:-len(out_shape)]) + list(out_shape)) if scale_factors is None: # if no scale given, we calculate it as the out to in ratio # (not recomended) scale_factors = [out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape)] if scale_factors is not None: # by default, if a single number is given as scale, we assume resizing # two dims (most common are images with 2 spatial dims) scale_factors = (scale_factors if isinstance(scale_factors, (list, tuple)) else [scale_factors, scale_factors]) # if less scale_factors than in_shape dims, we defaultly resize the # first dims for numpy and last dims for torch scale_factors = (list(scale_factors) + [1] * (len(in_shape) - len(scale_factors)) if fw is numpy else [1] * (len(in_shape) - len(scale_factors)) + list(scale_factors)) if out_shape is None: # when no out_shape given, it is calculated by multiplying the # scale by the in_shape (not recomended) out_shape = [ceil(scale_factor * in_sz) for scale_factor, in_sz in zip(scale_factors, in_shape)] # next line intentionally after out_shape determined for stability scale_factors = [float(sf) for sf in scale_factors] return scale_factors, out_shape def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None): # we start by having the ouput coordinates which are just integer locations out_coordinates = fw.arange(out_sz) # if using torch we need to match the grid tensor device to the input device out_coordinates = fw_set_device(out_coordinates, device, fw) # This is projecting the ouput pixel locations in 1d to the input tensor, # as non-integer locations. # the following fomrula is derived in the paper # "From Discrete to Continuous Convolutions" by Shocher et al. return (out_coordinates / scale_factor + (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor)) def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps): # for each output pixel, map which input pixels influence it, in 1d. # we start by calculating the leftmost neighbor, using half of the window # size (eps is for when boundary is exact int) left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) # then we simply take all the pixel centers in the field by counting # window size pixels from the left boundary ordinal_numbers = fw.arange(ceil(cur_support_sz - eps)) # in case using torch we need to match the device ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw) field_of_view = left_boundaries[:, None] + ordinal_numbers # next we do a trick instead of padding, we map the field of view so that # it would be like mirror padding, without actually padding # (which would require enlarging the input tensor) mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw) field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])] field_of_view = fw_set_device(field_of_view,projected_grid.device, fw) return field_of_view def get_weights(interp_method, projected_grid, field_of_view): # the set of weights per each output pixels is the result of the chosen # interpolation method applied to the distances between projected grid # locations and the pixel-centers in the field of view (distances are # directed, can be positive or negative) weights = interp_method(projected_grid[:, None] - field_of_view) # we now carefully normalize the weights to sum to 1 per each output pixel sum_weights = weights.sum(1, keepdims=True) sum_weights[sum_weights == 0] = 1 return weights / sum_weights def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, antialiasing): # antialiasing is "stretching" the field of view according to the scale # factor (only for downscaling). this is low-pass filtering. this # requires modifying both the interpolation (stretching the 1d # function and multiplying by the scale-factor) and the window size. if scale_factor >= 1.0 or not antialiasing: return interp_method, support_sz cur_interp_method = (lambda arg: scale_factor * interp_method(scale_factor * arg)) cur_support_sz = support_sz / scale_factor return cur_interp_method, cur_support_sz def fw_ceil(x, fw): if fw is numpy: return fw.int_(fw.ceil(x)) else: return x.ceil().long() def fw_cat(x, fw): if fw is numpy: return fw.concatenate(x) else: return fw.cat(x) def fw_swapaxes(x, ax_1, ax_2, fw): if fw is numpy: return fw.swapaxes(x, ax_1, ax_2) else: return x.transpose(ax_1, ax_2) def fw_set_device(x, device, fw): if fw is numpy: return x else: return x.to(device) ================================================ FILE: code/real/bsrt/utils/spatial_color_alignment.py ================================================ import math import torch import torch.nn.functional as F def gauss_1d(sz, sigma, center, end_pad=0, density=False): """ Returns a 1-D Gaussian """ k = torch.arange(-(sz-1)/2, (sz+1)/2 + end_pad).reshape(1, -1) gauss = torch.exp(-1.0/(2*sigma**2) * (k - center.reshape(-1, 1))**2) if density: gauss /= math.sqrt(2*math.pi) * sigma return gauss def gauss_2d(sz, sigma, center, end_pad=(0, 0), density=False): """ Returns a 2-D Gaussian """ if isinstance(sigma, (float, int)): sigma = (sigma, sigma) if isinstance(sz, int): sz = (sz, sz) if isinstance(center, (list, tuple)): center = torch.tensor(center).view(1, 2) return gauss_1d(sz[0], sigma[0], center[:, 0], end_pad[0], density).reshape(center.shape[0], 1, -1) * \ gauss_1d(sz[1], sigma[1], center[:, 1], end_pad[1], density).reshape(center.shape[0], -1, 1) def get_gaussian_kernel(sd): """ Returns a Gaussian kernel with standard deviation sd """ ksz = int(4 * sd + 1) assert ksz % 2 == 1 K = gauss_2d(ksz, sd, (0.0, 0.0), density=True) K = K / K.sum() return K.unsqueeze(0), ksz def apply_kernel(im, ksz, gauss_kernel): shape = im.shape im = im.view(-1, 1, *im.shape[-2:]) pad = [ksz // 2, ksz // 2, ksz // 2, ksz // 2] im = F.pad(im, pad, mode='reflect') im_mean = F.conv2d(im, gauss_kernel).view(shape) return im_mean def match_colors(im_ref, im_q, im_test, ksz, gauss_kernel): """ Estimates a color transformation matrix between im_ref and im_q. Applies the estimated transformation to im_test """ gauss_kernel = gauss_kernel.to(im_ref.device) bi = 5 # Apply Gaussian smoothing im_ref_mean = apply_kernel(im_ref, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous() im_q_mean = apply_kernel(im_q, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous() im_ref_mean_re = im_ref_mean.view(*im_ref_mean.shape[:2], -1) im_q_mean_re = im_q_mean.view(*im_q_mean.shape[:2], -1) # Estimate color transformation matrix by minimizing the least squares error c_mat_all = [] for ir, iq in zip(im_ref_mean_re, im_q_mean_re): c = torch.lstsq(ir.t(), iq.t()) c = c.solution[:3] c_mat_all.append(c) c_mat = torch.stack(c_mat_all, dim=0) im_q_mean_conv = torch.matmul(im_q_mean_re.permute(0, 2, 1), c_mat).permute(0, 2, 1) im_q_mean_conv = im_q_mean_conv.view(im_q_mean.shape) err = ((im_q_mean_conv - im_ref_mean) * 255.0).norm(dim=1) thresh = 20 # If error is larger than a threshold, ignore these pixels valid = err < thresh pad = (im_q.shape[-1] - valid.shape[-1]) // 2 pad = [pad, pad, pad, pad] valid = F.pad(valid, pad) upsample_factor = im_test.shape[-1] / valid.shape[-1] valid = F.interpolate(valid.unsqueeze(1).float(), scale_factor=upsample_factor, mode='bilinear', align_corners=False) valid = valid > 0.9 # Apply the transformation to test image im_test_re = im_test.view(*im_test.shape[:2], -1) im_t_conv = torch.matmul(im_test_re.permute(0, 2, 1), c_mat).permute(0, 2, 1) im_t_conv = im_t_conv.view(im_test.shape) return im_t_conv, valid ================================================ FILE: code/real/bsrt/utils/stn.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class SpatialTransformer(nn.Module): """ [SpatialTransformer] represesents a spatial transformation block that uses the output from the UNet to preform an grid_sample https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ def __init__(self, size, mode='bilinear'): """ Instiatiate the block :param size: size of input to the spatial transformer block :param mode: method of interpolation for grid_sampler """ super(OldSpatialTransformer, self).__init__() if isinstance(size, int): size = (size, size) # Create sampling grid vectors = [ torch.arange(0, s) for s in size ] grids = torch.meshgrid(vectors) grid = torch.stack(grids) # y, x, z grid = torch.unsqueeze(grid, 0) #add batch grid = grid.type(torch.FloatTensor) self.register_buffer('grid', grid) self.mode = mode def forward(self, src, flow): """ Push the src and flow through the spatial transform block :param src: the original moving image :param flow: the output from the U-Net """ new_locs = self.grid + flow shape = flow.shape[2:] # Need to normalize grid values to [-1, 1] for resampler for i in range(len(shape)): new_locs[:,i,...] = 2*(new_locs[:,i,...]/(shape[i]-1) - 0.5) if len(shape) == 2: new_locs = new_locs.permute(0, 2, 3, 1) new_locs = new_locs[..., [1,0]] elif len(shape) == 3: new_locs = new_locs.permute(0, 2, 3, 4, 1) new_locs = new_locs[..., [2,1,0]] return F.grid_sample(src, new_locs, mode=self.mode, align_corners=True) ================================================ FILE: code/real/bsrt/utils/warp.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F def warp(feat, flow, mode='bilinear', padding_mode='zeros'): """ warp an image/tensor (im2) back to im1, according to the optical flow im1 --> im2 input flow must be in format (x, y) at every pixel feat: [B, C, H, W] (im2) flow: [B, 2, H, W] flow (x, y) """ B, C, H, W = feat.size() # print(feat.device, flow.device) # mesh grid rowv, colv = torch.meshgrid([torch.arange(0.5, H + 0.5), torch.arange(0.5, W + 0.5)]) grid = torch.stack((colv, rowv), dim=0).unsqueeze(0).float().to(flow.device) # print(grid.device, flow.device, feat.device) # grid = grid.cuda() grid = grid + flow # scale grid to [-1,1] grid_norm_c = 2.0 * grid[:, 0] / W - 1.0 grid_norm_r = 2.0 * grid[:, 1] / H - 1.0 grid_norm = torch.stack((grid_norm_c, grid_norm_r), dim=1).to(flow.device) grid_norm = grid_norm.permute(0, 2, 3, 1) output = F.grid_sample(feat, grid_norm, mode=mode, align_corners=False, padding_mode=padding_mode) return output ================================================ FILE: code/real/bsrt/validate.py ================================================ import cv2 import torch import numpy as np import os from tqdm import tqdm import random import utility from option import args import torchvision.utils as tvutils from pwcnet.pwcnet import PWCNet from utils.postprocessing_functions import BurstSRPostProcess from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image_batch, pack_raw_image from utils.metrics import AlignedPSNR from utils.data_format_utils import convert_dict from data_processing.camera_pipeline import demosaic import model import torch.multiprocessing as mp import torch.backends.cudnn as cudnn import torch.utils.data.distributed import time checkpoint = utility.checkpoint(args) def main(): mp.spawn(main_worker, nprocs=1, args=(1, args)) def main_worker(local_rank, nprocs, args): cudnn.benchmark = True args.local_rank = local_rank utility.setup(local_rank, nprocs) torch.cuda.set_device(local_rank) dataset = BurstSRDataset(root=args.root, burst_size=14, crop_sz=80, split='val') # out_dir = 'val/ebsr_real' _model = model.Model(args, checkpoint) for param in _model.parameters(): param.requires_grad = False alignment_net = PWCNet(load_pretrained=True, weights_path='./pwcnet/pwcnet-network-default.pth') alignment_net = alignment_net.to('cuda') for param in alignment_net.parameters(): param.requires_grad = False aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40) postprocess_fn = BurstSRPostProcess(return_np=True) # os.makedirs(out_dir, exist_ok=True) tt = [] psnrs, ssims, lpipss = [], [], [] for idx in tqdm(range(len(dataset))): burst, gt, meta_info_burst, meta_info_gt = dataset[idx] burst = burst.unsqueeze(0).cuda() gt = gt.unsqueeze(0).cuda() with torch.no_grad(): tic = time.time() sr = _model(burst, 0).float() toc = time.time() tt.append(toc-tic) # sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short() # sr = sr_int.float() / (2 ** 14) psnr, ssim, lpips = aligned_psnr_fn(sr, gt, burst) psnrs.append(psnr.item()) ssims.append(ssim.item()) lpipss.append(lpips.item()) # os.makedirs(f'{out_dir}/{idx}', exist_ok=True) # sr_ = postprocess_fn.process(sr[0], meta_info_burst) # sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR) # cv2.imwrite('{}/{}_sr.png'.format(out_dir, idx), sr_) del burst del sr del gt print(f'avg PSNR: {np.mean(psnrs):.6f}') print(f'avg SSIM: {np.mean(ssims):.6f}') print(f'avg LPIPS: {np.mean(lpipss):.6f}') print(f' avg time: {np.mean(tt):.6f}') # utility.cleanup() if __name__ == '__main__': main() ================================================ FILE: code/synthetic/bsrt/README.md ================================================ # BSRT: Improving Burst Super-Resolution with Swin Transformer and Flow-Guided Deformable Alignment (Synthetic) ## Dependencies - OS: Ubuntu 18.04 - Python: Python 3.7 - nvidia : - cuda: 10.1 - cudnn: 7.6.1 - Other reference requirements ## Quick Start 1.Create a conda virtual environment and activate it ```python3 conda create -n pytorch_1.6 python=3.7 source activate pytorch_1.6 ``` 2.Install PyTorch and torchvision following the official instructions ```python3 conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch ``` 3.Install build requirements ```python3 pip3 install -r requirements.txt ``` 4.Install DCN ```python3 cd DCNv2 python3 setup.py build develop # build python3 test.py # run examples and check ``` ## Training ```python3 # Modify the root path of training dataset and model etc. # The number of GPUs should be more than 1 python main.py --n_GPUs 8 --print_every 40 --lr 0.00003 --decay 50-100 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 32 --burst_size 14 --patch_size 256 ``` ## Test ```python3 # Modify the path of test dataset and the path of the trained model python test_synburst.py --n_GPUs 1 --model BSRT --model_level S --fp16 --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth --root /data/dataset/ntire21/burstsr/synthetic ``` ================================================ FILE: code/synthetic/bsrt/data_processing/__init__.py ================================================ ================================================ FILE: code/synthetic/bsrt/data_processing/camera_pipeline.py ================================================ import torch import random import math import cv2 as cv import numpy as np import utils.data_format_utils as df_utils """ Based on http://timothybrooks.com/tech/unprocessing Functions for forward and inverse camera pipeline. All functions input a torch float tensor of shape (c, h, w). Additionally, some also support batch operations, i.e. inputs of shape (b, c, h, w) """ def random_ccm(): """Generates random RGB -> Camera color correction matrices.""" # Takes a random convex combination of XYZ -> Camera CCMs. xyz2cams = [[[1.0234, -0.2969, -0.2266], [-0.5625, 1.6328, -0.0469], [-0.0703, 0.2188, 0.6406]], [[0.4913, -0.0541, -0.0202], [-0.613, 1.3513, 0.2906], [-0.1564, 0.2151, 0.7183]], [[0.838, -0.263, -0.0639], [-0.2887, 1.0725, 0.2496], [-0.0627, 0.1427, 0.5438]], [[0.6596, -0.2079, -0.0562], [-0.4782, 1.3016, 0.1933], [-0.097, 0.1581, 0.5181]]] num_ccms = len(xyz2cams) xyz2cams = torch.tensor(xyz2cams) weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(0.0, 1.0) weights_sum = weights.sum() xyz2cam = (xyz2cams * weights).sum(dim=0) / weights_sum # Multiplies with RGB -> XYZ to get RGB -> Camera CCM. rgb2xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375], [0.2126729, 0.7151522, 0.0721750], [0.0193339, 0.1191920, 0.9503041]]) rgb2cam = torch.mm(xyz2cam, rgb2xyz) # Normalizes each row. rgb2cam = rgb2cam / rgb2cam.sum(dim=-1, keepdims=True) return rgb2cam def random_gains(): """Generates random gains for brightening and white balance.""" # RGB gain represents brightening. rgb_gain = 1.0 / random.gauss(mu=0.8, sigma=0.1) # Red and blue gains represent white balance. red_gain = random.uniform(1.9, 2.4) blue_gain = random.uniform(1.5, 1.9) return rgb_gain, red_gain, blue_gain def apply_smoothstep(image): """Apply global tone mapping curve.""" image_out = 3 * image**2 - 2 * image**3 return image_out def invert_smoothstep(image): """Approximately inverts a global tone mapping curve.""" image = image.clamp(0.0, 1.0) return 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0) def gamma_expansion(image): """Converts from gamma to linear space.""" # Clamps to prevent numerical instability of gradients near zero. return image.clamp(1e-8) ** 2.2 def gamma_compression(image): """Converts from linear to gammaspace.""" # Clamps to prevent numerical instability of gradients near zero. return image.clamp(1e-8) ** (1.0 / 2.2) def apply_ccm(image, ccm): """Applies a color correction matrix.""" assert image.dim() == 3 and image.shape[0] == 3 shape = image.shape image = image.view(3, -1) ccm = ccm.to(image.device).type_as(image) image = torch.mm(ccm, image) return image.view(shape) def apply_gains(image, rgb_gain, red_gain, blue_gain): """Inverts gains while safely handling saturated pixels.""" assert image.dim() == 3 and image.shape[0] in [3, 4] if image.shape[0] == 3: gains = torch.tensor([red_gain, 1.0, blue_gain]) * rgb_gain else: gains = torch.tensor([red_gain, 1.0, 1.0, blue_gain]) * rgb_gain gains = gains.view(-1, 1, 1) gains = gains.to(image.device).type_as(image) return (image * gains).clamp(0.0, 1.0) def safe_invert_gains(image, rgb_gain, red_gain, blue_gain): """Inverts gains while safely handling saturated pixels.""" assert image.dim() == 3 and image.shape[0] == 3 gains = torch.tensor([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain gains = gains.view(-1, 1, 1) # Prevents dimming of saturated pixels by smoothly masking gains near white. gray = image.mean(dim=0, keepdims=True) inflection = 0.9 mask = ((gray - inflection).clamp(0.0) / (1.0 - inflection)) ** 2.0 safe_gains = torch.max(mask + (1.0 - mask) * gains, gains) return image * safe_gains def mosaic(image, mode='rggb'): """Extracts RGGB Bayer planes from an RGB image.""" shape = image.shape if image.dim() == 3: image = image.unsqueeze(0) if mode == 'rggb': red = image[:, 0, 0::2, 0::2] green_red = image[:, 1, 0::2, 1::2] green_blue = image[:, 1, 1::2, 0::2] blue = image[:, 2, 1::2, 1::2] image = torch.stack((red, green_red, green_blue, blue), dim=1) elif mode == 'grbg': green_red = image[:, 1, 0::2, 0::2] red = image[:, 0, 0::2, 1::2] blue = image[:, 2, 0::2, 1::2] green_blue = image[:, 1, 1::2, 1::2] image = torch.stack((green_red, red, blue, green_blue), dim=1) if len(shape) == 3: return image.view((4, shape[-2] // 2, shape[-1] // 2)) else: return image.view((-1, 4, shape[-2] // 2, shape[-1] // 2)) def demosaic(image): assert isinstance(image, torch.Tensor) image = image.clamp(0.0, 1.0) * 255 if image.dim() == 4: num_images = image.dim() batch_input = True else: num_images = 1 batch_input = False image = image.unsqueeze(0) # Generate single channel input for opencv im_sc = torch.zeros((num_images, image.shape[-2] * 2, image.shape[-1] * 2, 1)) im_sc[:, ::2, ::2, 0] = image[:, 0, :, :] im_sc[:, ::2, 1::2, 0] = image[:, 1, :, :] im_sc[:, 1::2, ::2, 0] = image[:, 2, :, :] im_sc[:, 1::2, 1::2, 0] = image[:, 3, :, :] im_sc = im_sc.numpy().astype(np.uint8) out = [] for im in im_sc: # cv.imwrite('frames/tmp.png', im) im_dem_np = cv.cvtColor(im, cv.COLOR_BAYER_BG2RGB)#_VNG) # Convert to torch image im_t = df_utils.npimage_to_torch(im_dem_np, input_bgr=False) out.append(im_t) if batch_input: return torch.stack(out, dim=0) else: return out[0] def random_noise_levels(): """Generates random noise levels from a log-log linear distribution.""" log_min_shot_noise = math.log(0.0001) log_max_shot_noise = math.log(0.012) log_shot_noise = random.uniform(log_min_shot_noise, log_max_shot_noise) shot_noise = math.exp(log_shot_noise) line = lambda x: 2.18 * x + 1.20 log_read_noise = line(log_shot_noise) + random.gauss(mu=0.0, sigma=0.26) read_noise = math.exp(log_read_noise) return shot_noise, read_noise def add_noise(image, shot_noise=0.01, read_noise=0.0005): """Adds random shot (proportional to image) and read (independent) noise.""" variance = image * shot_noise + read_noise noise = torch.FloatTensor(image.shape).normal_().to(image.device)*variance.sqrt() return image + noise def process_linear_image_rgb(image, meta_info, return_np=False): image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain']) image = apply_ccm(image, meta_info['cam2rgb']) if meta_info['gamma']: image = gamma_compression(image) if meta_info['smoothstep']: image = apply_smoothstep(image) image = image.clamp(0.0, 1.0) if return_np: image = df_utils.torch_to_npimage(image) return image def process_linear_image_raw(image, meta_info): image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain']) image = demosaic(image) image = apply_ccm(image, meta_info['cam2rgb']) if meta_info['gamma']: image = gamma_compression(image) if meta_info['smoothstep']: image = apply_smoothstep(image) return image.clamp(0.0, 1.0) ================================================ FILE: code/synthetic/bsrt/data_processing/synthetic_burst_generation.py ================================================ import torch import random import cv2 import numpy as np import torch.nn.functional as F from data_processing.camera_pipeline import * from utils.data_format_utils import torch_to_numpy, numpy_to_torch def random_crop(frames, crop_sz): """ Extract a random crop of size crop_sz from the input frames. If the crop_sz is larger than the input image size, then the largest possible crop of same aspect ratio as crop_sz will be extracted from frames, and upsampled to crop_sz. """ if not isinstance(crop_sz, (tuple, list)): crop_sz = (crop_sz, crop_sz) crop_sz = torch.tensor(crop_sz).float() shape = frames.shape # Select scale_factor. Ensure the crop fits inside the image max_scale_factor = torch.tensor(shape[-2:]).float() / crop_sz max_scale_factor = max_scale_factor.min().item() if max_scale_factor < 1.0: scale_factor = max_scale_factor else: scale_factor = 1.0 # Extract the crop orig_crop_sz = (crop_sz * scale_factor).floor() assert orig_crop_sz[-2] <= shape[-2] and orig_crop_sz[-1] <= shape[-1], 'Bug in crop size estimation!' r1 = random.randint(0, shape[-2] - orig_crop_sz[-2]) c1 = random.randint(0, shape[-1] - orig_crop_sz[-1]) r2 = r1 + orig_crop_sz[0].int().item() c2 = c1 + orig_crop_sz[1].int().item() frames_crop = frames[:, r1:r2, c1:c2] # Resize to crop_sz if scale_factor < 1.0: frames_crop = F.interpolate(frames_crop.unsqueeze(0), size=crop_sz.int().tolist(), mode='bilinear', align_corners=False).squeeze(0) return frames_crop def rgb2rawburst(image, burst_size, downsample_factor=1, burst_transformation_params=None, image_processing_params=None, interpolation_type='bilinear'): """ Generates a synthetic LR RAW burst from the input image. The input sRGB image is first converted to linear sensor space using an inverse camera pipeline. A LR burst is then generated by applying random transformations defined by burst_transformation_params to the input image, and downsampling it by the downsample_factor. The generated burst is then mosaicekd and corrputed by random noise. """ if image_processing_params is None: image_processing_params = {} _defaults = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 'gamma': True, 'add_noise': True} for k, v in _defaults.items(): if k not in image_processing_params: image_processing_params[k] = v # Sample camera pipeline params if image_processing_params['random_ccm']: rgb2cam = random_ccm() else: rgb2cam = torch.eye(3).float() cam2rgb = rgb2cam.inverse() # Sample gains if image_processing_params['random_gains']: rgb_gain, red_gain, blue_gain = random_gains() else: rgb_gain, red_gain, blue_gain = (1.0, 1.0, 1.0) # Approximately inverts global tone mapping. use_smoothstep = image_processing_params['smoothstep'] if use_smoothstep: image = invert_smoothstep(image) # Inverts gamma compression. use_gamma = image_processing_params['gamma'] if use_gamma: image = gamma_expansion(image) # Inverts color correction. image = apply_ccm(image, rgb2cam) # Approximately inverts white balance and brightening. image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain) # Clip saturated pixels. image = image.clamp(0.0, 1.0) # Generate LR burst image_burst_rgb, flow_vectors = single2lrburst(image, burst_size=burst_size, downsample_factor=downsample_factor, transformation_params=burst_transformation_params, interpolation_type=interpolation_type) # mosaic image_burst = mosaic(image_burst_rgb.clone()) # Add noise if image_processing_params['add_noise']: shot_noise_level, read_noise_level = random_noise_levels() image_burst = add_noise(image_burst, shot_noise_level, read_noise_level) else: shot_noise_level = 0 read_noise_level = 0 # Clip saturated pixels. image_burst = image_burst.clamp(0.0, 1.0) meta_info = {'rgb2cam': rgb2cam, 'cam2rgb': cam2rgb, 'rgb_gain': rgb_gain, 'red_gain': red_gain, 'blue_gain': blue_gain, 'smoothstep': use_smoothstep, 'gamma': use_gamma, 'shot_noise_level': shot_noise_level, 'read_noise_level': read_noise_level} return image_burst, image, image_burst_rgb, flow_vectors, meta_info def get_tmat(image_shape, translation, theta, shear_values, scale_factors): """ Generates a transformation matrix corresponding to the input transformation parameters """ im_h, im_w = image_shape t_mat = np.identity(3) t_mat[0, 2] = translation[0] t_mat[1, 2] = translation[1] t_rot = cv2.getRotationMatrix2D((im_w * 0.5, im_h * 0.5), theta, 1.0) t_rot = np.concatenate((t_rot, np.array([0.0, 0.0, 1.0]).reshape(1, 3))) t_shear = np.array([[1.0, shear_values[0], -shear_values[0] * 0.5 * im_w], [shear_values[1], 1.0, -shear_values[1] * 0.5 * im_h], [0.0, 0.0, 1.0]]) t_scale = np.array([[scale_factors[0], 0.0, 0.0], [0.0, scale_factors[1], 0.0], [0.0, 0.0, 1.0]]) t_mat = t_scale @ t_rot @ t_shear @ t_mat t_mat = t_mat[:2, :] return t_mat def single2lrburst(image, burst_size, downsample_factor=1, transformation_params=None, interpolation_type='bilinear'): """ Generates a burst of size burst_size from the input image by applying random transformations defined by transformation_params, and downsampling the resulting burst by downsample_factor. """ if interpolation_type == 'bilinear': interpolation = cv2.INTER_LINEAR elif interpolation_type == 'lanczos': interpolation = cv2.INTER_LANCZOS4 else: raise ValueError normalize = False if isinstance(image, torch.Tensor): if image.max() < 2.0: image = image * 255.0 normalize = True image = torch_to_numpy(image).astype(np.uint8) burst = [] sample_pos_inv_all = [] rvs, cvs = torch.meshgrid([torch.arange(0, image.shape[0]), torch.arange(0, image.shape[1])]) sample_grid = torch.stack((cvs, rvs, torch.ones_like(cvs)), dim=-1).float() for i in range(burst_size): if i == 0: # For base image, do not apply any random transformations. We only translate the image to center the # sampling grid shift = (downsample_factor / 2.0) - 0.5 translation = (shift, shift) theta = 0.0 shear_factor = (0.0, 0.0) scale_factor = (1.0, 1.0) else: # Sample random image transformation parameters max_translation = transformation_params.get('max_translation', 0.0) if max_translation <= 0.01: shift = (downsample_factor / 2.0) - 0.5 translation = (shift, shift) else: translation = (random.uniform(-max_translation, max_translation), random.uniform(-max_translation, max_translation)) max_rotation = transformation_params.get('max_rotation', 0.0) theta = random.uniform(-max_rotation, max_rotation) max_shear = transformation_params.get('max_shear', 0.0) shear_x = random.uniform(-max_shear, max_shear) shear_y = random.uniform(-max_shear, max_shear) shear_factor = (shear_x, shear_y) max_ar_factor = transformation_params.get('max_ar_factor', 0.0) ar_factor = np.exp(random.uniform(-max_ar_factor, max_ar_factor)) max_scale = transformation_params.get('max_scale', 0.0) scale_factor = np.exp(random.uniform(-max_scale, max_scale)) scale_factor = (scale_factor, scale_factor * ar_factor) output_sz = (image.shape[1], image.shape[0]) # Generate a affine transformation matrix corresponding to the sampled parameters t_mat = get_tmat((image.shape[0], image.shape[1]), translation, theta, shear_factor, scale_factor) t_mat_tensor = torch.from_numpy(t_mat) # Apply the sampled affine transformation image_t = cv2.warpAffine(image, t_mat, output_sz, flags=interpolation, borderMode=cv2.BORDER_CONSTANT) t_mat_tensor_3x3 = torch.cat((t_mat_tensor.float(), torch.tensor([0.0, 0.0, 1.0]).view(1, 3)), dim=0) t_mat_tensor_inverse = t_mat_tensor_3x3.inverse()[:2, :].contiguous() sample_pos_inv = torch.mm(sample_grid.view(-1, 3), t_mat_tensor_inverse.t().float()).view( *sample_grid.shape[:2], -1) if transformation_params.get('border_crop') is not None: border_crop = transformation_params.get('border_crop') image_t = image_t[border_crop:-border_crop, border_crop:-border_crop, :] sample_pos_inv = sample_pos_inv[border_crop:-border_crop, border_crop:-border_crop, :] # Downsample the image image_t = cv2.resize(image_t, None, fx=1.0 / downsample_factor, fy=1.0 / downsample_factor, interpolation=interpolation) sample_pos_inv = cv2.resize(sample_pos_inv.numpy(), None, fx=1.0 / downsample_factor, fy=1.0 / downsample_factor, interpolation=interpolation) sample_pos_inv = torch.from_numpy(sample_pos_inv).permute(2, 0, 1).contiguous() if normalize: image_t = numpy_to_torch(image_t).float() / 255.0 else: image_t = numpy_to_torch(image_t).float() burst.append(image_t) sample_pos_inv_all.append(sample_pos_inv / downsample_factor) burst_images = torch.stack(burst) sample_pos_inv_all = torch.stack(sample_pos_inv_all) # Compute the flow vectors to go from the i'th burst image to the base image flow_vectors = sample_pos_inv_all - sample_pos_inv_all[:, :1, ...] return burst_images, flow_vectors ================================================ FILE: code/synthetic/bsrt/datasets/__init__.py ================================================ ================================================ FILE: code/synthetic/bsrt/datasets/burstsr_dataset.py ================================================ import os import torch import cv2 import numpy as np import pickle as pkl import torch.nn.functional as F import random import time class SamsungRAWImage: @staticmethod def load(path): im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED) im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16) im_raw = torch.from_numpy(im_raw) meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), "rb", -1)) return SamsungRAWImage(im_raw, meta_data['black_level'], meta_data['cam_wb'], meta_data['daylight_wb'], meta_data['color_matrix'], meta_data['exif_data'], meta_data.get('crop_info', None), meta_data.get('im_preview', None)) def __init__(self, im_raw, black_level, cam_wb, daylight_wb, color_matrix, exif_data, crop_info=None, im_preview=None): self.im_raw = im_raw self.black_level = black_level self.cam_wb = cam_wb self.daylight_wb = daylight_wb self.color_matrix = color_matrix self.exif_data = exif_data self.crop_info = crop_info self.im_preview = im_preview self.norm_factor = 1023.0 def get_all_meta_data(self): return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb, 'color_matrix': self.color_matrix.tolist()} def get_exposure_time(self): return self.exif_data['Image ExposureTime'].values[0].decimal() def get_noise_profile(self): noise = self.exif_data['Image Tag 0xC761'].values noise = [n[0] for n in noise] noise = np.array(noise).reshape(3, 2) return noise def get_f_number(self): return self.exif_data['Image FNumber'].values[0].decimal() def get_iso(self): return self.exif_data['Image ISOSpeedRatings'].values[0] def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False): im_raw = self.im_raw.float() if substract_black_level: im_raw = im_raw - torch.tensor(self.black_level).view(4, 1, 1) if white_balance: im_raw = im_raw * torch.tensor(self.cam_wb).view(4, 1, 1) if normalize: im_raw = im_raw / self.norm_factor return im_raw def shape(self): shape = (4, self.im_raw.shape[1], self.im_raw.shape[2]) return shape def crop_image(self, r1, r2, c1, c2): self.im_raw = self.im_raw[:, r1:r2, c1:c2] def get_crop(self, r1, r2, c1, c2): im_raw = self.im_raw[:, r1:r2, c1:c2] if self.im_preview is not None: im_preview = self.im_preview[2*r1:2*r2, 2*c1:2*c2] else: im_preview = None return SamsungRAWImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.color_matrix, self.exif_data, im_preview=im_preview) def postprocess(self, return_np=True, norm_factor=None): # Convert to rgb # im = torch.from_numpy(self.im_raw.astype(np.float32)) im = self.im_raw im = (im - torch.tensor(self.black_level).view(4, 1, 1)) * torch.tensor(self.cam_wb).view(4, 1, 1) if norm_factor is None: im = im / im.max() else: im = im / norm_factor im = torch.stack((im[0], (im[1] + im[2])/2, im[3]), dim=0) # im = torch.stack((im[0], im[1], im[3]), dim=0) im_out = im.clamp(0.0, 1.0) if return_np: im_out = im_out.permute(1, 2, 0).numpy() * 255.0 im_out = im_out.astype(np.uint8) return im_out class CanonImage: @staticmethod def load(path, split='train'): im_raw = cv2.imread('{}/im_raw.png'.format(path), cv2.IMREAD_UNCHANGED) im_raw = np.transpose(im_raw, (2, 0, 1)).astype(np.int16) im_raw = torch.from_numpy(im_raw) meta_data = pkl.load(open('{}/meta_info.pkl'.format(path), "rb", -1)) return CanonImage(im_raw.float(), meta_data['black_level'], meta_data['cam_wb'], meta_data['daylight_wb'], meta_data['rgb_xyz_matrix'], meta_data.get('exif_data', None), meta_data.get('crop_info', None)) def __init__(self, im_raw, black_level, cam_wb, daylight_wb, rgb_xyz_matrix, exif_data, crop_info=None): super(CanonImage, self).__init__() self.im_raw = im_raw if len(black_level) == 4: black_level = [black_level[0], black_level[1], black_level[3]] self.black_level = black_level if len(cam_wb) == 4: cam_wb = [cam_wb[0], cam_wb[1], cam_wb[3]] self.cam_wb = cam_wb if len(daylight_wb) == 4: daylight_wb = [daylight_wb[0], daylight_wb[1], daylight_wb[3]] self.daylight_wb = daylight_wb self.rgb_xyz_matrix = rgb_xyz_matrix self.xyz_srgb_matrix = torch.tensor([3.2404542, -1.5371385, -0.4985314, -0.9692660, 1.8760108, 0.0415560, 0.0556434, -0.2040259, 1.0572252]).view(3, 3) self.exif_data = exif_data self.crop_info = crop_info self.norm_factor = 16383 def shape(self): shape = (3, self.im_raw.shape[1], self.im_raw.shape[2]) return shape def get_all_meta_data(self): return {'black_level': self.black_level, 'cam_wb': self.cam_wb, 'daylight_wb': self.daylight_wb, 'rgb_xyz_matrix': self.rgb_xyz_matrix.tolist(), 'crop_info': self.crop_info, 'norm_factor': self.norm_factor} def get_exposure_time(self): return self.exif_data['EXIF ExposureTime'].values[0].decimal() def get_f_number(self): return self.exif_data['EXIF FNumber'].values[0].decimal() def get_iso(self): return self.exif_data['EXIF ISOSpeedRatings'].values[0] def get_image_data(self, substract_black_level=False, white_balance=False, normalize=False): im_raw = self.im_raw.float() if substract_black_level: im_raw = im_raw - torch.tensor(self.black_level).view(3, 1, 1) if white_balance: im_raw = im_raw * torch.tensor(self.cam_wb).view(3, 1, 1) / 1024.0 if normalize: im_raw = im_raw / self.norm_factor return im_raw def set_image_data(self, im_data): self.im_raw = im_data def crop_image(self, r1, r2, c1, c2): self.im_raw = self.im_raw[:, r1:r2, c1:c2] def get_crop(self, r1, r2, c1, c2): im_raw = self.im_raw[:, r1:r2, c1:c2] return CanonImage(im_raw, self.black_level, self.cam_wb, self.daylight_wb, self.rgb_xyz_matrix, self.exif_data, self.crop_info) def set_crop_info(self, crop_info): self.crop_info = crop_info def resize(self, size=None, scale_factor=None): self.im_raw = F.interpolate(self.im_raw.unsqueeze(0), size=size, scale_factor=scale_factor, mode='bilinear').squeeze(0) def postprocess(self, return_np=True): # Convert to rgb im = self.im_raw im = (im - torch.tensor(self.black_level).view(3, 1, 1)).float() * torch.tensor(self.cam_wb).view(3, 1, 1) im_out = im / im.max() im_out = im_out.clamp(0.0, 1.0) if return_np: im_out = im_out.permute(1, 2, 0).numpy() * 255.0 im_out = im_out.astype(np.uint8) return im_out def load_txt(path): with open(path, 'r') as fh: out = [d.rstrip() for d in fh.readlines()] return out class BurstSRDataset(torch.utils.data.Dataset): """ Real-world burst super-resolution dataset. """ def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='train'): """ args: root : path of the root directory burst_size : Burst size. Maximum allowed burst size is 14. crop_sz: Size of the extracted crop. Maximum allowed crop size is 80 center_crop: Whether to extract a random crop, or a centered crop. random_flip: Whether to apply random horizontal and vertical flip split: Can be 'train' or 'val' """ assert burst_size <= 14, 'burst_sz must be less than or equal to 14' assert crop_sz <= 80, 'crop_sz must be less than or equal to 80' assert split in ['train', 'val'] root = root + '/' + split super().__init__() self.burst_size = burst_size self.crop_sz = crop_sz self.split = split self.center_crop = center_crop self.random_flip = random_flip self.root = root self.substract_black_level = True self.white_balance = False self.burst_list = self._get_burst_list() def _get_burst_list(self): burst_list = sorted(os.listdir('{}'.format(self.root))) # print(burst_list) return burst_list def get_burst_info(self, burst_id): burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]} return burst_info def _get_raw_image(self, burst_id, im_id): raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id)) return raw_image def _get_gt_image(self, burst_id): canon_im = CanonImage.load('{}/{}/canon'.format(self.root, self.burst_list[burst_id]), split=self.split) return canon_im def get_burst(self, burst_id, im_ids, info=None): frames = [self._get_raw_image(burst_id, i) for i in im_ids] gt = self._get_gt_image(burst_id) if info is None: info = self.get_burst_info(burst_id) return frames, gt, info def _sample_images(self): burst_size = 14 ids = random.sample(range(1, burst_size), k=self.burst_size - 1) ids = [0, ] + ids return ids def __len__(self): return len(self.burst_list) def __getitem__(self, index): # Sample the images in the burst, in case a burst_size < 14 is used. im_ids = self._sample_images() # Read the burst images along with HR ground truth frames, gt, meta_info = self.get_burst(index, im_ids) # Extract crop if needed if frames[0].shape()[-1] != self.crop_sz: if getattr(self, 'center_crop', False): r1 = (frames[0].shape()[-2] - self.crop_sz) // 2 c1 = (frames[0].shape()[-1] - self.crop_sz) // 2 else: r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz) c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz) r2 = r1 + self.crop_sz c2 = c1 + self.crop_sz scale_factor = gt.shape()[-1] // frames[0].shape()[-1] frames = [im.get_crop(r1, r2, c1, c2) for im in frames] gt = gt.get_crop(scale_factor * r1, scale_factor * r2, scale_factor * c1, scale_factor * c2) # Load the RAW image data burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level, white_balance=self.white_balance) for im in frames] # Convert to tensor gt_image_data = gt.get_image_data(normalize=True, white_balance=self.white_balance, substract_black_level=self.substract_black_level) if self.random_flip: burst_image_data = [flatten_raw_image(im) for im in burst_image_data] pad = [0, 0, 0, 0] if random.random() > 0.5: burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data] gt_image_data = gt_image_data.flip([2, ])[:, :, 2:-2].contiguous() pad[1] = 1 if random.random() > 0.5: burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data] gt_image_data = gt_image_data.flip([1, ])[:, 2:-2, :].contiguous() pad[3] = 1 burst_image_data = [pack_raw_image(im) for im in burst_image_data] burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data] gt_image_data = F.pad(gt_image_data.unsqueeze(0), [4 * p for p in pad], mode='replicate').squeeze(0) burst_image_meta_info = frames[0].get_all_meta_data() burst_image_meta_info['black_level_subtracted'] = self.substract_black_level burst_image_meta_info['while_balance_applied'] = self.white_balance burst_image_meta_info['norm_factor'] = frames[0].norm_factor gt_image_meta_info = gt.get_all_meta_data() burst = torch.stack(burst_image_data, dim=0) burst_exposure = frames[0].get_exposure_time() canon_exposure = gt.get_exposure_time() burst_f_number = frames[0].get_f_number() canon_f_number = gt.get_f_number() burst_iso = frames[0].get_iso() canon_iso = gt.get_iso() # Normalize the GT image to account for differences in exposure, ISO etc light_factor_burst = burst_exposure * burst_iso / (burst_f_number ** 2) light_factor_canon = canon_exposure * canon_iso / (canon_f_number ** 2) exp_scale_factor = (light_factor_burst / light_factor_canon) gt_image_data = gt_image_data * exp_scale_factor gt_image_meta_info['black_level_subtracted'] = self.substract_black_level gt_image_meta_info['while_balance_applied'] = self.white_balance gt_image_meta_info['norm_factor'] = gt.norm_factor / exp_scale_factor burst_image_meta_info['exposure'] = burst_exposure burst_image_meta_info['f_number'] = burst_f_number burst_image_meta_info['iso'] = burst_iso gt_image_meta_info['exposure'] = canon_exposure gt_image_meta_info['f_number'] = canon_f_number gt_image_meta_info['iso'] = canon_iso burst = burst.float() frame_gt = gt_image_data.float() meta_info_burst = burst_image_meta_info meta_info_gt = gt_image_meta_info del meta_info_gt['crop_info'] for k, v in meta_info_gt.items(): if isinstance(v, (list, tuple)): meta_info_gt[k] = torch.tensor(v) for k, v in meta_info_burst.items(): if isinstance(v, (list, tuple)): meta_info_burst[k] = torch.tensor(v) meta_info_burst['burst_name'] = meta_info['burst_name'] return burst, frame_gt, meta_info_burst, meta_info_gt def pack_raw_image(im_raw): if isinstance(im_raw, np.ndarray): im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2)) elif isinstance(im_raw, torch.Tensor): im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype).to(im_raw.device) else: raise Exception im_out[0, :, :] = im_raw[0::2, 0::2] im_out[1, :, :] = im_raw[0::2, 1::2] im_out[2, :, :] = im_raw[1::2, 0::2] im_out[3, :, :] = im_raw[1::2, 1::2] return im_out def flatten_raw_image(im_raw_4ch): if isinstance(im_raw_4ch, np.ndarray): im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2)) elif isinstance(im_raw_4ch, torch.Tensor): im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device) else: raise Exception im_out[0::2, 0::2] = im_raw_4ch[0, :, :] im_out[0::2, 1::2] = im_raw_4ch[1, :, :] im_out[1::2, 0::2] = im_raw_4ch[2, :, :] im_out[1::2, 1::2] = im_raw_4ch[3, :, :] return im_out def pack_raw_image_batch(im_raw): im_out = torch.zeros((im_raw.shape[0], im_raw.shape[1], 4, im_raw.shape[3] // 2, im_raw.shape[4] // 2), dtype=im_raw.dtype).to(im_raw.device) im_out[:, :, 0, :, :] = im_raw[:, :, 0, 0::2, 0::2] im_out[:, :, 1, :, :] = im_raw[:, :, 0, 0::2, 1::2] im_out[:, :, 2, :, :] = im_raw[:, :, 0, 1::2, 0::2] im_out[:, :, 3, :, :] = im_raw[:, :, 0, 1::2, 1::2] return im_out def flatten_raw_image_batch(im_raw_4ch): im_out = torch.zeros((im_raw_4ch.shape[0], im_raw_4ch.shape[1], 1, im_raw_4ch.shape[3] * 2, im_raw_4ch.shape[4] * 2), dtype=im_raw_4ch.dtype).to(im_raw_4ch.device) im_out[:, :, 0, 0::2, 0::2] = im_raw_4ch[:, :, 0, :, :] im_out[:, :, 0, 0::2, 1::2] = im_raw_4ch[:, :, 1, :, :] im_out[:, :, 0, 1::2, 0::2] = im_raw_4ch[:, :, 2, :, :] im_out[:, :, 0, 1::2, 1::2] = im_raw_4ch[:, :, 3, :, :] return im_out ================================================ FILE: code/synthetic/bsrt/datasets/burstsr_test_dataset.py ================================================ import os import torch import torch.nn.functional as F import random from .burstsr_dataset import SamsungRAWImage, flatten_raw_image, pack_raw_image class BurstSRDataset(torch.utils.data.Dataset): """ Real-world burst super-resolution dataset. """ def __init__(self, root, burst_size=8, crop_sz=80, center_crop=False, random_flip=False, split='test'): """ args: root : path of the root directory burst_size : Burst size. Maximum allowed burst size is 14. crop_sz: Size of the extracted crop. Maximum allowed crop size is 80 center_crop: Whether to extract a random crop, or a centered crop. random_flip: Whether to apply random horizontal and vertical flip split: Can be 'train' or 'val' """ assert burst_size <= 14, 'burst_sz must be less than or equal to 14' assert crop_sz <= 80, 'crop_sz must be less than or equal to 80' assert split in ['test'] root = root + '/' + split super().__init__() self.burst_size = burst_size self.crop_sz = crop_sz self.split = split self.center_crop = center_crop self.random_flip = random_flip self.root = root self.substract_black_level = True self.white_balance = False self.burst_list = self._get_burst_list() def _get_burst_list(self): burst_list = sorted(os.listdir('{}'.format(self.root))) return burst_list def get_burst_info(self, burst_id): burst_info = {'burst_size': 14, 'burst_name': self.burst_list[burst_id]} return burst_info def _get_raw_image(self, burst_id, im_id): raw_image = SamsungRAWImage.load('{}/{}/samsung_{:02d}'.format(self.root, self.burst_list[burst_id], im_id)) return raw_image def get_burst(self, burst_id, im_ids, info=None): frames = [self._get_raw_image(burst_id, i) for i in im_ids] if info is None: info = self.get_burst_info(burst_id) return frames, info def _sample_images(self): burst_size = 14 ids = random.sample(range(1, burst_size), k=self.burst_size - 1) ids = [0, ] + ids return ids def __len__(self): return len(self.burst_list) def __getitem__(self, index): # Sample the images in the burst, in case a burst_size < 14 is used. im_ids = self._sample_images() # Read the burst images along with HR ground truth frames, meta_info = self.get_burst(index, im_ids) # Extract crop if needed if frames[0].shape()[-1] != self.crop_sz: if getattr(self, 'center_crop', False): r1 = (frames[0].shape()[-2] - self.crop_sz) // 2 c1 = (frames[0].shape()[-1] - self.crop_sz) // 2 else: r1 = random.randint(0, frames[0].shape()[-2] - self.crop_sz) c1 = random.randint(0, frames[0].shape()[-1] - self.crop_sz) r2 = r1 + self.crop_sz c2 = c1 + self.crop_sz frames = [im.get_crop(r1, r2, c1, c2) for im in frames] # Load the RAW image data burst_image_data = [im.get_image_data(normalize=True, substract_black_level=self.substract_black_level, white_balance=self.white_balance) for im in frames] if self.random_flip: burst_image_data = [flatten_raw_image(im) for im in burst_image_data] pad = [0, 0, 0, 0] if random.random() > 0.5: burst_image_data = [im.flip([1, ])[:, 1:-1].contiguous() for im in burst_image_data] pad[1] = 1 if random.random() > 0.5: burst_image_data = [im.flip([0, ])[1:-1, :].contiguous() for im in burst_image_data] pad[3] = 1 burst_image_data = [pack_raw_image(im) for im in burst_image_data] burst_image_data = [F.pad(im.unsqueeze(0), pad, mode='replicate').squeeze(0) for im in burst_image_data] burst_image_meta_info = frames[0].get_all_meta_data() burst_image_meta_info['black_level_subtracted'] = self.substract_black_level burst_image_meta_info['while_balance_applied'] = self.white_balance burst_image_meta_info['norm_factor'] = frames[0].norm_factor burst = torch.stack(burst_image_data, dim=0) burst_exposure = frames[0].get_exposure_time() burst_f_number = frames[0].get_f_number() burst_iso = frames[0].get_iso() burst_image_meta_info['exposure'] = burst_exposure burst_image_meta_info['f_number'] = burst_f_number burst_image_meta_info['iso'] = burst_iso burst = burst.float() meta_info_burst = burst_image_meta_info for k, v in meta_info_burst.items(): if isinstance(v, (list, tuple)): meta_info_burst[k] = torch.tensor(v) return burst, meta_info_burst ================================================ FILE: code/synthetic/bsrt/datasets/data_sampler.py ================================================ """ Modified from torch.utils.data.distributed.DistributedSampler Support enlarging the dataset for *iter-oriented* training, for saving time when restart the dataloader after each epoch """ import math import torch import torch.distributed as dist from torch.utils.data.sampler import Sampler class DistIterSampler(Sampler): """Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. """ def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas def __iter__(self): # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) indices = torch.randperm( self.total_size, generator=g ).tolist() # Returns a random permutation of integers from 0 to n - 1 dsize = len(self.dataset) indices = [v % dsize for v in indices] # subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def __len__(self): return self.num_samples def set_epoch(self, epoch): self.epoch = epoch ================================================ FILE: code/synthetic/bsrt/datasets/realworld_burst_test_set.py ================================================ import torch import cv2 import numpy as np import pickle as pkl class RealWorldBurstTest(torch.utils.data.Dataset): """ """ def __init__(self, root): self.root = root self.burst_list = list(range(20)) self.burst_size = 14 def __len__(self): return len(self.burst_list) def _read_burst_image(self, index, image_id): im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED) im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14) return im_t def __getitem__(self, index): """ args: index: Index of the burst returns: burst: LR RAW burst, a torch tensor of shape The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. meta_info: Meta information about the burst """ burst_name = '{:04d}'.format(index) burst = [self._read_burst_image(index, i) for i in range(self.burst_size)] burst = torch.stack(burst, 0) meta_info = {} meta_info['burst_name'] = burst_name return burst, meta_info ================================================ FILE: code/synthetic/bsrt/datasets/synthetic_burst_test_set.py ================================================ import torch import cv2 import numpy as np import pickle as pkl class SyntheticBurstTest(torch.utils.data.Dataset): """ Synthetic burst test set. The test burst have been generated using the same synthetic pipeline as employed in SyntheticBurst dataset. """ def __init__(self, root): self.root = root self.burst_list = list(range(92)) self.burst_size = 14 def __len__(self): return len(self.burst_list) def _read_burst_image(self, index, image_id): im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED) im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14) return im_t def __getitem__(self, index): """ Generates a synthetic burst args: index: Index of the burst returns: burst: LR RAW burst, a torch tensor of shape The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. meta_info: Meta information about the burst """ burst_name = '{:04d}'.format(index) burst = [self._read_burst_image(index, i) for i in range(self.burst_size)] burst = torch.stack(burst, 0) meta_info = {} meta_info['burst_name'] = burst_name return burst, meta_info ================================================ FILE: code/synthetic/bsrt/datasets/synthetic_burst_train_set.py ================================================ import torch import numpy as np from PIL import Image from data_processing.synthetic_burst_generation import rgb2rawburst, random_crop #syn_burst_utils import torchvision.transforms as tfm class SyntheticBurst(torch.utils.data.Dataset): """ Synthetic burst dataset for joint denoising, demosaicking, and super-resolution. RAW Burst sequences are synthetically generated on the fly as follows. First, a single image is loaded from the base_dataset. The sampled image is converted to linear sensor space using the inverse camera pipeline employed in [1]. A burst sequence is then generated by adding random translations and rotations to the converted image. The generated burst is then converted is then mosaicked, and corrupted by random noise to obtain the RAW burst. [1] Unprocessing Images for Learned Raw Denoising, Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen, Jiawen and Sharlet, Dillon and Barron, Jonathan T, CVPR 2019 """ def __init__(self, base_dataset, burst_size=8, crop_sz=384, transform=tfm.ToTensor()): self.base_dataset = base_dataset self.burst_size = burst_size self.crop_sz = crop_sz self.transform = transform self.downsample_factor = 4 self.burst_transformation_params = {'max_translation': 24.0, 'max_rotation': 1.0, 'max_shear': 0.0, 'max_scale': 0.0, 'border_crop': 24} self.image_processing_params = {'random_ccm': True, 'random_gains': True, 'smoothstep': True, 'gamma': True, 'add_noise': True} self.interpolation_type = 'bilinear' def __len__(self): return len(self.base_dataset) def __getitem__(self, index): """ Generates a synthetic burst args: index: Index of the image in the base_dataset used to generate the burst returns: burst: Generated LR RAW burst, a torch tensor of shape [burst_size, 4, self.crop_sz / (2*self.downsample_factor), self.crop_sz / (2*self.downsample_factor)] The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. The extra factor 2 in the denominator (2*self.downsample_factor) corresponds to the mosaicking operation. frame_gt: The HR RGB ground truth in the linear sensor space, a torch tensor of shape [3, self.crop_sz, self.crop_sz] flow_vectors: The ground truth flow vectors between a burst image and the base image (i.e. the first image in the burst). The flow_vectors can be used to warp the burst images to the base frame, using the 'warp' function in utils.warp package. flow_vectors is torch tensor of shape [burst_size, 2, self.crop_sz / self.downsample_factor, self.crop_sz / self.downsample_factor]. Note that the flow_vectors are in the LR RGB space, before mosaicking. Hence it has twice the number of rows and columns, compared to the output burst. NOTE: The flow_vectors are only available during training for the purpose of using any auxiliary losses if needed. The flow_vectors will NOT be provided for the bursts in the test set meta_info: A dictionary containing the parameters used to generate the synthetic burst. """ frame = self.base_dataset[index] # Augmentation, e.g. convert to tensor if self.transform is not None: # frame = Image.fromarray(frame) frame = self.transform(frame) # Extract a random crop from the image crop_sz = self.crop_sz + 2 * self.burst_transformation_params.get('border_crop', 0) frame_crop = random_crop(frame, crop_sz) # Generate RAW burst burst, frame_gt, burst_rgb, flow_vectors, meta_info = rgb2rawburst(frame_crop, self.burst_size, self.downsample_factor, burst_transformation_params=self.burst_transformation_params, image_processing_params=self.image_processing_params, interpolation_type=self.interpolation_type ) if self.burst_transformation_params.get('border_crop') is not None: border_crop = self.burst_transformation_params.get('border_crop') frame_gt = frame_gt[:, border_crop:-border_crop, border_crop:-border_crop] return burst, frame_gt, flow_vectors, meta_info ================================================ FILE: code/synthetic/bsrt/datasets/synthetic_burst_val_set.py ================================================ import os import torch import cv2 import numpy as np import pickle as pkl class SyntheticBurstVal(torch.utils.data.Dataset): """ Synthetic burst validation set introduced in [1]. The validation burst have been generated using a synthetic data generation pipeline. The dataset can be downloaded from https://data.vision.ee.ethz.ch/bhatg/SyntheticBurstVal.zip [1] Deep Burst Super-Resolution. Goutam Bhat, Martin Danelljan, Luc Van Gool, and Radu Timofte. CVPR 2021 """ def __init__(self, root=None, initialize=True): """ args: root - Path to root dataset directory initialize - boolean indicating whether to load the meta-data for the dataset """ self.root = os.path.join(root, 'val') self.burst_list = list(range(300)) self.burst_size = 14 def initialize(self): pass def __len__(self): return len(self.burst_list) def _read_burst_image(self, index, image_id): im = cv2.imread('{}/bursts/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED) im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14) return im_t def _read_gt_image(self, index): gt = cv2.imread('{}/gt/{:04d}/im_rgb.png'.format(self.root, index), cv2.IMREAD_UNCHANGED) gt_t = (torch.from_numpy(gt.astype(np.float32)) / 2 ** 14).permute(2, 0, 1).float() return gt_t def _read_meta_info(self, index): with open('{}/gt/{:04d}/meta_info.pkl'.format(self.root, index), "rb") as input_file: meta_info = pkl.load(input_file) return meta_info def __getitem__(self, index): """ Generates a synthetic burst args: index: Index of the burst returns: burst: LR RAW burst, a torch tensor of shape [14, 4, 48, 48] The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick. gt : Ground truth linear image meta_info: Meta info about the burst which can be used to convert gt to sRGB space """ burst_name = '{:04d}'.format(index) burst = [self._read_burst_image(index, i) for i in range(self.burst_size)] burst = torch.stack(burst, 0) gt = self._read_gt_image(index) meta_info = self._read_meta_info(index) meta_info['burst_name'] = burst_name return burst, gt, meta_info ================================================ FILE: code/synthetic/bsrt/datasets/zurich_raw2rgb_dataset.py ================================================ import torch import os import numpy as np from cv2 import imread class ZurichRAW2RGB(torch.utils.data.Dataset): """ Canon RGB images from the "Zurich RAW to RGB mapping" dataset. You can download the full dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip """ def __init__(self, root, split='train'): super().__init__() if split in ['train', 'test']: self.img_pth = os.path.join(root, split, 'canon') else: raise Exception('Unknown split {}'.format(split)) self.image_list = self._get_image_list(split) def _get_image_list(self, split): if split == 'train': image_list = ['{:d}.jpg'.format(i) for i in range(46839)] elif split == 'test': # image_list = ['{:d}.jpg'.format(int(i)) for i in np.linspace(1, 1200, 400)] image_list = ['{:d}.jpg'.format(i) for i in range(1200)] else: raise Exception return image_list def _get_image(self, im_id): path = os.path.join(self.img_pth, self.image_list[im_id]) img = imread(path) return img def get_image(self, im_id): frame = self._get_image(im_id) return frame def __len__(self): return len(self.image_list) def __getitem__(self, index): frame = self._get_image(index) return frame ================================================ FILE: code/synthetic/bsrt/demo.sh ================================================ #!/usr/bin/env bash python main.py --n_GPUs 8 --print_every 40 --lr 0.0001 --decay 100-200 --save bsrt_tiny --model BSRT --fp16 --model_level S --swinfeature --batch_size 32 --burst_size 14 --patch_size 256 # python main.py --n_GPUs 8 --print_every 40 --lr 0.0001 --decay 100-200 --save bsrt_large --model BSRT --fp16 --model_level L --swinfeature --batch_size 16 --burst_size 14 --patch_size 256 # python test_synburst.py --n_GPUs 1 --model BSRT --model_level S --fp16 --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_tiny/bsrt_best_epoch.pth --root /data/dataset/ntire21/burstsr/synthetic # python test_synburst.py --n_GPUs 1 --model BSRT --model_level L --fp16 --swinfeature --burst_size 14 --patch_size 384 --pre_train ../train_log/bsrt/real_models/bsrt_large/bsrt_synburst.pth --root /data/dataset/ntire21/burstsr/synthetic ================================================ FILE: code/synthetic/bsrt/loss/Charbonnier.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class CharbonnierLoss(nn.Module): """L1 charbonnier loss.""" def __init__(self, epsilon=1e-3, reduce=True): super(CharbonnierLoss, self).__init__() self.eps = epsilon * epsilon self.reduce = reduce def forward(self, X, Y): diff = torch.add(X, -Y) error = torch.sqrt(diff * diff + self.eps) if self.reduce: loss = torch.mean(error) else: loss = error return loss ================================================ FILE: code/synthetic/bsrt/loss/__init__.py ================================================ import os from importlib import import_module import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class Loss(nn.modules.loss._Loss): def __init__(self, args, ckp): super(Loss, self).__init__() if args.local_rank == 0: print('Preparing loss function:') self.n_GPUs = args.n_GPUs self.loss = [] self.loss_module = nn.ModuleList() for loss in args.loss.split('+'): weight, loss_type = loss.split('*') if loss_type == 'MSE': loss_function = nn.MSELoss() elif loss_type == 'L1': loss_function = nn.L1Loss() elif loss_type.find('VGG') >= 0: module = import_module('loss.vgg') loss_function = getattr(module, 'VGG')( loss_type[3:], rgb_range=args.rgb_range ) elif loss_type.find('GAN') >= 0: module = import_module('loss.adversarial') loss_function = getattr(module, 'Adversarial')( args, loss_type ) elif loss_type == 'FILTER': module = import_module('loss.filter') loss_function = getattr(module, 'Filter')(args) elif loss_type == 'SSIM': module = import_module('loss.mssim') loss_function = getattr(module, 'SSIM')(args) elif loss_type == 'MSSSIM': module = import_module('loss.mssim') loss_function = getattr(module, 'MSSSIM')(args) self.loss.append({ 'type': loss_type, 'weight': float(weight), 'function': loss_function} ) if loss_type.find('GAN') >= 0: self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) if len(self.loss) > 1: self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) for l in self.loss: if l['function'] is not None: if args.local_rank == 0: print('{:.3f} * {}'.format(l['weight'], l['type'])) self.loss_module.append(l['function']) self.log = torch.Tensor() device = torch.device('cpu' if args.cpu else 'cuda') self.loss_module.to(device) if args.precision == 'half': self.loss_module.half() if not args.cpu and args.n_GPUs > 1: self.loss_module = nn.DataParallel( self.loss_module, range(args.n_GPUs) ) if args.load != '': self.load(ckp.dir, cpu=args.cpu) def forward(self, sr, hr): losses = [] for i, l in enumerate(self.loss): if l['function'] is not None: loss = l['function'](sr, hr) effective_loss = l['weight'] * loss losses.append(effective_loss) self.log[-1, i] += effective_loss.item() elif l['type'] == 'DIS': self.log[-1, i] += self.loss[i - 1]['function'].loss loss_sum = sum(losses) if len(self.loss) > 1: self.log[-1, -1] += loss_sum.item() return loss_sum def step(self): for l in self.get_loss_module(): if hasattr(l, 'scheduler'): l.scheduler.step() def start_log(self): self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) def end_log(self, n_batches): self.log[-1].div_(n_batches) def display_loss(self, batch): n_samples = batch + 1 log = [] for l, c in zip(self.loss, self.log[-1]): log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) return ''.join(log) def plot_loss(self, apath, epoch): axis = np.linspace(1, epoch, epoch) for i, l in enumerate(self.loss): label = '{} Loss'.format(l['type']) fig = plt.figure() plt.title(label) plt.plot(axis, self.log[:, i].numpy(), label=label) plt.legend() plt.xlabel('Epochs') plt.ylabel('Loss') plt.grid(True) plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) plt.close(fig) def get_loss_module(self): if self.n_GPUs == 1: return self.loss_module else: return self.loss_module.module def save(self, apath): torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) torch.save(self.log, os.path.join(apath, 'loss_log.pt')) def load(self, apath, cpu=False): if cpu: kwargs = {'map_location': lambda storage, loc: storage} else: kwargs = {} self.load_state_dict(torch.load( os.path.join(apath, 'loss.pt'), **kwargs )) self.log = torch.load(os.path.join(apath, 'loss_log.pt')) for l in self.get_loss_module(): if hasattr(l, 'scheduler'): for _ in range(len(self.log)): l.scheduler.step() ================================================ FILE: code/synthetic/bsrt/loss/adversarial.py ================================================ import utility from types import SimpleNamespace from model import common from loss import discriminator import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class Adversarial(nn.Module): def __init__(self, args, gan_type): super(Adversarial, self).__init__() self.gan_type = gan_type self.gan_k = args.gan_k self.dis = discriminator.Discriminator(args) # if gan_type == 'WGAN_GP': if True: # see https://arxiv.org/pdf/1704.00028.pdf pp.4 optim_dict = { 'optimizer': 'ADAM', 'betas': (0.5, 0.9), 'epsilon': 1e-8, 'lr': 1e-5, 'weight_decay': args.weight_decay, 'decay': args.decay, 'gamma': args.gamma } optim_args = SimpleNamespace(**optim_dict) else: optim_args = args self.optimizer = utility.make_optimizer(optim_args, self.dis) def forward(self, fake, real): # updating discriminator... self.loss = 0 fake_detach = fake.detach() # do not backpropagate through G for _ in range(self.gan_k): self.optimizer.zero_grad() # d: B x 1 tensor d_fake = self.dis(fake_detach) d_real = self.dis(real) retain_graph = False if self.gan_type in ['GAN', 'SNGAN']: loss_d = self.bce(d_real, d_fake) elif self.gan_type.find('WGAN') >= 0: loss_d = (d_fake - d_real).mean() if self.gan_type.find('GP') >= 0: epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) hat.requires_grad = True d_hat = self.dis(hat) gradients = torch.autograd.grad( outputs=d_hat.sum(), inputs=hat, retain_graph=True, create_graph=True, only_inputs=True )[0] gradients = gradients.view(gradients.size(0), -1) gradient_norm = gradients.norm(2, dim=1) gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() loss_d += gradient_penalty # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks elif self.gan_type == 'RGAN': better_real = d_real - d_fake.mean(dim=0, keepdim=True) better_fake = d_fake - d_real.mean(dim=0, keepdim=True) loss_d = self.bce(better_real, better_fake) retain_graph = True # Discriminator update self.loss += loss_d.item() loss_d.backward(retain_graph=retain_graph) self.optimizer.step() if self.gan_type == 'WGAN': for p in self.dis.parameters(): p.data.clamp_(-1, 1) self.loss /= self.gan_k # updating generator... d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is if self.gan_type in ['GAN', 'SNGAN']: label_real = torch.ones_like(d_fake_bp) loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) elif self.gan_type.find('WGAN') >= 0: loss_g = -d_fake_bp.mean() elif self.gan_type == 'RGAN': better_real = d_real.detach() - d_fake_bp.mean(dim=0, keepdim=True) better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True).detach() loss_g = self.bce(better_fake, better_real) # Generator loss return loss_g def state_dict(self, *args, **kwargs): state_discriminator = self.dis.state_dict(*args, **kwargs) state_optimizer = self.optimizer.state_dict() return dict(**state_discriminator, **state_optimizer) def bce(self, real, fake): label_real = torch.ones_like(real) label_fake = torch.zeros_like(fake) bce_real = F.binary_cross_entropy_with_logits(real, label_real) bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) bce_loss = bce_real + bce_fake return bce_loss # Some references # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py # OR # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py ================================================ FILE: code/synthetic/bsrt/loss/discriminator.py ================================================ from model import common import torch.nn as nn class Discriminator(nn.Module): ''' output is not normalized ''' def __init__(self, args, gan_type='GAN'): super(Discriminator, self).__init__() in_channels = args.n_colors out_channels = 32 depth = 6 def _block(_in_channels, _out_channels, stride=1): Conv = nn.Conv2d( _in_channels, _out_channels, 3, padding=1, stride=stride, bias=False ) if gan_type == 'SNGAN': return nn.Sequential( spectral_norm(Conv), nn.BatchNorm2d(_out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True) ) else: return nn.Sequential( Conv, nn.BatchNorm2d(_out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True) ) m_features = [_block(in_channels, out_channels)] for i in range(depth): in_channels = out_channels # if i % 2 == 1: # stride = 1 # out_channels *= 2 # else: out_channels *= 2 stride = 2 m_features.append(_block(in_channels, out_channels, stride=stride)) patch_size = args.patch_size // 2**(depth-1) # print(out_channels, patch_size) m_classifier = [ nn.Flatten(), nn.Linear(out_channels*patch_size**2, 512), nn.LeakyReLU(0.2, True), nn.Linear(512, 1) ] self.features = nn.Sequential(*m_features) self.classifier = nn.Sequential(*m_classifier) def forward(self, x): features = self.features(x) # print(features.shape) output = self.classifier(features) return output ================================================ FILE: code/synthetic/bsrt/loss/filter.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class Filter(nn.Module): def __init__(self, args): super().__init__() self.args = args kernel = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]]) self.conv = nn.Conv2d(args.n_colors, args.n_colors, 3, 3) with torch.no_grad(): self.conv.weight.copy_(kernel.float()) self.loss = nn.L1Loss() def forward(self, x, y): preds_x = self.conv(x) preds_y = self.conv(y) return self.loss(preds_x, preds_y) ================================================ FILE: code/synthetic/bsrt/loss/hist_entropy.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class HistEntropy(nn.Module): def __init__(self, args): super().__init__() self.args = args def forward(self, x): p = torch.softmax(x, dim=1) logp = torch.log_softmax(x, dim=1) entropy = (-p * logp).sum(dim=(2, 3)).mean() return entropy ================================================ FILE: code/synthetic/bsrt/loss/mssim.py ================================================ import torch import torch.nn.functional as F from math import exp import numpy as np def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) return gauss/gauss.sum() def create_window(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() return window def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). if val_range is None: if torch.max(img1) > 128: max_val = 255 else: max_val = 1 if torch.min(img1) < -0.5: min_val = -1 else: min_val = 0 L = max_val - min_val else: L = val_range padd = 0 (_, channel, height, width) = img1.size() if window is None: real_size = min(window_size, height, width) window = create_window(real_size, channel=channel).to(img1.device) mu1 = F.conv2d(img1, window, padding=padd, groups=channel) mu2 = F.conv2d(img2, window, padding=padd, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 C1 = (0.01 * L) ** 2 C2 = (0.03 * L) ** 2 v1 = 2.0 * sigma12 + C2 v2 = sigma1_sq + sigma2_sq + C2 cs = torch.mean(v1 / v2) # contrast sensitivity ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) if size_average: ret = ssim_map.mean() else: ret = ssim_map.mean(1).mean(1).mean(1) if full: return ret, cs return ret def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None): device = img1.device weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) levels = weights.size()[0] ssims = [] mcs = [] for _ in range(levels): sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) # Relu normalize (not compliant with original definition) if normalize == "relu": ssims.append(torch.relu(sim)) mcs.append(torch.relu(cs)) else: ssims.append(sim) mcs.append(cs) img1 = F.avg_pool2d(img1, (2, 2)) img2 = F.avg_pool2d(img2, (2, 2)) ssims = torch.stack(ssims) mcs = torch.stack(mcs) # Simple normalize (not compliant with original definition) # TODO: remove support for normalize == True (kept for backward support) if normalize == "simple" or normalize == True: ssims = (ssims + 1) / 2 mcs = (mcs + 1) / 2 pow1 = mcs ** weights pow2 = ssims ** weights # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ output = torch.prod(pow1[:-1] * pow2[-1]) return output # Classes to re-use window class SSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True, val_range=None): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.val_range = val_range # Assume 1 channel for SSIM self.channel = 1 self.window = create_window(window_size) def forward(self, img1, img2): (_, channel, _, _) = img1.size() if channel == self.channel and self.window.dtype == img1.dtype: window = self.window else: window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) self.window = window self.channel = channel return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) class MSSSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True, channel=3): super(MSSSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = channel def forward(self, img1, img2): # TODO: store window between calls if possible return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) ================================================ FILE: code/synthetic/bsrt/loss/vgg.py ================================================ from model import common import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class VGG(nn.Module): def __init__(self, conv_index, rgb_range=1): super(VGG, self).__init__() vgg_features = models.vgg19(pretrained=True).features modules = [m for m in vgg_features] if conv_index.find('22') >= 0: self.vgg = nn.Sequential(*modules[:8]) elif conv_index.find('54') >= 0: self.vgg = nn.Sequential(*modules[:35]) vgg_mean = (0.485, 0.456, 0.406) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) for p in self.parameters(): p.requires_grad = False def forward(self, sr, hr): def _forward(x): # x = self.sub_mean(x) x = self.vgg(x) return x sr = sr.repeat(1, 3, 1, 1) hr = hr.repeat(1, 3, 1, 1) vgg_sr = _forward(sr) with torch.no_grad(): vgg_hr = _forward(hr.detach()) loss = F.mse_loss(vgg_sr, vgg_hr) return loss ================================================ FILE: code/synthetic/bsrt/main.py ================================================ import torch import random import numpy as np from torch.utils.data import DataLoader from torchvision import transforms as T import utility import model import loss from option import args from trainer import Trainer from datasets.synthetic_burst_train_set import SyntheticBurst from datasets.synthetic_burst_val_set import SyntheticBurstVal from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB from datasets.data_sampler import DistIterSampler import torch.multiprocessing as mp import torch.backends.cudnn as cudnn import torch.utils.data.distributed # torch.autograd.set_detect_anomaly(True) # torch.multiprocessing.set_sharing_strategy('file_system') def init_seeds(seed=0, cuda_deterministic=True): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html if cuda_deterministic: # slower, more reproducible cudnn.deterministic = True cudnn.benchmark = False else: # faster, less reproducible cudnn.deterministic = False cudnn.benchmark = True checkpoint = utility.checkpoint(args) def main(): if args.n_GPUs > 1: mp.spawn(main_worker, nprocs=args.n_GPUs, args=(args.n_GPUs, args), join=True) else: main_worker(0, args.n_GPUs, args) def main_worker(local_rank, nprocs, args): if checkpoint.ok: args.local_rank = local_rank if nprocs > 1: init_seeds(local_rank+1) cudnn.benchmark = True utility.setup(local_rank, nprocs) torch.cuda.set_device(args.local_rank) batch_size = int(args.batch_size / nprocs) train_zurich_raw2rgb = ZurichRAW2RGB(root=args.root, split='train') train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=args.burst_size, crop_sz=args.patch_size) # valid_zurich_raw2rgb = ZurichRAW2RGB(root=args.root, split='test') # valid_data = SyntheticBurst(valid_zurich_raw2rgb, burst_size=14, crop_sz=1024) valid_data = SyntheticBurstVal(root=args.root) if local_rank <= 0: print(f"train data: {len(train_data)}, test data: {len(valid_data)}") print(f"Test only: {args.test_only}") if nprocs > 1: train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) # train_sampler = DistIterSampler(train_data, nprocs, local_rank, 1) valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data, shuffle=False) train_loader = DataLoader(dataset=train_data, batch_size=batch_size, num_workers=args.batch_size, pin_memory=True, drop_last=True, sampler=train_sampler) valid_loader = DataLoader(dataset=valid_data, batch_size=1, num_workers=1, pin_memory=True, drop_last=True, sampler=valid_sampler) else: train_sampler = None train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=8, shuffle=True, pin_memory=True, drop_last=True) # args.cpus valid_loader = DataLoader(dataset=valid_data, batch_size=args.batch_size, num_workers=4, shuffle=False, pin_memory=True, drop_last=True) # args.cpus _model = model.Model(args, checkpoint) _loss = loss.Loss(args, checkpoint) if not args.test_only else None t = Trainer(args, train_loader, train_sampler, valid_loader, _model, _loss, checkpoint) while not t.terminate(): t.train() del _model del _loss del train_loader del valid_loader # utility.cleanup() checkpoint.done() if __name__ == '__main__': main() ================================================ FILE: code/synthetic/bsrt/model/DCNv2/LICENSE ================================================ BSD 3-Clause License Copyright (c) 2019, Charles Shang All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: code/synthetic/bsrt/model/DCNv2/README.md ================================================ ## Deformable Convolutional Networks V2 with Pytorch 1.0 ### Build ```bash ./make.sh # build python test.py # run examples and gradient check ``` ### An Example - deformable conv ```python from dcn_v2 import DCN input = torch.randn(2, 64, 128, 128).cuda() # wrap all things (offset and mask) in DCN dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() output = dcn(input) print(output.shape) ``` - deformable roi pooling ```python from dcn_v2 import DCNPooling input = torch.randn(2, 32, 64, 64).cuda() batch_inds = torch.randint(2, (20, 1)).cuda().float() x = torch.randint(256, (20, 1)).cuda().float() y = torch.randint(256, (20, 1)).cuda().float() w = torch.randint(64, (20, 1)).cuda().float() h = torch.randint(64, (20, 1)).cuda().float() rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) # mdformable pooling (V2) # wrap all things (offset and mask) in DCNPooling dpooling = DCNPooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=32, no_trans=False, group_size=1, trans_std=0.1).cuda() dout = dpooling(input, rois) ``` ### Note Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with, ```bash git checkout pytorch_0.4 ``` ### Known Issues: - [x] Gradient check w.r.t offset (solved) - [ ] Backward is not reentrant (minor) This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some non-differential points? Update: all gradient check passes with double precision. Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for float `<1e-15` for double), so it may not be a serious problem (?) Please post an issue or PR if you have any comments. ================================================ FILE: code/synthetic/bsrt/model/DCNv2/__init__.py ================================================ ================================================ FILE: code/synthetic/bsrt/model/DCNv2/dcn_v2.py ================================================ #!/usr/bin/env python from __future__ import absolute_import, division, print_function import math import torch from torch import nn from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair from torch.cuda.amp import custom_fwd, custom_bwd # from apex import amp import _ext as _backend class _DCNv2(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # @amp.float_function def forward( ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups ): ctx.stride = _pair(stride) ctx.padding = _pair(padding) ctx.dilation = _pair(dilation) ctx.kernel_size = _pair(weight.shape[2:4]) ctx.deformable_groups = deformable_groups output = _backend.dcn_v2_forward( input, weight, bias, offset, mask, ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], ctx.deformable_groups, ) ctx.save_for_backward(input, offset, mask, weight, bias) return output @staticmethod @once_differentiable @custom_bwd # @amp.float_function def backward(ctx, grad_output): input, offset, mask, weight, bias = ctx.saved_tensors grad_input, grad_offset, grad_mask, grad_weight, grad_bias = _backend.dcn_v2_backward( input, weight, bias, offset, mask, grad_output, ctx.kernel_size[0], ctx.kernel_size[1], ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], ctx.deformable_groups, ) return grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None @staticmethod def symbolic( g, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups ): from torch.nn.modules.utils import _pair stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) # as of trt 7, the dcn operation will be translated again by modifying the onnx file # so the exporting code is kept to resemble the forward() return g.op( "DCNv2_2", input, offset, mask, weight, bias, stride_i=stride, padding_i=padding, dilation_i=dilation, deformable_groups_i=deformable_groups, ) dcn_v2_conv = _DCNv2.apply class DCNv2(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1, ): super(DCNv2, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.deformable_groups = deformable_groups self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) self.bias = nn.Parameter(torch.Tensor(out_channels)) self.reset_parameters() def reset_parameters(self): n = self.in_channels for k in self.kernel_size: n *= k stdv = 1.0 / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) self.bias.data.zero_() def forward(self, input, offset, mask): assert ( 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == offset.shape[1] ) assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1] return dcn_v2_conv( input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups, ) class DCN(DCNv2): def __init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1, ): super(DCN, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups ) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Conv2d( self.in_channels, channels_, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=True, ) self.init_offset() def init_offset(self): self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.bias.data.zero_() def forward(self, input): out = self.conv_offset_mask(input) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) return dcn_v2_conv( input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups, ) class DCN_sep(DCNv2): '''Use other features to generate offsets and masks''' def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1): super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Conv2d( self.in_channels, channels_, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=True) self.init_offset() def init_offset(self): self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.bias.data.zero_() def forward(self, input, fea): '''input: input features for deformable conv fea: other features used for generating offsets and mask''' out = self.conv_offset_mask(fea) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) # offset = torch.clamp(offset, -100, 100) offset_mean = torch.mean(torch.abs(offset)) if offset_mean > 250: print('Offset mean is {}, larger than 100.'.format(offset_mean)) # return None # offset[offset>=150] = 1e-3 # offset = offset.clamp(-50, 50) mask = torch.sigmoid(mask) return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups) class FlowGuidedDCN(DCNv2): '''Use other features to generate offsets and masks''' def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1): super(FlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Conv2d( in_channels, channels_, kernel_size, stride, padding, bias=True) self.init_offset() def init_offset(self): self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.bias.data.zero_() def forward(self, input, fea, flows): '''input: input features for deformable conv: N, C, H, W. fea: other features used for generating offsets and mask: N, C, H, W. flows: N, 2, H, W. ''' out = self.conv_offset_mask(fea) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1) offset_mean = torch.mean(torch.abs(offset)) if offset_mean > 250: print('FlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean)) # offset = offset.clamp(-50, 50) # return None mask = torch.sigmoid(mask) return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups) class InsideFlowGuidedDCN(DCNv2): '''Use other features to generate offsets and masks''' def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, deformable_groups=1): super(InsideFlowGuidedDCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Sequential( nn.Conv2d(in_channels*2+2, out_channels, kernel_size, stride, padding, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(out_channels, channels_, kernel_size, stride, padding, bias=True) ) self.reset_parameters() self.init_offset() def reset_parameters(self): n = self.in_channels for k in self.kernel_size: n *= k stdv = 1.0 / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) self.bias.data.zero_() def init_offset(self): self.conv_offset_mask[-1].weight.data.zero_() self.conv_offset_mask[-1].bias.data.zero_() def forward(self, input, warped, ref, flows): '''input: input features for deformable conv: N, C, H, W. fea: other features used for generating offsets and mask: N, C, H, W. flows: N, 2, H, W. ''' out = self.conv_offset_mask(torch.cat([warped, ref, flows], dim=1)) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 10 # max_residue_magnitude offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1) offset_mean = torch.mean(torch.abs(offset)) if offset_mean > 250: print('InsideFlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean)) print('flow mean is {}'.format(torch.abs(flows).mean())) offset = offset.clamp(-50, 50) # return None mask = torch.sigmoid(mask) return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.deformable_groups) class _DCNv2Pooling(Function): @staticmethod def forward( ctx, input, rois, offset, spatial_scale, pooled_size, output_dim, no_trans, group_size=1, part_size=None, sample_per_part=4, trans_std=0.0, ): ctx.spatial_scale = spatial_scale ctx.no_trans = int(no_trans) ctx.output_dim = output_dim ctx.group_size = group_size ctx.pooled_size = pooled_size ctx.part_size = pooled_size if part_size is None else part_size ctx.sample_per_part = sample_per_part ctx.trans_std = trans_std output, output_count = _backend.dcn_v2_psroi_pooling_forward( input, rois, offset, ctx.no_trans, ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size, ctx.part_size, ctx.sample_per_part, ctx.trans_std, ) ctx.save_for_backward(input, rois, offset, output_count) return output @staticmethod @once_differentiable def backward(ctx, grad_output): input, rois, offset, output_count = ctx.saved_tensors grad_input, grad_offset = _backend.dcn_v2_psroi_pooling_backward( grad_output, input, rois, offset, output_count, ctx.no_trans, ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size, ctx.part_size, ctx.sample_per_part, ctx.trans_std, ) return grad_input, None, grad_offset, None, None, None, None, None, None, None, None dcn_v2_pooling = _DCNv2Pooling.apply class DCNv2Pooling(nn.Module): def __init__( self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1, part_size=None, sample_per_part=4, trans_std=0.0, ): super(DCNv2Pooling, self).__init__() self.spatial_scale = spatial_scale self.pooled_size = pooled_size self.output_dim = output_dim self.no_trans = no_trans self.group_size = group_size self.part_size = pooled_size if part_size is None else part_size self.sample_per_part = sample_per_part self.trans_std = trans_std def forward(self, input, rois, offset): assert input.shape[1] == self.output_dim if self.no_trans: offset = input.new() return dcn_v2_pooling( input, rois, offset, self.spatial_scale, self.pooled_size, self.output_dim, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std, ) class DCNPooling(DCNv2Pooling): def __init__( self, spatial_scale, pooled_size, output_dim, no_trans, group_size=1, part_size=None, sample_per_part=4, trans_std=0.0, deform_fc_dim=1024, ): super(DCNPooling, self).__init__( spatial_scale, pooled_size, output_dim, no_trans, group_size, part_size, sample_per_part, trans_std, ) self.deform_fc_dim = deform_fc_dim if not no_trans: self.offset_mask_fc = nn.Sequential( nn.Linear( self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim ), nn.ReLU(inplace=True), nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.ReLU(inplace=True), nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3), ) self.offset_mask_fc[4].weight.data.zero_() self.offset_mask_fc[4].bias.data.zero_() def forward(self, input, rois): offset = input.new() if not self.no_trans: # do roi_align first n = rois.shape[0] roi = dcn_v2_pooling( input, rois, offset, self.spatial_scale, self.pooled_size, self.output_dim, True, # no trans self.group_size, self.part_size, self.sample_per_part, self.trans_std, ) # build mask and offset offset_mask = self.offset_mask_fc(roi.view(n, -1)) offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size) o1, o2, mask = torch.chunk(offset_mask, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) # do pooling with offset and mask return ( dcn_v2_pooling( input, rois, offset, self.spatial_scale, self.pooled_size, self.output_dim, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std, ) * mask ) # only roi_align return dcn_v2_pooling( input, rois, offset, self.spatial_scale, self.pooled_size, self.output_dim, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std, ) ================================================ FILE: code/synthetic/bsrt/model/DCNv2/files.txt ================================================ /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.cpython-37m-x86_64-linux-gnu.so /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/_ext.py /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/PKG-INFO /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/SOURCES.txt /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/dependency_links.txt /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/native_libs.txt /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/not-zip-safe /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/EGG-INFO/top_level.txt /home/luoziwei/miniconda3/lib/python3.7/site-packages/DCNv2-0.1-py3.7-linux-x86_64.egg/__pycache__/_ext.cpython-37.pyc ================================================ FILE: code/synthetic/bsrt/model/DCNv2/make.sh ================================================ #!/usr/bin/env bash python setup.py build develop ================================================ FILE: code/synthetic/bsrt/model/DCNv2/setup.py ================================================ #!/usr/bin/env python import glob import os import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) os.environ["CC"] = "g++" sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": []} define_macros = [] if True: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] else: # raise NotImplementedError('Cuda is not available') pass sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] ext_modules = [ extension( "_ext", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="DCNv2", version="0.1", author="charlesshang", url="https://github.com/charlesshang/DCNv2", description="deformable convolutional networks", packages=find_packages(exclude=("configs", "tests")), # install_requires=requirements, ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_cpu.cpp ================================================ #include #include "cpu/dcn_v2_im2col_cpu.h" #include //#include #include //#include //#include //extern THCState *state; // author: Charles Shang // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu // modified from the CUDA version for CPU use by Daniel K. Suhendro at::Tensor dcn_v2_cpu_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group) { // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_out = weight.size(0); const int channels_kernel = weight.size(1); const int kernel_h_ = weight.size(2); const int kernel_w_ = weight.size(3); // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); // printf("Channels: %d %d\n", channels, channels_kernel); // printf("Channels: %d %d\n", channels_out, channels_kernel); AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); AT_ASSERTM(channels == channels_kernel, "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; auto ones = at::ones({height_out, width_out}, input.options()); auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); using scalar_t = float; for (int b = 0; b < batch; b++) { auto input_n = input.select(0, b); auto offset_n = offset.select(0, b); auto mask_n = mask.select(0, b); auto output_n = output.select(0, b); // Do Bias first: // M,N,K are dims of matrix A and B // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) // (N x 1) (1 x M) long m_ = channels_out; long n_ = height_out * width_out; long k_ = 1; THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f, ones.contiguous().data(), k_, bias.contiguous().data(), k_, 0.0f, output_n.data(), n_); modulated_deformable_im2col_cpu(input_n.data(), offset_n.data(), mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns.data()); //(k * m) x (m * n) // Y = WC long m = channels_out; long n = height_out * width_out; long k = channels * kernel_h * kernel_w; THFloatBlas_gemm('n', 'n', n, m, k, 1.0f, columns.data(), n, weight.data(), k, 1.0f, output_n.data(), n); } return output; } std::vector dcn_v2_cpu_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group) { THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_out = weight.size(0); const int channels_kernel = weight.size(1); const int kernel_h_ = weight.size(2); const int kernel_w_ = weight.size(3); AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); AT_ASSERTM(channels == channels_kernel, "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; auto ones = at::ones({height_out, width_out}, input.options()); auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); auto grad_input = at::zeros_like(input); auto grad_weight = at::zeros_like(weight); auto grad_bias = at::zeros_like(bias); auto grad_offset = at::zeros_like(offset); auto grad_mask = at::zeros_like(mask); using scalar_t = float; for (int b = 0; b < batch; b++) { auto input_n = input.select(0, b); auto offset_n = offset.select(0, b); auto mask_n = mask.select(0, b); auto grad_output_n = grad_output.select(0, b); auto grad_input_n = grad_input.select(0, b); auto grad_offset_n = grad_offset.select(0, b); auto grad_mask_n = grad_mask.select(0, b); long m = channels * kernel_h * kernel_w; long n = height_out * width_out; long k = channels_out; THFloatBlas_gemm('n', 't', n, m, k, 1.0f, grad_output_n.data(), n, weight.data(), m, 0.0f, columns.data(), n); // gradient w.r.t. input coordinate data modulated_deformable_col2im_coord_cpu(columns.data(), input_n.data(), offset_n.data(), mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, grad_offset_n.data(), grad_mask_n.data()); // gradient w.r.t. input data modulated_deformable_col2im_cpu(columns.data(), offset_n.data(), mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, grad_input_n.data()); // gradient w.r.t. weight, dWeight should accumulate across the batch and group modulated_deformable_im2col_cpu(input_n.data(), offset_n.data(), mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns.data()); long m_ = channels_out; long n_ = channels * kernel_h * kernel_w; long k_ = height_out * width_out; THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f, columns.data(), k_, grad_output_n.data(), k_, 1.0f, grad_weight.data(), n_); // gradient w.r.t. bias // long m_ = channels_out; // long k__ = height_out * width_out; // THFloatBlas_gemv('t', k_, m_, 1.0f, // grad_output_n.data(), k_, // ones.data(), 1, 1.0f, // grad_bias.data(), 1); } return { grad_input, grad_offset, grad_mask, grad_weight, grad_bias }; } ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.cpp ================================================ #include "dcn_v2_im2col_cpu.h" #include #include #include #include //#include #include //#include //#include // modified from the CUDA version for CPU use by Daniel K. Suhendro /*#define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; }*/ float dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_width, const int height, const int width, float h, float w) { int h_low = floor(h); int w_low = floor(w); int h_high = h_low + 1; int w_high = w_low + 1; float lh = h - h_low; float lw = w - w_low; float hh = 1 - lh, hw = 1 - lw; float v1 = 0; if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; float v2 = 0; if (h_low >= 0 && w_high <= width - 1) v2 = bottom_data[h_low * data_width + w_high]; float v3 = 0; if (h_high <= height - 1 && w_low >= 0) v3 = bottom_data[h_high * data_width + w_low]; float v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) v4 = bottom_data[h_high * data_width + w_high]; float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } float dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w, const int h, const int w, const int height, const int width) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { //empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; float weight = 0; if (h == argmax_h_low && w == argmax_w_low) weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); if (h == argmax_h_low && w == argmax_w_high) weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); if (h == argmax_h_high && w == argmax_w_low) weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); if (h == argmax_h_high && w == argmax_w_high) weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); return weight; } float dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w, const int height, const int width, const float *im_data, const int data_width, const int bp_dir) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { //empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; float weight = 0; if (bp_dir == 0) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } else if (bp_dir == 1) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } return weight; } void modulated_deformable_im2col_cpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int num_channels, const int deformable_group, const int height_col, const int width_col, float *data_col) { // launch channels * batch_size * height_col * width_col cores for(int index=0; index(0); const float h_im = h_in + i * dilation_h + offset_h; const float w_im = w_in + j * dilation_w + offset_w; //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { //const float map_h = i * dilation_h + offset_h; //const float map_w = j * dilation_w + offset_w; //const int cur_height = height - h_in; //const int cur_width = width - w_in; //val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, cur_height, cur_width, map_h, map_w); val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im); } *data_col_ptr = val * mask; // data_col_ptr += batch_size * height_col * width_col; data_col_ptr += height_col * width_col; } } } } void modulated_deformable_col2im_cpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int deformable_group, const int height_col, const int width_col, float *grad_im) { for(int index = 0; index < n; index++) { const int j = (index / width_col / height_col / batch_size) % kernel_w; const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; // compute the start and end of the output const int deformable_group_index = c / channel_per_deformable_group; int w_out = index % width_col; int h_out = (index / width_col) % height_col; int b = (index / width_col / height_col) % batch_size; int w_in = w_out * stride_w - pad_w; int h_in = h_out * stride_h - pad_h; const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; const float cur_inv_h_data = h_in + i * dilation_h + offset_h; const float cur_inv_w_data = w_in + j * dilation_w + offset_w; const float cur_top_grad = data_col[index] * mask; const int cur_h = (int)cur_inv_h_data; const int cur_w = (int)cur_inv_w_data; for (int dy = -2; dy <= 2; dy++) { for (int dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1) { int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; float weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); //atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); *(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad; } } } } } void modulated_deformable_col2im_coord_cpu_kernel(const int n, const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int offset_channels, const int deformable_group, const int height_col, const int width_col, float *grad_offset, float *grad_mask) { for(int index = 0; index < n; index++) { float val = 0, mval = 0; int w = index % width_col; int h = (index / width_col) % height_col; int c = (index / width_col / height_col) % offset_channels; int b = (index / width_col / height_col) / offset_channels; // compute the start and end of the output const int deformable_group_index = c / (2 * kernel_h * kernel_w); const int col_step = kernel_h * kernel_w; int cnt = 0; const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) { const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; const int bp_dir = offset_c % 2; int j = (col_pos / width_col / height_col / batch_size) % kernel_w; int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; int w_out = col_pos % width_col; int h_out = (col_pos / width_col) % height_col; int w_in = w_out * stride_w - pad_w; int h_in = h_out * stride_h - pad_h; const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; float inv_h = h_in + i * dilation_h + offset_h; float inv_w = w_in + j * dilation_w + offset_w; if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { inv_h = inv_w = -2; } else { mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); } const float weight = dmcn_get_coordinate_weight_cpu( inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, width, bp_dir); val += weight * data_col_ptr[col_pos] * mask; cnt += 1; } // KERNEL_ASSIGN(grad_offset[index], offset_req, val); grad_offset[index] = val; if (offset_c % 2 == 0) // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; } } void modulated_deformable_im2col_cpu(const float* data_im, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* data_col) { // num_axes should be smaller than block size const int channel_per_deformable_group = channels / deformable_group; const int num_kernels = channels * batch_size * height_col * width_col; modulated_deformable_im2col_cpu_kernel( num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, channels, deformable_group, height_col, width_col, data_col); /*cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); }*/ } void modulated_deformable_col2im_cpu(const float* data_col, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* grad_im){ const int channel_per_deformable_group = channels / deformable_group; const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; modulated_deformable_col2im_cpu_kernel( num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, deformable_group, height_col, width_col, grad_im); /*cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); }*/ } void modulated_deformable_col2im_coord_cpu(const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* grad_offset, float* grad_mask) { const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; modulated_deformable_col2im_coord_cpu_kernel( num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, grad_offset, grad_mask); /*cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); }*/ } ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_im2col_cpu.h ================================================ /*! ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** * * COPYRIGHT * * All contributions by the University of California: * Copyright (c) 2014-2017 The Regents of the University of California (Regents) * All rights reserved. * * All other contributions: * Copyright (c) 2014-2017, the respective contributors * All rights reserved. * * Caffe uses a shared copyright model: each contributor holds copyright over * their contributions to Caffe. The project versioning records all such * contribution and copyright details. If a contributor wants to further mark * their specific copyright on a particular contribution, they should indicate * their copyright solely in the commit message of the change when it is * committed. * * LICENSE * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * CONTRIBUTION AGREEMENT * * By contributing to the BVLC/caffe repository through pull-request, comment, * or otherwise, the contributor releases their content to the * license and copyright terms herein. * ***************** END Caffe Copyright Notice and Disclaimer ******************** * * Copyright (c) 2018 Microsoft * Licensed under The MIT License [see LICENSE for details] * \file modulated_deformable_im2col.h * \brief Function definitions of converting an image to * column matrix based on kernel, padding, dilation, and offset. * These functions are mainly used in deformable convolution operators. * \ref: https://arxiv.org/abs/1811.11168 * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu */ /***************** Adapted by Charles Shang *********************/ // modified from the CUDA version for CPU use by Daniel K. Suhendro #ifndef DCN_V2_IM2COL_CPU #define DCN_V2_IM2COL_CPU #ifdef __cplusplus extern "C" { #endif void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *data_col); void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *grad_im); void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *grad_offset, float *grad_mask); #ifdef __cplusplus } #endif #endif ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cpu/dcn_v2_psroi_pooling_cpu.cpp ================================================ /*! * Copyright (c) 2017 Microsoft * Licensed under The MIT License [see LICENSE for details] * \file deformable_psroi_pooling.cu * \brief * \author Yi Li, Guodong Zhang, Jifeng Dai */ /***************** Adapted by Charles Shang *********************/ // modified from the CUDA version for CPU use by Daniel K. Suhendro #include #include #include #include //#include #include //#include //#include /*#define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; }*/ template T bilinear_interp_cpu( const T *data, const T x, const T y, const int width, const int height) { int x1 = floor(x); int x2 = ceil(x); int y1 = floor(y); int y2 = ceil(y); T dist_x = static_cast(x - x1); T dist_y = static_cast(y - y1); T value11 = data[y1 * width + x1]; T value12 = data[y2 * width + x1]; T value21 = data[y1 * width + x2]; T value22 = data[y2 * width + x2]; T value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; return value; } template void DeformablePSROIPoolForwardKernelCpu( const int count, const T *bottom_data, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const T *bottom_rois, const T *bottom_trans, const int no_trans, const T trans_std, const int sample_per_part, const int output_dim, const int group_size, const int part_size, const int num_classes, const int channels_each_class, T *top_data, T *top_count) { for(int index = 0; index < count; index++) { // The output is in order (n, ctop, ph, pw) int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int ctop = (index / pooled_width / pooled_height) % output_dim; int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0 T roi_height = std::max(roi_end_h - roi_start_h, T(0.1)); // Compute w and h at bottom T bin_size_h = roi_height / static_cast(pooled_height); T bin_size_w = roi_width / static_cast(pooled_width); T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); int part_h = floor(static_cast(ph) / pooled_height * part_size); int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; T sum = 0; int count = 0; int gw = floor(static_cast(pw) * group_size / pooled_width); int gh = floor(static_cast(ph) * group_size / pooled_height); gw = std::min(std::max(gw, 0), group_size - 1); gh = std::min(std::max(gh, 0), group_size - 1); const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; for (int ih = 0; ih < sample_per_part; ih++) { for (int iw = 0; iw < sample_per_part; iw++) { T w = wstart + iw * sub_bin_size_w; T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = std::min(std::max(w, T(0.)), width - T(1.)); h = std::min(std::max(h, T(0.)), height - T(1.)); int c = (ctop * group_size + gh) * group_size + gw; T val = bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height); sum += val; count++; } } top_data[index] = count == 0 ? static_cast(0) : sum / count; top_count[index] = count; } } template void DeformablePSROIPoolBackwardAccKernelCpu( const int count, const T *top_diff, const T *top_count, const int num_rois, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int output_dim, T *bottom_data_diff, T *bottom_trans_diff, const T *bottom_data, const T *bottom_rois, const T *bottom_trans, const int no_trans, const T trans_std, const int sample_per_part, const int group_size, const int part_size, const int num_classes, const int channels_each_class) { for(int index = 0; index < count; index++) { // The output is in order (n, ctop, ph, pw) int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int ctop = (index / pooled_width / pooled_height) % output_dim; int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0 T roi_height = std::max(roi_end_h - roi_start_h, T(0.1)); // Compute w and h at bottom T bin_size_h = roi_height / static_cast(pooled_height); T bin_size_w = roi_width / static_cast(pooled_width); T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); int part_h = floor(static_cast(ph) / pooled_height * part_size); int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; if (top_count[index] <= 0) { continue; } T diff_val = top_diff[index] / top_count[index]; const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; int gw = floor(static_cast(pw) * group_size / pooled_width); int gh = floor(static_cast(ph) * group_size / pooled_height); gw = std::min(std::max(gw, 0), group_size - 1); gh = std::min(std::max(gh, 0), group_size - 1); for (int ih = 0; ih < sample_per_part; ih++) { for (int iw = 0; iw < sample_per_part; iw++) { T w = wstart + iw * sub_bin_size_w; T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = std::min(std::max(w, T(0.)), width - T(1.)); h = std::min(std::max(h, T(0.)), height - T(1.)); int c = (ctop * group_size + gh) * group_size + gw; // backward on feature int x0 = floor(w); int x1 = ceil(w); int y0 = floor(h); int y1 = ceil(h); T dist_x = w - x0, dist_y = h - y0; T q00 = (1 - dist_x) * (1 - dist_y); T q01 = (1 - dist_x) * dist_y; T q10 = dist_x * (1 - dist_y); T q11 = dist_x * dist_y; int bottom_index_base = c * height * width; /*atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);*/ *(offset_bottom_data_diff + bottom_index_base + y0 * width + x0) += q00 * diff_val; *(offset_bottom_data_diff + bottom_index_base + y1 * width + x0) += q01 * diff_val; *(offset_bottom_data_diff + bottom_index_base + y0 * width + x1) += q10 * diff_val; *(offset_bottom_data_diff + bottom_index_base + y1 * width + x1) += q11 * diff_val; if (no_trans) { continue; } T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; diff_x *= roi_width; T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; diff_y *= roi_height; /*atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);*/ *(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w) += diff_x; *(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w) += diff_y; } } } } std::tuple dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_trans = no_trans ? 2 : trans.size(1); const int num_bbox = bbox.size(0); AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); auto pooled_height = pooled_size; auto pooled_width = pooled_size; auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); long out_size = num_bbox * output_dim * pooled_height * pooled_width; auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; //cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (out.numel() == 0) { //THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); } /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512);*/ AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cpu_forward", [&] { DeformablePSROIPoolForwardKernelCpu( out_size, input.contiguous().data(), spatial_scale, channels, height, width, pooled_height, pooled_width, bbox.contiguous().data(), trans.contiguous().data(), no_trans, trans_std, sample_per_part, output_dim, group_size, part_size, num_classes, channels_each_class, out.data(), top_count.data()); }); //THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); } std::tuple dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { /*AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_trans = no_trans ? 2 : trans.size(1); const int num_bbox = bbox.size(0); AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); auto pooled_height = pooled_size; auto pooled_width = pooled_size; long out_size = num_bbox * output_dim * pooled_height * pooled_width; const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); auto trans_grad = at::zeros_like(trans); if (input_grad.numel() == 0) { //THCudaCheck(cudaGetLastError()); return std::make_tuple(input_grad, trans_grad); } /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512); cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/ AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cpu_backward", [&] { DeformablePSROIPoolBackwardAccKernelCpu( out_size, out_grad.contiguous().data(), top_count.contiguous().data(), num_bbox, spatial_scale, channels, height, width, pooled_height, pooled_width, output_dim, input_grad.contiguous().data(), trans_grad.contiguous().data(), input.contiguous().data(), bbox.contiguous().data(), trans.contiguous().data(), no_trans, trans_std, sample_per_part, group_size, part_size, num_classes, channels_each_class); }); //THCudaCheck(cudaGetLastError()); return std::make_tuple(input_grad, trans_grad); } ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cpu/vision.h ================================================ #pragma once #include at::Tensor dcn_v2_cpu_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group); std::vector dcn_v2_cpu_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group); std::tuple dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std); std::tuple dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std); ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_cuda.cu ================================================ #include #include "cuda/dcn_v2_im2col_cuda.h" #include #include #include #include #include #include #include #include #include #include THCState *state = at::globalContext().lazyInitCUDA(); static cublasOperation_t _cublasOpFromChar(char op) { switch (op) { case 'n': case 'N': return CUBLAS_OP_N; case 't': case 'T': return CUBLAS_OP_T; case 'c': case 'C': return CUBLAS_OP_C; } AT_ERROR( "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); } static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) { // Note: leading dimensions generally are checked that they are > 0 // and at least as big the result requires (even if the value won't // be used). // Q: Why does Level3 check trans but this doesn't? // A: In level 2, the sizes (m, n) specify the size of A // (independent of trans value). In level 3. the sizes (m, n, k) // specify the sizes of op(A), op(B) where op depend on trans // values. if (n <= 1) *lda = std::max(m, 1); } // author: Charles Shang // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu // [batch gemm] // https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu __global__ void createBatchGemmBuffer(const float **input_b, float **output_b, float **columns_b, const float **ones_b, const float **weight_b, const float **bias_b, float *input, float *output, float *columns, float *ones, float *weight, float *bias, const int input_stride, const int output_stride, const int columns_stride, const int ones_stride, const int num_batches) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num_batches) { input_b[idx] = input + idx * input_stride; output_b[idx] = output + idx * output_stride; columns_b[idx] = columns + idx * columns_stride; ones_b[idx] = ones + idx * ones_stride; // share weights and bias within a Mini-Batch weight_b[idx] = weight; bias_b[idx] = bias; } } at::Tensor dcn_v2_cuda_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group) { using scalar_t = float; // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_out = weight.size(0); const int channels_kernel = weight.size(1); const int kernel_h_ = weight.size(2); const int kernel_w_ = weight.size(3); // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); // printf("Channels: %d %d\n", channels, channels_kernel); // printf("Channels: %d %d\n", channels_out, channels_kernel); AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); AT_ASSERTM(channels == channels_kernel, "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; auto ones = at::ones({batch, height_out, width_out}, input.options()); auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); // prepare for batch-wise computing, which is significantly faster than instance-wise computing // when batch size is large. // launch batch threads int matrices_size = batch * sizeof(float *); auto input_b = static_cast(THCudaMalloc(state, matrices_size)); auto output_b = static_cast(THCudaMalloc(state, matrices_size)); auto columns_b = static_cast(THCudaMalloc(state, matrices_size)); auto ones_b = static_cast(THCudaMalloc(state, matrices_size)); auto weight_b = static_cast(THCudaMalloc(state, matrices_size)); auto bias_b = static_cast(THCudaMalloc(state, matrices_size)); const int block = 128; const int grid = (batch + block - 1) / block; createBatchGemmBuffer<<>>( input_b, output_b, columns_b, ones_b, weight_b, bias_b, input.data_ptr(), output.data_ptr(), columns.data_ptr(), ones.data_ptr(), weight.data_ptr(), bias.data_ptr(), channels * width * height, channels_out * width_out * height_out, channels * kernel_h * kernel_w * height_out * width_out, height_out * width_out, batch); long m_ = channels_out; long n_ = height_out * width_out; long k_ = 1; THCudaBlas_SgemmBatched(state, 't', 'n', n_, m_, k_, 1.0f, ones_b, k_, bias_b, k_, 0.0f, output_b, n_, batch); modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), input.data_ptr(), offset.data_ptr(), mask.data_ptr(), batch, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns.data_ptr()); long m = channels_out; long n = height_out * width_out; long k = channels * kernel_h * kernel_w; THCudaBlas_SgemmBatched(state, 'n', 'n', n, m, k, 1.0f, (const float **)columns_b, n, weight_b, k, 1.0f, output_b, n, batch); THCudaFree(state, input_b); THCudaFree(state, output_b); THCudaFree(state, columns_b); THCudaFree(state, ones_b); THCudaFree(state, weight_b); THCudaFree(state, bias_b); return output; } __global__ void createBatchGemmBufferBackward( float **grad_output_b, float **columns_b, float **ones_b, float **weight_b, float **grad_weight_b, float **grad_bias_b, float *grad_output, float *columns, float *ones, float *weight, float *grad_weight, float *grad_bias, const int grad_output_stride, const int columns_stride, const int ones_stride, const int num_batches) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < num_batches) { grad_output_b[idx] = grad_output + idx * grad_output_stride; columns_b[idx] = columns + idx * columns_stride; ones_b[idx] = ones + idx * ones_stride; // share weights and bias within a Mini-Batch weight_b[idx] = weight; grad_weight_b[idx] = grad_weight; grad_bias_b[idx] = grad_bias; } } std::vector dcn_v2_cuda_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group) { THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_out = weight.size(0); const int channels_kernel = weight.size(1); const int kernel_h_ = weight.size(2); const int kernel_w_ = weight.size(3); AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); AT_ASSERTM(channels == channels_kernel, "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; auto ones = at::ones({height_out, width_out}, input.options()); auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); auto grad_input = at::zeros_like(input); auto grad_weight = at::zeros_like(weight); auto grad_bias = at::zeros_like(bias); auto grad_offset = at::zeros_like(offset); auto grad_mask = at::zeros_like(mask); using scalar_t = float; for (int b = 0; b < batch; b++) { auto input_n = input.select(0, b); auto offset_n = offset.select(0, b); auto mask_n = mask.select(0, b); auto grad_output_n = grad_output.select(0, b); auto grad_input_n = grad_input.select(0, b); auto grad_offset_n = grad_offset.select(0, b); auto grad_mask_n = grad_mask.select(0, b); long m = channels * kernel_h * kernel_w; long n = height_out * width_out; long k = channels_out; THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, grad_output_n.data_ptr(), n, weight.data_ptr(), m, 0.0f, columns.data_ptr(), n); // gradient w.r.t. input coordinate data modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(), columns.data_ptr(), input_n.data_ptr(), offset_n.data_ptr(), mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, grad_offset_n.data_ptr(), grad_mask_n.data_ptr()); // gradient w.r.t. input data modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(), columns.data_ptr(), offset_n.data_ptr(), mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, grad_input_n.data_ptr()); // gradient w.r.t. weight, dWeight should accumulate across the batch and group modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), input_n.data_ptr(), offset_n.data_ptr(), mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns.data_ptr()); long m_ = channels_out; long n_ = channels * kernel_h * kernel_w; long k_ = height_out * width_out; THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, columns.data_ptr(), k_, grad_output_n.data_ptr(), k_, 1.0f, grad_weight.data_ptr(), n_); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasOperation_t op = _cublasOpFromChar('t'); _cublasAdjustLdLevel2(k_, m_, &k_); scalar_t* grad_output_n_float = grad_output_n.data_ptr(); scalar_t* one_float = ones.data_ptr(); scalar_t alpha = 1.0; scalar_t beta = 1.0; cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr(), 1); } return { grad_input, grad_offset, grad_mask, grad_weight, grad_bias }; } ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu ================================================ #include "dcn_v2_im2col_cuda.h" #include #include #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } __device__ float dmcn_im2col_bilinear_cuda(const float *bottom_data, const int data_width, const int height, const int width, float h, float w) { int h_low = floor(h); int w_low = floor(w); int h_high = h_low + 1; int w_high = w_low + 1; float lh = h - h_low; float lw = w - w_low; float hh = 1 - lh, hw = 1 - lw; float v1 = 0; if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; float v2 = 0; if (h_low >= 0 && w_high <= width - 1) v2 = bottom_data[h_low * data_width + w_high]; float v3 = 0; if (h_high <= height - 1 && w_low >= 0) v3 = bottom_data[h_high * data_width + w_low]; float v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) v4 = bottom_data[h_high * data_width + w_high]; float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } __device__ float dmcn_get_gradient_weight_cuda(float argmax_h, float argmax_w, const int h, const int w, const int height, const int width) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { //empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; float weight = 0; if (h == argmax_h_low && w == argmax_w_low) weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); if (h == argmax_h_low && w == argmax_w_high) weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); if (h == argmax_h_high && w == argmax_w_low) weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); if (h == argmax_h_high && w == argmax_w_high) weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); return weight; } __device__ float dmcn_get_coordinate_weight_cuda(float argmax_h, float argmax_w, const int height, const int width, const float *im_data, const int data_width, const int bp_dir) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { //empty return 0; } int argmax_h_low = floor(argmax_h); int argmax_w_low = floor(argmax_w); int argmax_h_high = argmax_h_low + 1; int argmax_w_high = argmax_w_low + 1; float weight = 0; if (bp_dir == 0) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } else if (bp_dir == 1) { if (argmax_h_low >= 0 && argmax_w_low >= 0) weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; if (argmax_h_low >= 0 && argmax_w_high <= width - 1) weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; if (argmax_h_high <= height - 1 && argmax_w_low >= 0) weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; } return weight; } __global__ void modulated_deformable_im2col_gpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int num_channels, const int deformable_group, const int height_col, const int width_col, float *data_col) { // launch channels * batch_size * height_col * width_col cores CUDA_KERNEL_LOOP(index, n) { // NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow) // here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis // index index of output matrix const int w_col = index % width_col; const int h_col = (index / width_col) % height_col; // const int b_col = (index / width_col / height_col) % batch_size; const int b_col = (index / width_col / height_col / num_channels) % batch_size; // const int c_im = (index / width_col / height_col) / batch_size; const int c_im = (index / width_col / height_col) % num_channels; // const int c_col = c_im * kernel_h * kernel_w; const int c_col = c_im * kernel_h * kernel_w; // compute deformable group index const int deformable_group_index = c_im / channel_per_deformable_group; const int h_in = h_col * stride_h - pad_h; const int w_in = w_col * stride_w - pad_w; // float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; float *data_col_ptr = data_col + ((b_col * num_channels * kernel_w * kernel_h + c_col) * height_col + h_col) * width_col + w_col; //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; for (int i = 0; i < kernel_h; ++i) { for (int j = 0; j < kernel_w; ++j) { const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; float val = static_cast(0); const float h_im = h_in + i * dilation_h + offset_h; const float w_im = w_in + j * dilation_w + offset_w; //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { //const float map_h = i * dilation_h + offset_h; //const float map_w = j * dilation_w + offset_w; //const int cur_height = height - h_in; //const int cur_width = width - w_in; //val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, cur_height, cur_width, map_h, map_w); val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, height, width, h_im, w_im); } *data_col_ptr = val * mask; // data_col_ptr += batch_size * height_col * width_col; data_col_ptr += height_col * width_col; } } } } __global__ void modulated_deformable_col2im_gpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int deformable_group, const int height_col, const int width_col, float *grad_im) { CUDA_KERNEL_LOOP(index, n) { const int j = (index / width_col / height_col / batch_size) % kernel_w; const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; // compute the start and end of the output const int deformable_group_index = c / channel_per_deformable_group; int w_out = index % width_col; int h_out = (index / width_col) % height_col; int b = (index / width_col / height_col) % batch_size; int w_in = w_out * stride_w - pad_w; int h_in = h_out * stride_h - pad_h; const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; const float cur_inv_h_data = h_in + i * dilation_h + offset_h; const float cur_inv_w_data = w_in + j * dilation_w + offset_w; const float cur_top_grad = data_col[index] * mask; const int cur_h = (int)cur_inv_h_data; const int cur_w = (int)cur_inv_w_data; for (int dy = -2; dy <= 2; dy++) { for (int dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1) { int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; float weight = dmcn_get_gradient_weight_cuda(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); } } } } } __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int channel_per_deformable_group, const int batch_size, const int offset_channels, const int deformable_group, const int height_col, const int width_col, float *grad_offset, float *grad_mask) { CUDA_KERNEL_LOOP(index, n) { float val = 0, mval = 0; int w = index % width_col; int h = (index / width_col) % height_col; int c = (index / width_col / height_col) % offset_channels; int b = (index / width_col / height_col) / offset_channels; // compute the start and end of the output const int deformable_group_index = c / (2 * kernel_h * kernel_w); const int col_step = kernel_h * kernel_w; int cnt = 0; const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) { const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; const int bp_dir = offset_c % 2; int j = (col_pos / width_col / height_col / batch_size) % kernel_w; int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; int w_out = col_pos % width_col; int h_out = (col_pos / width_col) % height_col; int w_in = w_out * stride_w - pad_w; int h_in = h_out * stride_h - pad_h; const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); const float offset_h = data_offset_ptr[data_offset_h_ptr]; const float offset_w = data_offset_ptr[data_offset_w_ptr]; const float mask = data_mask_ptr[data_mask_hw_ptr]; float inv_h = h_in + i * dilation_h + offset_h; float inv_w = w_in + j * dilation_w + offset_w; if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { inv_h = inv_w = -2; } else { mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cuda(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); } const float weight = dmcn_get_coordinate_weight_cuda( inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, width, bp_dir); val += weight * data_col_ptr[col_pos] * mask; cnt += 1; } // KERNEL_ASSIGN(grad_offset[index], offset_req, val); grad_offset[index] = val; if (offset_c % 2 == 0) // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; } } void modulated_deformable_im2col_cuda(cudaStream_t stream, const float* data_im, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* data_col) { // num_axes should be smaller than block size const int channel_per_deformable_group = channels / deformable_group; const int num_kernels = channels * batch_size * height_col * width_col; modulated_deformable_im2col_gpu_kernel <<>>( num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, channels, deformable_group, height_col, width_col, data_col); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); } } void modulated_deformable_col2im_cuda(cudaStream_t stream, const float* data_col, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* grad_im){ const int channel_per_deformable_group = channels / deformable_group; const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; modulated_deformable_col2im_gpu_kernel <<>>( num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, deformable_group, height_col, width_col, grad_im); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); } } void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float* grad_offset, float* grad_mask) { const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; modulated_deformable_col2im_coord_gpu_kernel <<>>( num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, grad_offset, grad_mask); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_im2col_cuda.h ================================================ /*! ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** * * COPYRIGHT * * All contributions by the University of California: * Copyright (c) 2014-2017 The Regents of the University of California (Regents) * All rights reserved. * * All other contributions: * Copyright (c) 2014-2017, the respective contributors * All rights reserved. * * Caffe uses a shared copyright model: each contributor holds copyright over * their contributions to Caffe. The project versioning records all such * contribution and copyright details. If a contributor wants to further mark * their specific copyright on a particular contribution, they should indicate * their copyright solely in the commit message of the change when it is * committed. * * LICENSE * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * CONTRIBUTION AGREEMENT * * By contributing to the BVLC/caffe repository through pull-request, comment, * or otherwise, the contributor releases their content to the * license and copyright terms herein. * ***************** END Caffe Copyright Notice and Disclaimer ******************** * * Copyright (c) 2018 Microsoft * Licensed under The MIT License [see LICENSE for details] * \file modulated_deformable_im2col.h * \brief Function definitions of converting an image to * column matrix based on kernel, padding, dilation, and offset. * These functions are mainly used in deformable convolution operators. * \ref: https://arxiv.org/abs/1811.11168 * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu */ /***************** Adapted by Charles Shang *********************/ #ifndef DCN_V2_IM2COL_CUDA #define DCN_V2_IM2COL_CUDA #ifdef __cplusplus extern "C" { #endif void modulated_deformable_im2col_cuda(cudaStream_t stream, const float *data_im, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *data_col); void modulated_deformable_col2im_cuda(cudaStream_t stream, const float *data_col, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *grad_im); void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, float *grad_offset, float *grad_mask); #ifdef __cplusplus } #endif #endif ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu ================================================ /*! * Copyright (c) 2017 Microsoft * Licensed under The MIT License [see LICENSE for details] * \file deformable_psroi_pooling.cu * \brief * \author Yi Li, Guodong Zhang, Jifeng Dai */ /***************** Adapted by Charles Shang *********************/ #include #include #include #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } template __device__ T bilinear_interp_cuda( const T *data, const T x, const T y, const int width, const int height) { int x1 = floor(x); int x2 = ceil(x); int y1 = floor(y); int y2 = ceil(y); T dist_x = static_cast(x - x1); T dist_y = static_cast(y - y1); T value11 = data[y1 * width + x1]; T value12 = data[y2 * width + x1]; T value21 = data[y1 * width + x2]; T value22 = data[y2 * width + x2]; T value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; return value; } template __global__ void DeformablePSROIPoolForwardKernelCuda( const int count, const T *bottom_data, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const T *bottom_rois, const T *bottom_trans, const int no_trans, const T trans_std, const int sample_per_part, const int output_dim, const int group_size, const int part_size, const int num_classes, const int channels_each_class, T *top_data, T *top_count) { CUDA_KERNEL_LOOP(index, count) { // The output is in order (n, ctop, ph, pw) int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int ctop = (index / pooled_width / pooled_height) % output_dim; int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 T roi_height = max(roi_end_h - roi_start_h, 0.1); // Compute w and h at bottom T bin_size_h = roi_height / static_cast(pooled_height); T bin_size_w = roi_width / static_cast(pooled_width); T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); int part_h = floor(static_cast(ph) / pooled_height * part_size); int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; T sum = 0; int count = 0; int gw = floor(static_cast(pw) * group_size / pooled_width); int gh = floor(static_cast(ph) * group_size / pooled_height); gw = min(max(gw, 0), group_size - 1); gh = min(max(gh, 0), group_size - 1); const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; for (int ih = 0; ih < sample_per_part; ih++) { for (int iw = 0; iw < sample_per_part; iw++) { T w = wstart + iw * sub_bin_size_w; T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = min(max(w, 0.), width - 1.); h = min(max(h, 0.), height - 1.); int c = (ctop * group_size + gh) * group_size + gw; T val = bilinear_interp_cuda(offset_bottom_data + c * height * width, w, h, width, height); sum += val; count++; } } top_data[index] = count == 0 ? static_cast(0) : sum / count; top_count[index] = count; } } template __global__ void DeformablePSROIPoolBackwardAccKernelCuda( const int count, const T *top_diff, const T *top_count, const int num_rois, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int output_dim, T *bottom_data_diff, T *bottom_trans_diff, const T *bottom_data, const T *bottom_rois, const T *bottom_trans, const int no_trans, const T trans_std, const int sample_per_part, const int group_size, const int part_size, const int num_classes, const int channels_each_class) { CUDA_KERNEL_LOOP(index, count) { // The output is in order (n, ctop, ph, pw) int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int ctop = (index / pooled_width / pooled_height) % output_dim; int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 T roi_height = max(roi_end_h - roi_start_h, 0.1); // Compute w and h at bottom T bin_size_h = roi_height / static_cast(pooled_height); T bin_size_w = roi_width / static_cast(pooled_width); T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); int part_h = floor(static_cast(ph) / pooled_height * part_size); int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; if (top_count[index] <= 0) { continue; } T diff_val = top_diff[index] / top_count[index]; const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; int gw = floor(static_cast(pw) * group_size / pooled_width); int gh = floor(static_cast(ph) * group_size / pooled_height); gw = min(max(gw, 0), group_size - 1); gh = min(max(gh, 0), group_size - 1); for (int ih = 0; ih < sample_per_part; ih++) { for (int iw = 0; iw < sample_per_part; iw++) { T w = wstart + iw * sub_bin_size_w; T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = min(max(w, 0.), width - 1.); h = min(max(h, 0.), height - 1.); int c = (ctop * group_size + gh) * group_size + gw; // backward on feature int x0 = floor(w); int x1 = ceil(w); int y0 = floor(h); int y1 = ceil(h); T dist_x = w - x0, dist_y = h - y0; T q00 = (1 - dist_x) * (1 - dist_y); T q01 = (1 - dist_x) * dist_y; T q10 = dist_x * (1 - dist_y); T q11 = dist_x * dist_y; int bottom_index_base = c * height * width; atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); if (no_trans) { continue; } T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; diff_x *= roi_width; T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; diff_y *= roi_height; atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); } } } } std::tuple dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_trans = no_trans ? 2 : trans.size(1); const int num_bbox = bbox.size(0); AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); auto pooled_height = pooled_size; auto pooled_width = pooled_size; auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); long out_size = num_bbox * output_dim * pooled_height * pooled_width; auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (out.numel() == 0) { THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); } dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512); AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { DeformablePSROIPoolForwardKernelCuda<<>>( out_size, input.contiguous().data_ptr(), spatial_scale, channels, height, width, pooled_height, pooled_width, bbox.contiguous().data_ptr(), trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, output_dim, group_size, part_size, num_classes, channels_each_class, out.data_ptr(), top_count.data_ptr()); }); THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); } std::tuple dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); const int channels_trans = no_trans ? 2 : trans.size(1); const int num_bbox = bbox.size(0); AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); auto pooled_height = pooled_size; auto pooled_width = pooled_size; long out_size = num_bbox * output_dim * pooled_height * pooled_width; const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); auto trans_grad = at::zeros_like(trans); if (input_grad.numel() == 0) { THCudaCheck(cudaGetLastError()); return std::make_tuple(input_grad, trans_grad); } dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { DeformablePSROIPoolBackwardAccKernelCuda<<>>( out_size, out_grad.contiguous().data_ptr(), top_count.contiguous().data_ptr(), num_bbox, spatial_scale, channels, height, width, pooled_height, pooled_width, output_dim, input_grad.contiguous().data_ptr(), trans_grad.contiguous().data_ptr(), input.contiguous().data_ptr(), bbox.contiguous().data_ptr(), trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, group_size, part_size, num_classes, channels_each_class); }); THCudaCheck(cudaGetLastError()); return std::make_tuple(input_grad, trans_grad); } ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/cuda/vision.h ================================================ #pragma once #include #include at::Tensor dcn_v2_cuda_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group); std::vector dcn_v2_cuda_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group); std::tuple dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std); std::tuple dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std); ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/dcn_v2.h ================================================ #pragma once #include "cpu/vision.h" #ifdef WITH_CUDA #include "cuda/vision.h" #endif at::Tensor dcn_v2_forward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int deformable_group) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcn_v2_cuda_forward(input, weight, bias, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, deformable_group); #else AT_ERROR("Not compiled with GPU support"); #endif } else{ return dcn_v2_cpu_forward(input, weight, bias, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, deformable_group); } } std::vector dcn_v2_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, const at::Tensor &grad_output, int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int deformable_group) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcn_v2_cuda_backward(input, weight, bias, offset, mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, deformable_group); #else AT_ERROR("Not compiled with GPU support"); #endif } else{ return dcn_v2_cpu_backward(input, weight, bias, offset, mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, deformable_group); } } std::tuple dcn_v2_psroi_pooling_forward(const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcn_v2_psroi_pooling_cuda_forward(input, bbox, trans, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); #else AT_ERROR("Not compiled with GPU support"); #endif } else{ return dcn_v2_psroi_pooling_cpu_forward(input, bbox, trans, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); } } std::tuple dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, const at::Tensor &input, const at::Tensor &bbox, const at::Tensor &trans, const at::Tensor &top_count, const int no_trans, const float spatial_scale, const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return dcn_v2_psroi_pooling_cuda_backward(out_grad, input, bbox, trans, top_count, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); #else AT_ERROR("Not compiled with GPU support"); #endif } else{ return dcn_v2_psroi_pooling_cpu_backward(out_grad, input, bbox, trans, top_count, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); } } ================================================ FILE: code/synthetic/bsrt/model/DCNv2/src/vision.cpp ================================================ #include "dcn_v2.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); } ================================================ FILE: code/synthetic/bsrt/model/DCNv2/test.py ================================================ #!/usr/bin/env python from __future__ import absolute_import from __future__ import print_function from __future__ import division import time import torch import torch.nn as nn from torch.autograd import gradcheck from dcn_v2 import dcn_v2_conv, DCNv2, DCN from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling deformable_groups = 1 N, inC, inH, inW = 2, 2, 4, 4 outC = 2 kH, kW = 3, 3 def conv_identify(weight, bias): weight.data.zero_() bias.data.zero_() o, i, h, w = weight.shape y = h//2 x = w//2 for p in range(i): for q in range(o): if p == q: weight.data[q, p, y, x] = 1.0 def check_zero_offset(): conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, kernel_size=(kH, kW), stride=(1, 1), padding=(1, 1), bias=True).cuda() conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, kernel_size=(kH, kW), stride=(1, 1), padding=(1, 1), bias=True).cuda() dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda() conv_offset.weight.data.zero_() conv_offset.bias.data.zero_() conv_mask.weight.data.zero_() conv_mask.bias.data.zero_() conv_identify(dcn_v2.weight, dcn_v2.bias) input = torch.randn(N, inC, inH, inW).cuda() offset = conv_offset(input) mask = conv_mask(input) mask = torch.sigmoid(mask) output = dcn_v2(input, offset, mask) output *= 2 d = (input - output).abs().max() if d < 1e-10: print('Zero offset passed') else: print('Zero offset failed') print(input) print(output) def check_gradient_dconv(): input = torch.rand(N, inC, inH, inW).cuda() * 0.01 input.requires_grad = True offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2 # offset.data.zero_() # offset.data -= 0.5 offset.requires_grad = True mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() # mask.data.zero_() mask.requires_grad = True mask = torch.sigmoid(mask) weight = torch.randn(outC, inC, kH, kW).cuda() weight.requires_grad = True bias = torch.rand(outC).cuda() bias.requires_grad = True stride = 1 padding = 1 dilation = 1 print('check_gradient_dconv: ', gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups), eps=1e-3, atol=1e-4, rtol=1e-2)) def check_pooling_zero_offset(): input = torch.randn(2, 16, 64, 64).cuda().zero_() input[0, :, 16:26, 16:26] = 1. input[1, :, 10:20, 20:30] = 2. rois = torch.tensor([ [0, 65, 65, 103, 103], [1, 81, 41, 119, 79], ]).cuda().float() pooling = DCNv2Pooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=16, no_trans=True, group_size=1, trans_std=0.0).cuda() out = pooling(input, rois, input.new()) s = ', '.join(['%f' % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) print(s) dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=16, no_trans=False, group_size=1, trans_std=0.0).cuda() offset = torch.randn(20, 2, 7, 7).cuda().zero_() dout = dpooling(input, rois, offset) s = ', '.join(['%f' % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) print(s) def check_gradient_dpooling(): input = torch.randn(2, 3, 5, 5).cuda() * 0.01 N = 4 batch_inds = torch.randint(2, (N, 1)).cuda().float() x = torch.rand((N, 1)).cuda().float() * 15 y = torch.rand((N, 1)).cuda().float() * 15 w = torch.rand((N, 1)).cuda().float() * 10 h = torch.rand((N, 1)).cuda().float() * 10 rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) offset = torch.randn(N, 2, 3, 3).cuda() input.requires_grad = True offset.requires_grad = True spatial_scale = 1.0 / 4 pooled_size = 3 output_dim = 3 no_trans = 0 group_size = 1 trans_std = 0.0 sample_per_part = 4 part_size = pooled_size print('check_gradient_dpooling:', gradcheck(dcn_v2_pooling, (input, rois, offset, spatial_scale, pooled_size, output_dim, no_trans, group_size, part_size, sample_per_part, trans_std), eps=1e-4)) def example_dconv(): input = torch.randn(2, 64, 128, 128).cuda() # wrap all things (offset and mask) in DCN dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda() # print(dcn.weight.shape, input.shape) output = dcn(input) targert = output.new(*output.size()) targert.data.uniform_(-0.01, 0.01) error = (targert - output).mean() error.backward() print(output.shape) def example_dpooling(): input = torch.randn(2, 32, 64, 64).cuda() batch_inds = torch.randint(2, (20, 1)).cuda().float() x = torch.randint(256, (20, 1)).cuda().float() y = torch.randint(256, (20, 1)).cuda().float() w = torch.randint(64, (20, 1)).cuda().float() h = torch.randint(64, (20, 1)).cuda().float() rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) offset = torch.randn(20, 2, 7, 7).cuda() input.requires_grad = True offset.requires_grad = True # normal roi_align pooling = DCNv2Pooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=32, no_trans=True, group_size=1, trans_std=0.1).cuda() # deformable pooling dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=32, no_trans=False, group_size=1, trans_std=0.1).cuda() out = pooling(input, rois, offset) dout = dpooling(input, rois, offset) print(out.shape) print(dout.shape) target_out = out.new(*out.size()) target_out.data.uniform_(-0.01, 0.01) target_dout = dout.new(*dout.size()) target_dout.data.uniform_(-0.01, 0.01) e = (target_out - out).mean() e.backward() e = (target_dout - dout).mean() e.backward() def example_mdpooling(): input = torch.randn(2, 32, 64, 64).cuda() input.requires_grad = True batch_inds = torch.randint(2, (20, 1)).cuda().float() x = torch.randint(256, (20, 1)).cuda().float() y = torch.randint(256, (20, 1)).cuda().float() w = torch.randint(64, (20, 1)).cuda().float() h = torch.randint(64, (20, 1)).cuda().float() rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) # mdformable pooling (V2) dpooling = DCNPooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=32, no_trans=False, group_size=1, trans_std=0.1, deform_fc_dim=1024).cuda() dout = dpooling(input, rois) target = dout.new(*dout.size()) target.data.uniform_(-0.1, 0.1) error = (target - dout).mean() error.backward() print(dout.shape) if __name__ == '__main__': example_dconv() example_dpooling() example_mdpooling() check_pooling_zero_offset() # zero offset check if inC == outC: check_zero_offset() check_gradient_dpooling() check_gradient_dconv() # """ # ****** Note: backward is not reentrant error may not be a serious problem, # ****** since the max error is less than 1e-7, # ****** Still looking for what trigger this problem # """ ================================================ FILE: code/synthetic/bsrt/model/__init__.py ================================================ import os from importlib import import_module import torch import torch.nn as nn import torch.nn.parallel as P import torch.utils.model_zoo import time class Model(nn.Module): def __init__(self, args, ckp): super(Model, self).__init__() self.args = args if args.local_rank == 0: print("Making model: ", args.model) print("Patch size: ", args.patch_size) self.scale = args.scale self.idx_scale = 0 self.input_large = (args.model == 'VDSR') self.self_ensemble = args.self_ensemble self.chop = args.chop self.precision = args.precision self.cpu = args.cpu self.device = torch.device('cpu' if args.cpu else 'cuda:%d' % args.local_rank) self.n_GPUs = args.n_GPUs self.save_models = args.save_models module = import_module('model.' + args.model.lower()) self.model = module.make_model(args).to(self.device) if args.precision == 'half': self.model.half() self.load( ckp.get_path('model'), pre_train=args.pre_train, resume=args.resume, cpu=args.cpu ) # time.sleep(3) if args.n_GPUs > 1: self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank], find_unused_parameters=True ) print(self.model, file=ckp.log_file) def forward(self, x, idx_scale): self.idx_scale = idx_scale if hasattr(self.model, 'set_scale'): self.model.set_scale(idx_scale) if self.training: # if self.n_GPUs > 1: return self.model(x) else: if self.chop: forward_function = self.forward_chop else: forward_function = self.model.forward if self.self_ensemble: return self.forward_x8(x, forward_function=forward_function) else: # return self.model(x) return forward_function(x) def save(self, apath, epoch, is_best=False): save_dirs = [os.path.join(apath, 'model_latest.pt')] if is_best: save_dirs.append(os.path.join(apath, 'model_best.pt')) if self.save_models: save_dirs.append( os.path.join(apath, 'model_{}.pt'.format(epoch)) ) if self.n_GPUs > 1: model = self.model.module else: model = self.model for s in save_dirs: torch.save(self.model.state_dict(), s) def load(self, apath, pre_train='', resume=-1, cpu=False): load_from = None kwargs = {} if cpu: kwargs = {'map_location': lambda storage, loc: storage} if resume == -1: load_from = torch.load( os.path.join(apath, 'model_latest.pt'), **kwargs ) elif resume == 0: if pre_train == 'download': print('Download the model') dir_model = os.path.join('..', 'models') os.makedirs(dir_model, exist_ok=True) load_from = torch.utils.model_zoo.load_url( self.model.url, model_dir=dir_model, **kwargs ) elif pre_train: if self.args.local_rank == 0: print('Load the model from {}'.format(pre_train)) map_location = {'cuda:%d' % 0: 'cuda:%d' % self.args.local_rank} load_from = torch.load(pre_train, map_location=map_location) else: load_from = torch.load( os.path.join(apath, 'model_{}.pt'.format(resume)), **kwargs ) if load_from: self.model.load_state_dict(load_from, strict=True) del load_from if self.args.finetune: if self.args.local_rank == 0: print('finetune') for param in self.model.parameters(): param.requires_grad = False for param in self.model.HRconv.parameters(): param.requires_grad = True for param in self.model.conv_last.parameters(): param.requires_grad = True if self.args.finetune_prelayer: if self.args.local_rank == 0: print('finetune_prelayer') if self.args.swinfeature: if self.args.model == 'MBSRT': for param in self.model.pre_layer1.parameters(): param.requires_grad = True for param in self.model.pre_layer2.parameters(): param.requires_grad = True else: for param in self.model.pre_layers.parameters(): param.requires_grad = True else: for param in self.model.feature_extraction.parameters(): param.requires_grad = True for param in self.model.conv_after_pre_layer.parameters(): param.requires_grad = True if self.args.finetune_align: if self.args.local_rank == 0: print('finetune_align') for param in self.model.align.parameters(): param.requires_grad = True if self.args.finetune_spynet: if self.args.local_rank == 0: print('finetune_spynet') for param in self.model.spynet.parameters(): param.requires_grad = True if self.args.finetune_swin: if self.args.local_rank == 0: print('finetune_swin') for param in self.model.layers.parameters(): param.requires_grad = True for param in self.model.conv_after_body.parameters(): param.requires_grad = True if self.args.finetune_upconv: if self.args.local_rank == 0: print('finetune_upconv') for param in self.model.upconv1.parameters(): param.requires_grad = True for param in self.model.upconv2.parameters(): param.requires_grad = True for param in self.model.skipup1.parameters(): param.requires_grad = True for param in self.model.skipup2.parameters(): param.requires_grad = True if self.args.finetune_conv: if self.args.local_rank == 0: print('finetune_conv') # for param in self.model.conv_first.parameters(): # param.requires_grad = True # for param in self.model.conv_flow.parameters(): # param.requires_grad = True # for param in self.model.fea_L2_conv1.parameters(): # param.requires_grad = True # for param in self.model.fea_L3_conv1.parameters(): # param.requires_grad = True # for param in self.model.toplayer.parameters(): # param.requires_grad = True # for param in self.model.smooth1.parameters(): # param.requires_grad = True # for param in self.model.smooth2.parameters(): # param.requires_grad = True # for param in self.model.latlayer1.parameters(): # param.requires_grad = True # for param in self.model.latlayer2.parameters(): # param.requires_grad = True # for param in self.model.fusion.parameters(): # param.requires_grad = True for param in self.model.conv_after_body.parameters(): param.requires_grad = True def forward_chop(self, *args, shave=10, min_size=160000): scale = 1 if self.input_large else self.scale[self.idx_scale] n_GPUs = min(self.n_GPUs, 4) # height, width h, w = args[0].size()[-2:] top = slice(0, h//2 + shave) bottom = slice(h - h//2 - shave, h) left = slice(0, w//2 + shave) right = slice(w - w//2 - shave, w) x_chops = [torch.cat([ a[..., top, left], a[..., top, right], a[..., bottom, left], a[..., bottom, right] ]) for a in args] y_chops = [] if h * w < 4 * min_size: for i in range(0, 4, n_GPUs): x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] y = P.data_parallel(self.model, *x, range(n_GPUs)) if not isinstance(y, list): y = [y] if not y_chops: y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] else: for y_chop, _y in zip(y_chops, y): y_chop.extend(_y.chunk(n_GPUs, dim=0)) else: for p in zip(*x_chops): y = self.forward_chop(*p, shave=shave, min_size=min_size) if not isinstance(y, list): y = [y] if not y_chops: y_chops = [[_y] for _y in y] else: for y_chop, _y in zip(y_chops, y): y_chop.append(_y) h *= scale w *= scale top = slice(0, h//2) bottom = slice(h - h//2, h) bottom_r = slice(h//2 - h, None) left = slice(0, w//2) right = slice(w - w//2, w) right_r = slice(w//2 - w, None) # batch size, number of color channels b, c = y_chops[0][0].size()[:-2] y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] for y_chop, _y in zip(y_chops, y): _y[..., top, left] = y_chop[0][..., top, left] _y[..., top, right] = y_chop[1][..., top, right_r] _y[..., bottom, left] = y_chop[2][..., bottom_r, left] _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] if len(y) == 1: y = y[0] return y def forward_x8(self, *args, forward_function=None): def _transform(v, op): if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) if self.precision == 'half': ret = ret.half() return ret list_x = [] for a in args: x = [a] for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) list_x.append(x) list_y = [] for x in zip(*list_x): y = forward_function(*x) if not isinstance(y, list): y = [y] if not list_y: list_y = [[_y] for _y in y] else: for _list_y, _y in zip(list_y, y): _list_y.append(_y) for _list_y in list_y: for i in range(len(_list_y)): if i > 3: _list_y[i] = _transform(_list_y[i], 't') if i % 4 > 1: _list_y[i] = _transform(_list_y[i], 'h') if (i % 4) % 2 == 1: _list_y[i] = _transform(_list_y[i], 'v') y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] if len(y) == 1: y = y[0] return y ================================================ FILE: code/synthetic/bsrt/model/arch_util.py ================================================ import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F from model import common from model.utils.psconv import PSGConv2d as PSConv2d, PyConv2d def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) ########################### def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0): return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True) class ESA(nn.Module): def __init__(self, n_feats, conv=conv_layer): super(ESA, self).__init__() f = n_feats // 4 self.conv1 = conv(n_feats, f, kernel_size=1) self.conv_f = conv(f, f, kernel_size=1) self.conv_max = conv(f, f, kernel_size=3, padding=1) self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) self.conv3 = conv(f, f, kernel_size=3, padding=1) self.conv3_ = conv(f, f, kernel_size=3, padding=1) self.conv4 = conv(f, n_feats, kernel_size=1) self.sigmoid = nn.Sigmoid() self.relu = nn.ReLU(inplace=True) def forward(self, x): c1_ = (self.conv1(x)) c1 = self.conv2(c1_) v_max = F.max_pool2d(c1, kernel_size=7, stride=3) v_range = self.relu(self.conv_max(v_max)) c3 = self.relu(self.conv3(v_range)) c3 = self.conv3_(c3) c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False) cf = self.conv_f(c1_) c4 = self.conv4(c3+cf) m = self.sigmoid(c4) return x * m class DWConv(nn.Module): def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, x): x = self.dwconv(x) return x ########################## class SELayer(nn.Module): ''' SE-block ''' def __init__(self, channel, reduction=16): super(SELayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), # nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) class ResidualBlock_noBN(nn.Module): '''Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| ''' def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out class ResidualBlock_SE(nn.Module): '''Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| ''' def __init__(self, nf=64, reduction=16): super(ResidualBlock_SE, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv3 = nn.Conv2d(3 * nf, nf, 1, padding=0, dilation=1, bias=True) self.se = SELayer(nf, reduction) # initialization initialize_weights([self.conv1, self.conv2, self.conv3], 0.1) def forward(self, x): identity = x basic_out = F.relu(self.conv1(x), inplace=True) basic_out = self.conv2(basic_out) se_out = self.se(basic_out) out = torch.cat((identity, basic_out, se_out), 1) out = self.conv3(out) return out class _PositionAttentionModule(nn.Module): """ Position attention module""" def __init__(self, in_channels, **kwargs): super(_PositionAttentionModule, self).__init__() self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1) self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1) self.conv_d = nn.Conv2d(in_channels, in_channels, 1) self.alpha = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self, x): batch_size, _, height, width = x.size() feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1) feat_c = self.conv_c(x).view(batch_size, -1, height * width) attention_s = self.softmax(torch.bmm(feat_b, feat_c)) feat_d = self.conv_d(x).view(batch_size, -1, height * width) feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width) out = self.alpha * feat_e + x return out ## Spatial Attention (CA) Layer class SALayer(nn.Module): def __init__(self, wn=None): super(SALayer,self).__init__() self.body = nn.Sequential( wn(nn.Conv2d(2, 1, 7, 1, 3, bias=False)), nn.Sigmoid() ) def forward(self, x): avg_f = torch.mean(x, dim=1, keepdim=True) max_f = torch.max(x, dim=1, keepdim=True)[0] y = torch.cat([avg_f, max_f], dim=1) return self.body(y).expand_as(x) * x ## Channel Attention (CA) Layer class CALayerV2(nn.Module): def __init__(self, n_feat, reduction=16, wn=None): super(CALayerV2, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( wn(nn.Conv2d(n_feat, n_feat//reduction, 1, padding=0, bias=False)), nn.ReLU(inplace=True), wn(nn.Conv2d(n_feat//reduction, n_feat, 1, padding=0, bias=False)), # nn.Sigmoid() ) def forward(self, x): y1 = self.avg_pool(x) y2 = self.max_pool(x) y1 = self.conv_du(y1) y2 = self.conv_du(y2) return x * torch.sigmoid(y1+y2) class DALayer(nn.Module): def __init__(self, channel, reduction, wn): super(DALayer, self).__init__() # global average pooling: feature --> point self.ca = CALayer(channel, reduction, wn) self.sa = SALayer(wn) self.conv = wn(nn.Conv2d(channel*2, channel, 1)) def forward(self, x): ca = self.ca(x) sa = self.sa(x) res = self.conv(torch.cat([ca, sa], dim=1)) return res + x ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction, wn): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( wn(nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True)), nn.ReLU(inplace=True), wn(nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True)), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False): super(RCAB, self).__init__() expand = 6 linear = 0.75 modules_body = [] # for i in range(2): modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias))) modules_body.append(act) modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias))) modules_body.append(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias)) if da: modules_body.append(DALayer(n_feat, reduction, wn)) else: modules_body.append(CALayer(n_feat, reduction, wn)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) #res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False): super(ResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = common.default_conv wn = lambda x: torch.nn.utils.weight_norm(x) modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \ for _ in range(n_resblocks)] modules_body.append(wn(conv(n_feat, n_feat, kernel_size))) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ################################################################ ################################################################ ################################################################ def make_layer_idx(block, n_layers): layers = [] for i in range(n_layers): layers.append(block(idx=i)) return nn.Sequential(*layers) ## Residual Channel Attention Block (RCAB) class LRSCRCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0): super(LRSCRCAB, self).__init__() expand = 6 linear = 0.75 modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else [] # for i in range(2): modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias))) modules_body.append(act) modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias))) modules_body.append(wn(conv(int(n_feat*linear), n_feat, kernel_size, bias=bias))) if da: modules_body.append(DALayer(n_feat, reduction, wn)) else: modules_body.append(CALayer(n_feat, reduction, wn)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) res = torch.cat([res, x], dim=1) return res ## Residual Channel Attention Block (RCAB) class LRSCPYRCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False, idx=0): super(LRSCPYRCAB, self).__init__() expand = 6 linear = 0.75 modules_body = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else [] # for i in range(2): modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias))) modules_body.append(act) modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias))) modules_body.append( PyConv2d(in_channels=int(n_feat*linear), out_channels=[n_feat//4, n_feat//4, n_feat//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])) if da: modules_body.append(DALayer(n_feat, reduction, wn)) else: modules_body.append(CALayer(n_feat, reduction, wn)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False, idx=0): super(LRSCResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = common.default_conv wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(conv(n_feat*(idx+1), n_feat, 1, bias=True))] if idx > 0 else [] modules_body = [ LRSCRCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \ for i in range(n_resblocks)] modules_body.append(wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size))) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.head(x) res = self.body(res) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCPSResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False, idx=0): super(LRSCPSResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = PSConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else [] modules_body = [ LRSCRCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \ for i in range(n_resblocks)] modules_tail = [wn(conv(n_feat*(n_resblocks+1), n_feat, kernel_size))] self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): res = self.head(x) res = self.body(res) res = self.tail(res) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCPyResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False, idx=0): super(LRSCPyResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = PyConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(nn.Conv2d(n_feat*(idx+1), n_feat, 1, 1, 0, bias=True))] if idx > 0 else [] modules_body = [ LRSCPYRCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da, idx=i) \ for i in range(n_resblocks)] modules_tail = [wn(nn.Conv2d(n_feat*(n_resblocks+1), n_feat, 1))] self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): res = self.head(x) res = self.body(res) res = self.tail(res) res = torch.cat([res, x], dim=1) return res class LRSCWideActResBlock(nn.Module): def __init__(self, nf=64, idx=0): super(LRSCWideActResBlock, self).__init__() self.res_scale = 1 expand = 6 linear = 0.8 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else [] body = [] body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2))) self.head = nn.Sequential(*head) self.body = nn.Sequential(*body) def forward(self, x): res = self.head(x) res = self.body(res) res = torch.cat([res, x], dim=1) return res class LRSCPyWideActResBlock(nn.Module): def __init__(self, nf=64, idx=0): super(LRSCPyWideActResBlock, self).__init__() self.res_scale = 1 expand = 6 linear = 0.75 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, bias=True))] if idx > 0 else [] body = [] body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( PyConv2d(in_channels=int(nf*linear), out_channels=[nf//4, nf//4, nf//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])) self.head = nn.Sequential(*head) self.body = nn.Sequential(*body) def forward(self, x): res = self.head(x) res = self.body(res) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCPyWideActResGroup(nn.Module): def __init__(self, nf, n_resblocks, idx=0): super(LRSCPyWideActResGroup, self).__init__() kernel_size = 3 conv = PyConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else [] modules_body = [ LRSCPyWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)] modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))] self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): res = self.head(x) res = self.body(res) res = self.tail(res) res = torch.cat([res, x], dim=1) return res ## Long-Range Skip-connect Residual Group (RG) class LRSCWideActResGroup(nn.Module): def __init__(self, nf, n_resblocks, idx=0): super(LRSCWideActResGroup, self).__init__() kernel_size = 3 conv = PyConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_head = [wn(nn.Conv2d(nf*(idx+1), nf, 1, 1, 0, bias=True))] if idx > 0 else [] modules_body = [ LRSCWideActResBlock(nf=nf, idx=i) for i in range(n_resblocks)] modules_tail = [wn(nn.Conv2d(nf*(n_resblocks+1), nf, 1))] self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): res = self.head(x) res = self.body(res) res = self.tail(res) res = torch.cat([res, x], dim=1) return res ################################################################ ################################################################ ################################################################ ## Residual Channel Attention Block (RCAB) class PYRCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, da=False): super(PYRCAB, self).__init__() expand = 6 linear = 0.75 modules_body = [] # for i in range(2): modules_body.append(wn(nn.Conv2d(n_feat, n_feat*expand, 1, bias=bias))) modules_body.append(act) modules_body.append(wn(nn.Conv2d(n_feat*expand, int(n_feat*linear), 1, bias=bias))) # modules_body.append(conv(, n_feat, kernel_size, bias=bias)) modules_body.append(PyConv2d(in_channels=int(n_feat*linear), out_channels=[n_feat//4, n_feat//4, n_feat//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8], bias=bias)) if da: modules_body.append(DALayer(n_feat, reduction, wn)) else: modules_body.append(CALayer(n_feat, reduction, wn)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) res += x return res ## Residual Group (RG) class PyResidualGroup(nn.Module): def __init__(self, n_feat, n_resblocks, da=False): super(PyResidualGroup, self).__init__() kernel_size = 3 res_scale = 1 reduction = 16 conv = PyConv2d wn = lambda x: torch.nn.utils.weight_norm(x) modules_body = [] modules_body = [ PYRCAB( conv, n_feat, kernel_size, reduction, wn=wn, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale, da=da) \ for _ in range(n_resblocks)] modules_body.append( PyConv2d(in_channels=n_feat, out_channels=[n_feat//4, n_feat//4, n_feat//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res class WideActResBlock(nn.Module): def __init__(self, nf=64): super(WideActResBlock, self).__init__() self.res_scale = 1 body = [] expand = 6 linear = 0.8 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( wn(nn.Conv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2))) self.body = nn.Sequential(*body) def forward(self, x): res = self.body(x) * self.res_scale res += x return res class PSWideActResBlock(nn.Module): def __init__(self, nf=64): super(PSWideActResBlock, self).__init__() self.res_scale = 1 body = [] expand = 6 linear = 0.75 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( wn(PSConv2d(int(nf*linear), nf, kernel_size, padding=kernel_size//2))) self.body = nn.Sequential(*body) def forward(self, x): res = self.body(x) * self.res_scale res += x return res class PyWideActResBlock(nn.Module): def __init__(self, nf=64): super(PyWideActResBlock, self).__init__() self.res_scale = 1 body = [] expand = 6 linear = 0.75 kernel_size = 3 wn = lambda x: torch.nn.utils.weight_norm(x) act=nn.ReLU(True) expand_nf = nf*expand linear_nf = int(nf * linear) body.append( wn(nn.Conv2d(nf, nf*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(nf*expand, int(nf*linear), 1, padding=1//2))) body.append( PyConv2d(in_channels=linear_nf, out_channels=[nf//4, nf//4, nf//2], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8])) self.body = nn.Sequential(*body) def forward(self, x): res = self.body(x) * self.res_scale res += x return res def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True, use_pad_mask=False): """Warp an image or feature map with optical flow. Args: x (Tensor): Tensor with size (n, c, h, w). flow (Tensor): Tensor with size (n, h, w, 2), normal value. interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. padding_mode (str): 'zeros' or 'border' or 'reflection'. Default: 'zeros'. align_corners (bool): Before pytorch 1.3, the default value is align_corners=True. After pytorch 1.3, the default value is align_corners=False. Here, we use the True as default. use_pad_mask (bool): only used for PWCNet, x is first padded with ones along the channel dimension. The mask is generated according to the grid_sample results of the padded dimension. Returns: Tensor: Warped image or feature map. """ # assert x.size()[-2:] == flow.size()[1:3] # temporaily turned off for image-wise shift n, _, h, w = x.size() x = x.float() # create mesh grid # grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) # an illegal memory access on TITAN RTX + PyTorch1.9.1 grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), torch.arange(0, w, dtype=x.dtype, device=x.device)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # if use_pad_mask: # for PWCNet # x = F.pad(x, (0,0,0,0,0,1), mode='constant', value=1) # scale grid to [-1,1] if interp_mode == 'nearest4': # todo: bug, no gradient for flow model in this case!!! but the result is good vgrid_x_floor = 2.0 * torch.floor(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0 vgrid_x_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 0]) / max(w - 1, 1) - 1.0 vgrid_y_floor = 2.0 * torch.floor(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0 vgrid_y_ceil = 2.0 * torch.ceil(vgrid[:, :, :, 1]) / max(h - 1, 1) - 1.0 output00 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners) output01 = F.grid_sample(x, torch.stack((vgrid_x_floor, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners) output10 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_floor), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners) output11 = F.grid_sample(x, torch.stack((vgrid_x_ceil, vgrid_y_ceil), dim=3), mode='nearest', padding_mode=padding_mode, align_corners=align_corners) return torch.cat([output00, output01, output10, output11], 1) else: vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) # if use_pad_mask: # for PWCNet # output = _flow_warp_masking(output) # TODO, what if align_corners=False return output # def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): # """Warp an image or feature map with optical flow # Args: # x (Tensor): size (N, C, H, W) # flow (Tensor): size (N, H, W, 2), normal value # interp_mode (str): 'nearest' or 'bilinear' # padding_mode (str): 'zeros' or 'border' or 'reflection' # Returns: # Tensor: warped image or feature map # """ # assert x.size()[-2:] == flow.size()[1:3] # B, C, H, W = x.size() # # mesh grid # grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) # grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 # grid.requires_grad = False # grid = grid.type_as(x) # vgrid = grid + flow # # scale grid to [-1,1] # vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 # vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 # vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) # output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) # return output ================================================ FILE: code/synthetic/bsrt/model/bsrt.py ================================================ import functools import torch import torch.nn as nn import torch.nn.functional as F import model.arch_util as arch_util from torch.cuda.amp import autocast import model.swin_util as swu import time import os import math from utils.debayer import Debayer3x3 import torchvision.utils as tvutils from datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch try: from model.non_local.non_local_cross_dot_product import NONLocalBlock2D as NonLocalCross from model.non_local.non_local_dot_product import NONLocalBlock2D as NonLocal except ImportError: raise ImportError('Failed to import Non_Local module.') try: from model.DCNv2.dcn_v2 import DCN_sep as DCN, FlowGuidedDCN, InsideFlowGuidedDCN except ImportError: raise ImportError('Failed to import DCNv2 module.') def make_model(args, parent=False): nframes = args.burst_size img_size = args.patch_size // args.scale[0] patch_size = 1 in_chans = args.burst_channel out_chans = args.n_colors if args.model_level == "S": depths = [6]*1 + [6] * 4 num_heads = [6]*1 + [6] * 4 embed_dim = 60 elif args.model_level == "L": depths = [6]*1 + [8] * 6 num_heads = [6]*1 + [6] * 6 embed_dim = 180 window_size = 8 mlp_ratio = 2 upscale = args.scale[0] non_local = args.non_local use_checkpoint=args.use_checkpoint if args.local_rank <= 0: print("depths: ", depths) return BSRT(args=args,nframes=nframes, img_size=img_size, patch_size=patch_size, in_chans=in_chans, out_chans=out_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, upscale=upscale, non_local=non_local, use_checkpoint=use_checkpoint) class BasicModule(nn.Module): """Basic Module for SpyNet. """ def __init__(self): super(BasicModule, self).__init__() self.basic_module = nn.Sequential( nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) def forward(self, tensor_input): return self.basic_module(tensor_input) class SpyNet(nn.Module): """SpyNet architecture. Args: load_path (str): path for pretrained SpyNet. Default: None. return_levels (list[int]): return flows of different levels. Default: [5]. """ def __init__(self, load_path=None, return_levels=[5]): super(SpyNet, self).__init__() self.return_levels = return_levels self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) if load_path: if not os.path.exists(load_path): import requests url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/spynet_sintel_final-3d2a1287.pth' r = requests.get(url, allow_redirects=True) print(f'downloading SpyNet pretrained model from {url}') os.makedirs(os.path.dirname(load_path), exist_ok=True) open(load_path, 'wb').write(r.content) self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) def preprocess(self, tensor_input): tensor_output = (tensor_input - self.mean) / self.std return tensor_output def process(self, ref, supp, w, h, w_floor, h_floor): flow_list = [] ref = [self.preprocess(ref)] supp = [self.preprocess(supp)] # ref = [ref] # supp = [supp] for level in range(5): ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) flow = ref[0].new_zeros( [ref[0].size(0), 2, int(math.floor(ref[0].size(2) / 2.0)), int(math.floor(ref[0].size(3) / 2.0))]) for level in range(len(ref)): upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 if upsampled_flow.size(2) != ref[level].size(2): upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') if upsampled_flow.size(3) != ref[level].size(3): upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') flow = self.basic_module[level](torch.cat([ ref[level], arch_util.flow_warp( supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), upsampled_flow ], 1)) + upsampled_flow if level in self.return_levels: scale = 2**(5-level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8) flow_out = F.interpolate(input=flow, size=(h//scale, w//scale), mode='bilinear', align_corners=False) flow_out[:, 0, :, :] *= float(w//scale) / float(w_floor//scale) flow_out[:, 1, :, :] *= float(h//scale) / float(h_floor//scale) if torch.abs(flow_out).mean() > 200: print(f"level {level}, flow > 200: {torch.abs(flow_out).mean():.4f}") # return None flow_out.clamp(-250, 250) flow_list.insert(0, flow_out) return flow_list def forward(self, ref, supp): assert ref.size() == supp.size() h, w = ref.size(2), ref.size(3) w_floor = math.floor(math.ceil(w / 32.0) * 32.0) h_floor = math.floor(math.ceil(h / 32.0) * 32.0) ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) flow_list = self.process(ref, supp, w, h, w_floor, h_floor) return flow_list[0] if len(flow_list) == 1 else flow_list class FlowGuidedPCDAlign(nn.Module): ''' Alignment module using Pyramid, Cascading and Deformable convolution with 3 pyramid levels. [From EDVR] ''' def __init__(self, nf=64, groups=8): super(FlowGuidedPCDAlign, self).__init__() # L3: level 3, 1/4 spatial size self.L3_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L3_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) # L2: level 2, 1/2 spatial size self.L2_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L2_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # L1: level 1, original spatial size self.L1_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L1_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # Cascading DCN self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l): '''align other neighboring frames to the reference frame in the feature level nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features ''' # L3 L3_offset = torch.cat([nbr_fea_warped_l[2], ref_fea_l[2], flows_l[2]], dim=1) L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset)) L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset)) L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset, flows_l[2])) # L2 L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False) L2_offset = torch.cat([nbr_fea_warped_l[1], ref_fea_l[1], flows_l[1]], dim=1) L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset)) L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset*2], dim=1))) L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset)) L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset, flows_l[1]) L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False) L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1))) # L1 L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False) L1_offset = torch.cat([nbr_fea_warped_l[0], ref_fea_l[0], flows_l[0]], dim=1) L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset)) L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1))) L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset)) L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset, flows_l[0]) L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False) L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1)) # Cascading offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1) offset = self.lrelu(self.cas_offset_conv1(offset)) offset = self.lrelu(self.cas_offset_conv2(offset)) L1_fea = self.cas_dcnpack(L1_fea, offset) return L1_fea class CrossNonLocal_Fusion(nn.Module): ''' Cross Non Local fusion module ''' def __init__(self, nf=64, out_feat=96, nframes=5, center=2): super(CrossNonLocal_Fusion, self).__init__() self.center = center self.non_local_T = nn.ModuleList() self.non_local_F = nn.ModuleList() for i in range(nframes): self.non_local_T.append(NonLocalCross(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False)) self.non_local_F.append(NonLocal(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False)) # fusion conv: using 1x1 to save parameters and computation self.fea_fusion = nn.Conv2d(nframes * nf*2, out_feat, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, aligned_fea): B, N, C, H, W = aligned_fea.size() # N video frames ref = aligned_fea[:, self.center, :, :, :].clone() cor_l = [] non_l = [] for i in range(N): nbr = aligned_fea[:, i, :, :, :] non_l.append(self.non_local_F[i](nbr)) cor_l.append(self.non_local_T[i](nbr, ref)) aligned_fea_T = torch.cat(cor_l, dim=1) aligned_fea_F = torch.cat(non_l, dim=1) aligned_fea = torch.cat([aligned_fea_T, aligned_fea_F], dim=1) #### fusion fea = self.fea_fusion(aligned_fea) return fea class BSRT(nn.Module): def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chans=3, out_chans=3, embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=4, non_local=False, **kwargs): super(BSRT, self).__init__() num_in_ch = in_chans num_out_ch = out_chans num_feat = 64 groups = 8 # embed_dim = num_feat back_RBs = 5 n_resblocks = 6 self.args = args self.center = 0 self.upscale = upscale self.window_size = window_size self.non_local = non_local self.nframes = nframes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = embed_dim self.mlp_ratio = mlp_ratio spynet_path='/home/luoziwei/.pretrained_models/spynet_sintel_final-3d2a1287.pth' self.spynet = SpyNet(spynet_path, [3, 4, 5]) self.conv_flow = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1) self.flow_ps = nn.PixelShuffle(2) # split image into non-overlapping patches self.patch_embed = swu.PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # merge non-overlapping patches into image self.patch_unembed = swu.PatchUnEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) ##################################################################################################### ################################### 1, shallow feature extraction ################################### self.conv_first = nn.Conv2d(num_in_ch*(1+2*0), embed_dim, 3, 1, 1, bias=True) # # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule if args.swinfeature: if self.args.local_rank <= 0: print("using swinfeature") self.pre_layers = nn.ModuleList() for i_layer in range(depths[0]): layer = swu.SwinTransformerBlock(dim=embed_dim, input_resolution=(patches_resolution[0]//2, patches_resolution[1]//2), num_heads=num_heads[0], window_size=window_size, shift_size=0 if (i_layer % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i_layer], norm_layer=norm_layer) self.pre_layers.append(layer) self.pre_norm = norm_layer(embed_dim) else: WARB = functools.partial(arch_util.WideActResBlock, nf=embed_dim) self.feature_extraction = arch_util.make_layer(WARB, 5) self.conv_after_pre_layer = nn.Conv2d(embed_dim, num_feat*4, 3, 1, 1, bias=True) self.mid_ps = nn.PixelShuffle(2) self.fea_L2_conv1 = nn.Conv2d(num_feat, num_feat*2, 3, 2, 1, bias=True) self.fea_L3_conv1 = nn.Conv2d(num_feat*2, num_feat*4, 3, 2, 1, bias=True) ##################################################################################################### ################################### 2, Feature Enhanced PCD Align ################################### # Top layers self.toplayer = nn.Conv2d(num_feat*4, num_feat, kernel_size=1, stride=1, padding=0) # Smooth layers self.smooth1 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1) self.smooth2 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1) # Lateral layers self.latlayer1 = nn.Conv2d(num_feat*2, num_feat, kernel_size=1, stride=1, padding=0) self.latlayer2 = nn.Conv2d(num_feat*1, num_feat, kernel_size=1, stride=1, padding=0) # self.align = PCD_Align(nf=num_feat, groups=groups) self.align = FlowGuidedPCDAlign(nf=num_feat, groups=groups) ##################################################################################################### ################################### 3, Multi-frame Feature Fusion ################################## if self.non_local: if self.args.local_rank <= 0: print("using non_local") self.fusion = CrossNonLocal_Fusion(nf=num_feat, out_feat=embed_dim, nframes=nframes, center=self.center) else: self.fusion = nn.Conv2d(nframes * num_feat, embed_dim, 1, 1, bias=True) ##################################################################################################### ################################### 4, deep feature extraction ###################################### # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) swu.trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # build Residual Swin Transformer blocks (RSTB) self.layers = nn.ModuleList() for i_layer in range(1, self.num_layers): layer = swu.RSTB(dim=embed_dim, input_resolution=(patches_resolution[0], patches_resolution[1]), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results norm_layer=norm_layer, downsample=None, use_checkpoint=use_checkpoint, img_size=img_size, patch_size=patch_size ) self.layers.append(layer) self.norm = norm_layer(self.num_features) # build the last conv layer in deep feature extraction self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) ##################################################################################################### ################################ 5, high quality image reconstruction ################################ self.upconv1 = nn.Conv2d(embed_dim, num_feat * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(64, args.n_colors, 3, 1, 1, bias=True) #### skip ############# self.skip_pixel_shuffle = nn.PixelShuffle(2) self.skipup1 = nn.Conv2d(num_in_ch//4, num_feat * 4, 3, 1, 1, bias=True) self.skipup2 = nn.Conv2d(num_feat, args.n_colors * 4, 3, 1, 1, bias=True) #### activation function self.lrelu = nn.LeakyReLU(0.1, inplace=True) self.lrelu2 = nn.LeakyReLU(0.1, inplace=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): swu.trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def _upsample_add(self, x, y): return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + y def check_image_size(self, x): _, _, h, w = x.size() mod_pad_h = (self.window_size - h % self.window_size) % self.window_size mod_pad_w = (self.window_size - w % self.window_size) % self.window_size x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') return x def pre_forward_features(self, x): if self.args.swinfeature: x_size = (x.shape[-2], x.shape[-1]) x = self.patch_embed(x, use_norm=True) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for idx, layer in enumerate(self.pre_layers): x = layer(x, x_size) x = self.pre_norm(x) x = self.patch_unembed(x, x_size) else: x = self.feature_extraction(x) return x def forward_features(self, x): x_size = (x.shape[-2], x.shape[-1]) x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for idx, layer in enumerate(self.layers): x = layer(x, x_size) if torch.any(torch.isinf(x)) or torch.any(torch.isnan(x)): print('layer: ', idx) x = self.norm(x) # B L C x = self.patch_unembed(x, x_size) return x @autocast() def forward(self, x, print_time=False): B, N, C, H, W = x.size() # N video frames x_center = x[:, self.center, :, :, :].contiguous() #### skip module ######## skip1 = self.lrelu2(self.skip_pixel_shuffle(self.skipup1(self.skip_pixel_shuffle(x_center)))) skip2 = self.skip_pixel_shuffle(self.skipup2(skip1)) x_ = self.conv_flow(self.flow_ps(x.view(B*N, C, H, W))).view(B, N, -1, H*2, W*2) # calculate flows ref_flows = self.get_ref_flows(x_) #### extract LR features x = self.lrelu(self.conv_first(x.view(B*N, -1, H, W))) L1_fea = self.mid_ps(self.conv_after_pre_layer(self.pre_forward_features(x))) _, _, H, W = L1_fea.size() L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea)) L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea)) # FPN enhance features L3_fea = self.lrelu(self.toplayer(L3_fea)) L2_fea = self.smooth1(self._upsample_add(L3_fea, self.latlayer1(L2_fea))) L1_fea = self.smooth2(self._upsample_add(L2_fea, self.latlayer2(L1_fea))) L1_fea = L1_fea.view(B, N, -1, H, W).contiguous() L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2 ).contiguous() L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4).contiguous() #### PCD align # ref feature list ref_fea_l = [ L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(), L3_fea[:, self.center, :, :, :].clone() ] aligned_fea = [] for i in range(N): nbr_fea_l = [ L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(), L3_fea[:, i, :, :, :].clone() ] flows_l = [ ref_flows[0][:, i, :, :, :].clone(), ref_flows[1][:, i, :, :, :].clone(), ref_flows[2][:, i, :, :, :].clone() ] # print(nbr_fea_l[0].shape, flows_l[0].shape) nbr_warped_l = [ arch_util.flow_warp(nbr_fea_l[0], flows_l[0].permute(0, 2, 3, 1), 'bilinear'), arch_util.flow_warp(nbr_fea_l[1], flows_l[1].permute(0, 2, 3, 1), 'bilinear'), arch_util.flow_warp(nbr_fea_l[2], flows_l[2].permute(0, 2, 3, 1), 'bilinear') ] aligned_fea.append(self.align(nbr_fea_l, nbr_warped_l, ref_fea_l, flows_l)) aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W] --> [B, T, C, H, W] if not self.non_local: aligned_fea = aligned_fea.view(B, -1, H, W) x = self.lrelu(self.fusion(aligned_fea)) x = self.lrelu(self.conv_after_body(self.forward_features(x))) + x x = self.lrelu(self.pixel_shuffle(self.upconv1(x))) x = skip1 + x x = self.lrelu(self.pixel_shuffle(self.upconv2(x))) x = self.lrelu(self.HRconv(x)) x = self.conv_last(x) x = skip2 + x return x def get_ref_flows(self, x): '''Get flow between frames ref and other''' b, n, c, h, w = x.size() x_nbr = x.reshape(-1, c, h, w) x_ref = x[:, self.center:self.center+1, :, :, :].repeat(1, n, 1, 1, 1).reshape(-1, c, h, w) # backward flows = self.spynet(x_ref, x_nbr) flows_list = [flow.view(b, n, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in zip(flows, range(3))] return flows_list ================================================ FILE: code/synthetic/bsrt/model/checkpoint.py ================================================ import torch import warnings def detach_variable(inputs): if isinstance(inputs, tuple): out = [] for inp in inputs: x = inp.detach() x.requires_grad = inp.requires_grad out.append(x) return tuple(out) else: raise RuntimeError( "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) def check_backward_validity(inputs): if not any(inp.requires_grad for inp in inputs): warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod def backward(ctx, *output_grads): for i in range(len(ctx.input_tensors)): temp = ctx.input_tensors[i] ctx.input_tensors[i] = temp.detach() ctx.input_tensors[i].requires_grad = temp.requires_grad with torch.enable_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) return (None, None) + input_grads ================================================ FILE: code/synthetic/bsrt/model/common.py ================================================ import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) class MeanShift(nn.Conv2d): def __init__( self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std for p in self.parameters(): p.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)): m = [conv(in_channels, out_channels, kernel_size, bias=bias)] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feats)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feats, 4 * n_feats, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) elif scale == 3: m.append(conv(n_feats, 9 * n_feats, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) else: raise NotImplementedError super(Upsampler, self).__init__(*m) class UpOnly(nn.Sequential): def __init__(self, scale): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(nn.PixelShuffle(2)) elif scale == 3: m.append(nn.PixelShuffle(3)) else: raise NotImplementedError super(UpOnly, self).__init__(*m) def lanczos_kernel(dx, a=3, N=None, dtype=None, device=None): ''' Generates 1D Lanczos kernels for translation and interpolation. Args: dx : float, tensor (batch_size, 1), the translation in pixels to shift an image. a : int, number of lobes in the kernel support. If N is None, then the width is the kernel support (length of all lobes), S = 2(a + ceil(dx)) + 1. N : int, width of the kernel. If smaller than S then N is set to S. Returns: k: tensor (?, ?), lanczos kernel ''' if not torch.is_tensor(dx): dx = torch.tensor(dx, dtype=dtype, device=device) if device is None: device = dx.device if dtype is None: dtype = dx.dtype D = dx.abs().ceil().int() S = 2 * (a + D) + 1 # width of kernel support S_max = S.max() if hasattr(S, 'shape') else S if (N is None) or (N < S_max): N = S Z = (N - S) // 2 # width of zeros beyond kernel support start = (-(a + D + Z)).min() end = (a + D + Z + 1).max() x = torch.arange(start, end, dtype=dtype, device=device).view(1, -1) - dx px = (np.pi * x) + 1e-3 sin_px = torch.sin(px) sin_pxa = torch.sin(px / a) k = a * sin_px * sin_pxa / px ** 2 # sinc(x) masked by sinc(x/a) return k def lanczos_shift(img, shift, p=5, a=3): ''' Shifts an image by convolving it with a Lanczos kernel. Lanczos interpolation is an approximation to ideal sinc interpolation, by windowing a sinc kernel with another sinc function extending up to a few nunber of its lobes (typically a=3). Args: img : tensor (batch_size, channels, height, width), the images to be shifted shift : tensor (batch_size, 2) of translation parameters (dy, dx) p : int, padding width prior to convolution (default=3) a : int, number of lobes in the Lanczos interpolation kernel (default=3) Returns: I_s: tensor (batch_size, channels, height, width), shifted images ''' img = img.transpose(0, 1) dtype = img.dtype if len(img.shape) == 2: img = img[None, None].repeat(1, shift.shape[0], 1, 1) # batch of one image elif len(img.shape) == 3: # one image per shift assert img.shape[0] == shift.shape[0] img = img[None,] # Apply padding padder = torch.nn.ReflectionPad2d(p) # reflect pre-padding I_padded = padder(img) # Create 1D shifting kernels y_shift = shift[:, [0]] x_shift = shift[:, [1]] k_y = (lanczos_kernel(y_shift, a=a, N=None, dtype=dtype) .flip(1) # flip axis of convolution )[:, None, :, None] # expand dims to get shape (batch, channels, y_kernel, 1) k_x = (lanczos_kernel(x_shift, a=a, N=None, dtype=dtype) .flip(1) )[:, None, None, :] # shape (batch, channels, 1, x_kernel) # Apply kernels # print(I_padded.shape, k_y.shape) I_s = torch.conv1d(I_padded, groups=k_y.shape[0], weight=k_y, padding=[k_y.shape[2] // 2, 0]) # same padding I_s = torch.conv1d(I_s, groups=k_x.shape[0], weight=k_x, padding=[0, k_x.shape[3] // 2]) I_s = I_s[..., p:-p, p:-p] # remove padding # print(I_s.shape) return I_s.transpose(0, 1) # , k.squeeze() ================================================ FILE: code/synthetic/bsrt/model/ebsr.py ================================================ import functools import torch import torch.nn as nn import torch.nn.functional as F import model.arch_util as arch_util from torch.cuda.amp import autocast import model.swin_util as swu import time import os import math from utils.debayer import Debayer3x3 import torchvision.utils as tvutils from datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch try: from model.non_local.non_local_cross_dot_product import NONLocalBlock2D as NonLocalCross from model.non_local.non_local_dot_product import NONLocalBlock2D as NonLocal except ImportError: raise ImportError('Failed to import Non_Local module.') try: from model.DCNv2.dcn_v2 import DCN_sep as DCN, FlowGuidedDCN, InsideFlowGuidedDCN except ImportError: raise ImportError('Failed to import DCNv2 module.') def make_model(args, parent=False): nframes = args.burst_size img_size = args.patch_size // args.scale[0] patch_size = 1 in_chans = args.burst_channel out_chans = args.n_colors embed_dim = args.n_feats depths = [6]*1 + [8] * 6 num_heads = [6]*1 + [6] * 6 window_size = 8 mlp_ratio = 2 upscale = args.scale[0] non_local = args.non_local if args.local_rank <= 0: print("depths: ", depths) return EBSR(args=args,nframes=nframes, img_size=img_size, patch_size=patch_size, in_chans=in_chans, out_chans=out_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, upscale=upscale, non_local=non_local) class BasicModule(nn.Module): """Basic Module for SpyNet. """ def __init__(self): super(BasicModule, self).__init__() self.basic_module = nn.Sequential( nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) def forward(self, tensor_input): return self.basic_module(tensor_input) class SpyNet(nn.Module): """SpyNet architecture. Args: load_path (str): path for pretrained SpyNet. Default: None. return_levels (list[int]): return flows of different levels. Default: [5]. """ def __init__(self, load_path=None, return_levels=[5]): super(SpyNet, self).__init__() self.return_levels = return_levels self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) if load_path: if not os.path.exists(load_path): import requests url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/spynet_sintel_final-3d2a1287.pth' r = requests.get(url, allow_redirects=True) print(f'downloading SpyNet pretrained model from {url}') os.makedirs(os.path.dirname(load_path), exist_ok=True) open(load_path, 'wb').write(r.content) self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) def preprocess(self, tensor_input): tensor_output = (tensor_input - self.mean) / self.std return tensor_output def process(self, ref, supp, w, h, w_floor, h_floor): flow_list = [] ref = [self.preprocess(ref)] supp = [self.preprocess(supp)] # ref = [ref] # supp = [supp] for level in range(5): ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) flow = ref[0].new_zeros( [ref[0].size(0), 2, int(math.floor(ref[0].size(2) / 2.0)), int(math.floor(ref[0].size(3) / 2.0))]) for level in range(len(ref)): upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 if upsampled_flow.size(2) != ref[level].size(2): upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') if upsampled_flow.size(3) != ref[level].size(3): upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') flow = self.basic_module[level](torch.cat([ ref[level], arch_util.flow_warp( supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), upsampled_flow ], 1)) + upsampled_flow if level in self.return_levels: scale = 2**(5-level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8) flow_out = F.interpolate(input=flow, size=(h//scale, w//scale), mode='bilinear', align_corners=False) flow_out[:, 0, :, :] *= float(w//scale) / float(w_floor//scale) flow_out[:, 1, :, :] *= float(h//scale) / float(h_floor//scale) if torch.abs(flow_out).mean() > 200: print(f"level {level}, flow > 200: {torch.abs(flow_out).mean():.4f}") # return None flow_out.clamp(-50, 50) flow_list.insert(0, flow_out) return flow_list def forward(self, ref, supp): assert ref.size() == supp.size() h, w = ref.size(2), ref.size(3) w_floor = math.floor(math.ceil(w / 32.0) * 32.0) h_floor = math.floor(math.ceil(h / 32.0) * 32.0) ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) flow_list = self.process(ref, supp, w, h, w_floor, h_floor) return flow_list[0] if len(flow_list) == 1 else flow_list class PCD_Align(nn.Module): ''' Alignment module using Pyramid, Cascading and Deformable convolution with 3 pyramid levels. [From EDVR] ''' def __init__(self, nf=64, groups=8, wn=None): super(PCD_Align, self).__init__() if wn is None: wn = lambda x: torch.nn.utils.weight_norm(x) # L3: level 3, 1/4 spatial size self.L3_offset_conv1 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)) # concat for diff self.L3_offset_conv2 = wn(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) # self.L3_shift = ShiftAlign(nf) self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) # extra_offset_mask=True) # L2: level 2, 1/2 spatial size self.L2_offset_conv1 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)) # concat for diff self.L2_offset_conv2 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)) # concat for offset self.L2_offset_conv3 = wn(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) # self.L2_shift = ShiftAlign(nf) self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) # extra_offset_mask=True) self.L2_fea_conv = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)) # concat for fea # L1: level 1, original spatial size self.L1_offset_conv1 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)) # concat for diff self.L1_offset_conv2 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)) # concat for offset self.L1_offset_conv3 = wn(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) # self.L1_shift = ShiftAlign(nf) self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) # extra_offset_mask=True) self.L1_fea_conv = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)) # concat for fea # Cascading DCN self.cas_offset_conv1 = wn(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)) # concat for diff self.cas_offset_conv2 = wn(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, nbr_fea_l, ref_fea_l): '''align other neighboring frames to the reference frame in the feature level nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features ''' # L3 L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1) L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset)) L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset)) # L3_nbr_fea = self.L3_shift(L3_offset, nbr_fea_l[2]) L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset)) # L2 L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False) L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1) L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset)) L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset*2], dim=1))) L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset)) # L2_nbr_fea = self.L2_shift(L2_offset, nbr_fea_l[1]) L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset) L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False) L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1))) # L1 L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False) L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1) L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset)) L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1))) L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset)) # L1_nbr_fea = self.L1_shift(L1_offset, nbr_fea_l[0]) L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset) L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False) L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1)) # Cascading offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1) offset = self.lrelu(self.cas_offset_conv1(offset)) offset = self.lrelu(self.cas_offset_conv2(offset)) L1_fea = self.cas_dcnpack(L1_fea, offset) return L1_fea class FlowGuidedPCDAlign(nn.Module): ''' Alignment module using Pyramid, Cascading and Deformable convolution with 3 pyramid levels. [From EDVR] ''' def __init__(self, nf=64, groups=8): super(FlowGuidedPCDAlign, self).__init__() # L3: level 3, 1/4 spatial size self.L3_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L3_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) # L2: level 2, 1/2 spatial size self.L2_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L2_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # L1: level 1, original spatial size self.L1_offset_conv1 = nn.Conv2d(nf * 2 + 2, nf, 3, 1, 1, bias=True) # concat for diff self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L1_dcnpack = FlowGuidedDCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # Cascading DCN # self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff # self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, nbr_fea_l, nbr_fea_warped_l, ref_fea_l, flows_l): '''align other neighboring frames to the reference frame in the feature level nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features ''' # L3 L3_offset = torch.cat([nbr_fea_warped_l[2], ref_fea_l[2], flows_l[2]], dim=1) L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset)) L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset)) L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset, flows_l[2])) # L2 L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False) L2_offset = torch.cat([nbr_fea_warped_l[1], ref_fea_l[1], flows_l[1]], dim=1) L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset)) L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset*2], dim=1))) L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset)) L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset, flows_l[1]) L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False) L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1))) # L1 L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False) L1_offset = torch.cat([nbr_fea_warped_l[0], ref_fea_l[0], flows_l[0]], dim=1) L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset)) L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1))) L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset)) L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset, flows_l[0]) L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False) L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1)) # Cascading # offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1) # offset = self.lrelu(self.cas_offset_conv1(offset)) # offset = self.lrelu(self.cas_offset_conv2(offset)) # L1_fea = self.cas_dcnpack(L1_fea, offset) return L1_fea class CrossNonLocal_Fusion(nn.Module): ''' Cross Non Local fusion module ''' def __init__(self, nf=64, out_feat=96, nframes=5, center=2): super(CrossNonLocal_Fusion, self).__init__() self.center = center self.non_local_T = nn.ModuleList() self.non_local_F = nn.ModuleList() for i in range(nframes): self.non_local_T.append(NonLocalCross(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False)) self.non_local_F.append(NonLocal(nf, inter_channels=nf//2, sub_sample=True, bn_layer=False)) # fusion conv: using 1x1 to save parameters and computation self.fea_fusion = nn.Conv2d(nframes * nf*2, out_feat, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, aligned_fea): B, N, C, H, W = aligned_fea.size() # N video frames ref = aligned_fea[:, self.center, :, :, :].clone() cor_l = [] non_l = [] for i in range(N): nbr = aligned_fea[:, i, :, :, :] non_l.append(self.non_local_F[i](nbr)) cor_l.append(self.non_local_T[i](nbr, ref)) aligned_fea_T = torch.cat(cor_l, dim=1) aligned_fea_F = torch.cat(non_l, dim=1) aligned_fea = torch.cat([aligned_fea_T, aligned_fea_F], dim=1) #### fusion fea = self.fea_fusion(aligned_fea) return fea class EBSR(nn.Module): r""" SwinBSR """ def __init__(self, args, nframes=8, img_size=64, patch_size=1, in_chans=3, out_chans=3, embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=4, non_local=False, **kwargs): super(EBSR, self).__init__() num_in_ch = in_chans num_out_ch = out_chans num_feat = 128 groups = 8 back_RBs = 5 n_resblocks = 8 embed_dim = num_feat self.args = args self.center = 0 self.upscale = upscale self.window_size = window_size self.non_local = non_local self.nframes = nframes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = embed_dim self.mlp_ratio = mlp_ratio spynet_path='/home/luoziwei/.pretrained_models/spynet_sintel_final-3d2a1287.pth' self.spynet = SpyNet(spynet_path, [3, 4, 5]) self.conv_flow = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1) self.flow_ps = nn.PixelShuffle(2) # self.debayer = Debayer3x3() # split image into non-overlapping patches self.patch_embed = swu.PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # merge non-overlapping patches into image self.patch_unembed = swu.PatchUnEmbed( img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) ##################################################################################################### ################################### 1, shallow feature extraction ################################### self.conv_first = nn.Conv2d(num_in_ch*(1+2*0), embed_dim, 3, 1, 1, bias=True) # # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule if args.swinfeature: if self.args.local_rank <= 0: print("using swinfeature") self.pre_layers = nn.ModuleList() for i_layer in range(depths[0]): layer = swu.SwinTransformerBlock(dim=embed_dim, input_resolution=(patches_resolution[0]//2, patches_resolution[1]//2), num_heads=num_heads[i_layer], window_size=window_size, shift_size=0 if (i_layer % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i_layer], norm_layer=norm_layer) self.pre_layers.append(layer) # self.pre_linear = nn.Linear(embed_dim, embed_dim) self.pre_norm = norm_layer(embed_dim) else: WARB = functools.partial(arch_util.WideActResBlock, nf=embed_dim) self.feature_extraction = arch_util.make_layer(WARB, 5) self.conv_after_pre_layer = nn.Conv2d(embed_dim, num_feat*4, 3, 1, 1, bias=True) self.mid_ps = nn.PixelShuffle(2) self.fea_L2_conv1 = nn.Conv2d(num_feat, num_feat*2, 3, 2, 1, bias=True) self.fea_L3_conv1 = nn.Conv2d(num_feat*2, num_feat*4, 3, 2, 1, bias=True) ##################################################################################################### ################################### 2, Feature Enhanced PCD Align ################################### # Top layers self.toplayer = nn.Conv2d(num_feat*4, num_feat, kernel_size=1, stride=1, padding=0) # Smooth layers self.smooth1 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1) self.smooth2 = nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1) # Lateral layers self.latlayer1 = nn.Conv2d(num_feat*2, num_feat, kernel_size=1, stride=1, padding=0) self.latlayer2 = nn.Conv2d(num_feat*1, num_feat, kernel_size=1, stride=1, padding=0) # self.align = PCD_Align(nf=num_feat, groups=groups) self.align = FlowGuidedPCDAlign(nf=num_feat, groups=groups) ##################################################################################################### ################################### 3, Multi-frame Feature Fusion ################################## if self.non_local: if self.args.local_rank <= 0: print("using non_local") self.fusion = CrossNonLocal_Fusion(nf=num_feat, out_feat=embed_dim, nframes=nframes, center=self.center) else: self.fusion = nn.Conv2d(nframes * num_feat, embed_dim, 1, 1, bias=True) ##################################################################################################### ################################### 4, deep feature extraction ###################################### # absolute position embedding # if self.ape: # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # swu.trunc_normal_(self.absolute_pos_embed, std=.02) # self.pos_drop = nn.Dropout(p=drop_rate) # # build Residual Swin Transformer blocks (RSTB) # self.layers = nn.ModuleList() # for i_layer in range(1, self.num_layers): # layer = swu.RSTB(dim=embed_dim, # input_resolution=(patches_resolution[0], # patches_resolution[1]), # depth=depths[i_layer], # num_heads=num_heads[i_layer], # window_size=window_size, # mlp_ratio=self.mlp_ratio, # qkv_bias=qkv_bias, qk_scale=qk_scale, # drop=drop_rate, attn_drop=attn_drop_rate, # drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results # norm_layer=norm_layer, # downsample=None, # use_checkpoint=use_checkpoint, # img_size=img_size, # patch_size=patch_size # ) # self.layers.append(layer) # self.norm = norm_layer(self.num_features) LRCN = functools.partial(arch_util.LRSCWideActResGroup, n_resblocks=n_resblocks, nf=embed_dim) self.post_feature_extraction = nn.Sequential(arch_util.make_layer_idx(LRCN, back_RBs), nn.Conv2d(embed_dim*(back_RBs+1), num_feat, 1)) # self.post_feature_extraction = nn.Sequential( # arch_util.make_layer(WARB, 20), # nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)) # build the last conv layer in deep feature extraction # self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) ##################################################################################################### ################################ 5, high quality image reconstruction ################################ self.upconv1 = nn.Conv2d(embed_dim, num_feat * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(64, args.n_colors, 3, 1, 1, bias=True) #### skip ############# self.skip_pixel_shuffle = nn.PixelShuffle(2) self.skipup1 = nn.Conv2d(num_in_ch//4, num_feat * 4, 3, 1, 1, bias=True) self.skipup2 = nn.Conv2d(num_feat, args.n_colors * 4, 3, 1, 1, bias=True) #### activation function self.lrelu = nn.LeakyReLU(0.1, inplace=True) self.lrelu2 = nn.LeakyReLU(0.1, inplace=True) # self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): swu.trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def _upsample_add(self, x, y): return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + y def check_image_size(self, x): _, _, h, w = x.size() mod_pad_h = (self.window_size - h % self.window_size) % self.window_size mod_pad_w = (self.window_size - w % self.window_size) % self.window_size x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') return x def pre_forward_features(self, x): if self.args.swinfeature: x_size = (x.shape[-2], x.shape[-1]) x = self.patch_embed(x, use_norm=True) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for idx, layer in enumerate(self.pre_layers): x = layer(x, x_size) x = self.pre_norm(x) x = self.patch_unembed(x, x_size) else: x = self.feature_extraction(x) return x def forward_features(self, x): # x_size = (x.shape[-2], x.shape[-1]) # x = self.patch_embed(x) # if self.ape: # x = x + self.absolute_pos_embed # x = self.pos_drop(x) # for idx, layer in enumerate(self.layers): # x = layer(x, x_size) # x = self.norm(x) # B L C # x = self.patch_unembed(x, x_size) x = self.post_feature_extraction(x) return x @autocast() def forward(self, x, print_time=False): B, N, C, H, W = x.size() # N video frames x_center = x[:, self.center, :, :, :].contiguous() #### skip module ######## skip1 = self.lrelu2(self.skip_pixel_shuffle(self.skipup1(self.skip_pixel_shuffle(x_center)))) skip2 = self.skip_pixel_shuffle(self.skipup2(skip1)) x_ = self.conv_flow(self.flow_ps(x.view(B*N, C, H, W))).view(B, N, -1, H*2, W*2) # calculate flows ref_flows = self.get_ref_flows(x_) # flows_backward, flows_forward = self.get_flow_2frames(x_) # # # warp input # x_backward, x_forward = self.get_aligned_image_2frames(x, flows_backward[1], flows_forward[1]) # x = torch.cat([x, x_backward, x_forward], 2) #### extract LR features x = self.lrelu(self.conv_first(x.view(B*N, -1, H, W))) L1_fea = self.mid_ps(self.conv_after_pre_layer(self.pre_forward_features(x))) _, _, H, W = L1_fea.size() L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea)) L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea)) # FPN enhance features L3_fea = self.lrelu(self.toplayer(L3_fea)) L2_fea = self.smooth1(self._upsample_add(L3_fea, self.latlayer1(L2_fea))) L1_fea = self.smooth2(self._upsample_add(L2_fea, self.latlayer2(L1_fea))) L1_fea = L1_fea.view(B, N, -1, H, W).contiguous() L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2 ).contiguous() L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4).contiguous() #### PCD align # ref feature list ref_fea_l = [ L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(), L3_fea[:, self.center, :, :, :].clone() ] aligned_fea = [] for i in range(N): nbr_fea_l = [ L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(), L3_fea[:, i, :, :, :].clone() ] flows_l = [ ref_flows[0][:, i, :, :, :].clone(), ref_flows[1][:, i, :, :, :].clone(), ref_flows[2][:, i, :, :, :].clone() ] # print(nbr_fea_l[0].shape, flows_l[0].shape) nbr_warped_l = [ arch_util.flow_warp(nbr_fea_l[0], flows_l[0].permute(0, 2, 3, 1), 'bilinear'), arch_util.flow_warp(nbr_fea_l[1], flows_l[1].permute(0, 2, 3, 1), 'bilinear'), arch_util.flow_warp(nbr_fea_l[2], flows_l[2].permute(0, 2, 3, 1), 'bilinear') ] aligned_fea.append(self.align(nbr_fea_l, nbr_warped_l, ref_fea_l, flows_l)) # aligned_fea.append(self.align(nbr_fea_l, ref_fea_l)) aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W] --> [B, T, C, H, W] if not self.non_local: aligned_fea = aligned_fea.view(B, -1, H, W) x = self.lrelu(self.fusion(aligned_fea)) x = self.forward_features(x) x = self.lrelu(self.pixel_shuffle(self.upconv1(x))) x = skip1 + x x = self.lrelu(self.pixel_shuffle(self.upconv2(x))) x = self.lrelu(self.HRconv(x)) x = self.conv_last(x) x = skip2 + x return x def get_ref_flows(self, x): '''Get flow between frames ref and other''' b, n, c, h, w = x.size() x_nbr = x.reshape(-1, c, h, w) x_ref = x[:, self.center:self.center+1, :, :, :].repeat(1, n, 1, 1, 1).reshape(-1, c, h, w) # backward flows = self.spynet(x_ref, x_nbr) flows_list = [flow.view(b, n, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in zip(flows, range(3))] return flows_list def get_flow_2frames(self, x): '''Get flow between frames t and t+1 from x.''' b, n, c, h, w = x.size() x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w) x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w) # backward flows_backward = self.spynet(x_1, x_2) flows_backward = [flow.view(b, n-1, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in zip(flows_backward, range(3))] # forward flows_forward = self.spynet(x_2, x_1) flows_forward = [flow.view(b, n-1, 2, h // (2 ** (i)), w // (2 ** (i))) for flow, i in zip(flows_forward, range(3))] return flows_backward, flows_forward def get_aligned_image_2frames(self, x, flows_backward, flows_forward): '''Parallel feature warping for 2 frames.''' # backward n = x.size(1) x_backward = [torch.zeros_like(x[:, -1, ...]).repeat(1, 4, 1, 1)] for i in range(n - 1, 0, -1): x_i = x[:, i, ...] flow = flows_backward[:, i - 1, ...] x_backward.insert(0, arch_util.flow_warp(x_i, flow.permute(0, 2, 3, 1), 'nearest4')) # frame i+1 aligned towards i # forward x_forward = [torch.zeros_like(x[:, 0, ...]).repeat(1, 4, 1, 1)] for i in range(0, n - 1): x_i = x[:, i, ...] flow = flows_forward[:, i, ...] x_forward.append(arch_util.flow_warp(x_i, flow.permute(0, 2, 3, 1), 'nearest4')) # frame i-1 aligned towards i return [torch.stack(x_backward, 1), torch.stack(x_forward, 1)] def get_aligned_feature_2frames(self, x): '''Parallel feature warping for 2 frames.''' # backward n = x.size(1) x_backward = [torch.zeros_like(x[:, -1, ...])] for i in range(n - 1, 0, -1): # x_i = x[:, i, ...] # flow = flows_backward[0][:, i - 1, ...] # x_i_warped = arch_util.flow_warp(x_i, flow.permute(0, 2, 3, 1), 'bilinear') # frame i+1 aligned towards i # x_backward.insert(0, self.FDCN(x_i, x_i_warped, x[:, i - 1, ...], flow)) offset = self.offset_conv(torch.cat([x[:, i, ...], x[:, i - 1, ...]], dim=1)) x_backward.insert(0, self.FDCN(x[:, i, ...].clone(), offset)) # forward x_forward = [torch.zeros_like(x[:, 0, ...])] for i in range(0, n - 1): # x_i = x[:, i, ...] # flow = flows_forward[0][:, i, ...] # x_i_warped = arch_util.flow_warp(x_i, flow.permute(0, 2, 3, 1), 'bilinear') # frame i-1 aligned towards i # x_forward.append(self.FDCN(x_i, x_i_warped, x[:, i + 1, ...], flow)) offset = self.offset_conv(torch.cat([x[:, i, ...], x[:, i + 1, ...]], dim=1)) x_forward.insert(0, self.FDCN(x[:, i, ...].clone(), offset)) return [torch.stack(x_backward, 1), torch.stack(x_forward, 1)] ================================================ FILE: code/synthetic/bsrt/model/non_local/network.py ================================================ from torch import nn # from lib.non_local_concatenation import NONLocalBlock2D # from lib.non_local_gaussian import NONLocalBlock2D from lib.non_local_embedded_gaussian import NONLocalBlock2D # from lib.non_local_dot_product import NONLocalBlock2D class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv_1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), ) self.nl_1 = NONLocalBlock2D(in_channels=32) self.conv_2 = nn.Sequential( nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), ) self.nl_2 = NONLocalBlock2D(in_channels=64) self.conv_3 = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), ) self.fc = nn.Sequential( nn.Linear(in_features=128*3*3, out_features=256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(in_features=256, out_features=10) ) def forward(self, x): batch_size = x.size(0) feature_1 = self.conv_1(x) nl_feature_1 = self.nl_1(feature_1) feature_2 = self.conv_2(nl_feature_1) nl_feature_2 = self.nl_2(feature_2) output = self.conv_3(nl_feature_2).view(batch_size, -1) output = self.fc(output) return output def forward_with_nl_map(self, x): batch_size = x.size(0) feature_1 = self.conv_1(x) nl_feature_1, nl_map_1 = self.nl_1(feature_1, return_nl_map=True) feature_2 = self.conv_2(nl_feature_1) nl_feature_2, nl_map_2 = self.nl_2(feature_2, return_nl_map=True) output = self.conv_3(nl_feature_2).view(batch_size, -1) output = self.fc(output) return output, [nl_map_1, nl_map_2] if __name__ == '__main__': import torch img = torch.randn(3, 1, 28, 28) net = Network() out = net(img) print(out.size()) ================================================ FILE: code/synthetic/bsrt/model/non_local/non_local_concatenation.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(2)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.concat_project = nn.Sequential( nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), nn.ReLU() ) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer) def forward(self, x, return_nl_map=False): ''' :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: ''' batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) # (b, c, N, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) # (b, c, 1, N) phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) h = theta_x.size(2) w = phi_x.size(3) theta_x = theta_x.repeat(1, 1, 1, w) phi_x = phi_x.repeat(1, 1, h, 1) concat_feature = torch.cat([theta_x, phi_x], dim=1) f = self.concat_project(concat_feature) b, _, h, w = f.size() f = f.view(b, h, w) N = f.size(-1) f_div_C = f / N y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True,): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/synthetic/bsrt/model/non_local/non_local_cross_dot_product.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(4)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer) def forward(self, x, ref, return_nl_map=False): """ :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_ref = self.theta(ref).view(batch_size, self.inter_channels, -1) theta_ref = theta_ref.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_ref, phi_x) N = f.size(-1) f_div_C = f / N y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/synthetic/bsrt/model/non_local/non_local_dot_product.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 4, 4)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(4, 4)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(2)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer) def forward(self, x, return_nl_map=False): """ :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_x, phi_x) N = f.size(-1) f_div_C = f / N y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/synthetic/bsrt/model/non_local/non_local_embedded_gaussian.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): """ :param in_channels: :param inter_channels: :param dimension: :param sub_sample: :param bn_layer: """ super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(2)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer) def forward(self, x, return_nl_map=False): """ :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_x, phi_x) f_div_C = F.softmax(f, dim=-1) y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer,) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer,) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/synthetic/bsrt/model/non_local/non_local_gaussian.py ================================================ import torch from torch import nn from torch.nn import functional as F class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool_layer = nn.MaxPool1d(kernel_size=(2)) bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) if sub_sample: self.g = nn.Sequential(self.g, max_pool_layer) self.phi = max_pool_layer def forward(self, x, return_nl_map=False): """ :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = x.view(batch_size, self.in_channels, -1) theta_x = theta_x.permute(0, 2, 1) if self.sub_sample: phi_x = self.phi(x).view(batch_size, self.in_channels, -1) else: phi_x = x.view(batch_size, self.in_channels, -1) f = torch.matmul(theta_x, phi_x) f_div_C = F.softmax(f, dim=-1) # if self.store_last_batch_nl_map: # self.nl_map = f_div_C y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class NONLocalBlock1D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock1D, self).__init__(in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock2D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock2D, self).__init__(in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) class NONLocalBlock3D(_NonLocalBlockND): def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): super(NONLocalBlock3D, self).__init__(in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer) if __name__ == '__main__': import torch for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: img = torch.zeros(2, 3, 20) net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.zeros(2, 3, 20, 20) net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) img = torch.randn(2, 3, 8, 20, 20) net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) out = net(img) print(out.size()) ================================================ FILE: code/synthetic/bsrt/model/swin_util.py ================================================ # ----------------------------------------------------------------------------------- # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 # Originally Written by Ze Liu, Modified by Jingyun Liang. # ----------------------------------------------------------------------------------- import math import torch import torch.nn as nn import torch.nn.functional as F # import torch.utils.checkpoint as checkpoint from model.checkpoint import CheckpointFunction as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from functools import reduce, lru_cache import time class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops @lru_cache() def calculate_mask(x_size, window_size, shift_size): # calculate attention mask for SW-MSA H, W = x_size img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) w_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, window_size * window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) return attn_mask class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio self.use_checkpoint = use_checkpoint if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, x_size): H, W = x_size B, L, C = x.shape # assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size attn_mask = calculate_mask(x_size, self.window_size, self.shift_size).to(x.device) attn_windows = self.attn(x_windows, mask=attn_mask) # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = False # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, use_checkpoint=use_checkpoint) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, x_size): for i, blk in enumerate(self.blocks): if self.use_checkpoint: # x = checkpoint.checkpoint(blk, x, x_size) x = checkpoint.apply(blk, 2, x, x_size) else: x = blk(x, x_size) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class RSTB(nn.Module): """Residual Swin Transformer Block (RSTB). Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. img_size: Input image size. patch_size: Patch size. resi_connection: The convolutional block before residual connection. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, img_size=224, patch_size=4, resi_connection='1conv'): super(RSTB, self).__init__() # print(f'dim: {dim}, input_resolution: {input_resolution}, depth: {depth}, num_heads: {num_heads}, window_size: {window_size}, img_size: {img_size}. patch_size: {patch_size}') self.dim = dim self.input_resolution = input_resolution self.residual_group = BasicLayer(dim=dim, input_resolution=input_resolution, depth=depth, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, norm_layer=norm_layer, downsample=downsample, use_checkpoint=use_checkpoint) if resi_connection == '1conv': self.conv = nn.Conv2d(dim, dim, 3, 1, 1) elif resi_connection == '3conv': # to save parameters and memory self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(dim // 4, dim, 3, 1, 1)) self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) self.patch_unembed = PatchUnEmbed( img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) def forward(self, x, x_size): x = self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x return x def flops(self): flops = 0 flops += self.residual_group.flops() H, W = self.input_resolution flops += H * W * self.dim * self.dim * 9 flops += self.patch_embed.flops() flops += self.patch_unembed.flops() return flops class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x, use_norm=True): x = x.flatten(2).transpose(1, 2) # B Ph*Pw C if use_norm and self.norm is not None: x = self.norm(x) return x def flops(self): flops = 0 H, W = self.img_size if self.norm is not None: flops += H * W * self.embed_dim return flops class PatchUnEmbed(nn.Module): r""" Image to Patch Unembedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim def forward(self, x, x_size): B, HW, C = x.shape x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C return x def flops(self): flops = 0 return flops ================================================ FILE: code/synthetic/bsrt/model/utils/interp_methods.py ================================================ from math import pi try: import torch except ImportError: torch = None try: import numpy except ImportError: numpy = None if numpy is None and torch is None: raise ImportError("Must have either Numpy or PyTorch but both not found") def set_framework_dependencies(x): if type(x) is numpy.ndarray: to_dtype = lambda a: a fw = numpy else: to_dtype = lambda a: a.to(x.dtype) fw = torch eps = fw.finfo(fw.float32).eps return fw, to_dtype, eps def support_sz(sz): def wrapper(f): f.support_sz = sz return f return wrapper @support_sz(4) def cubic(x): fw, to_dtype, eps = set_framework_dependencies(x) absx = fw.abs(x) absx2 = absx ** 2 absx3 = absx ** 3 return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * to_dtype((1. < absx) & (absx <= 2.))) @support_sz(4) def lanczos2(x): fw, to_dtype, eps = set_framework_dependencies(x) return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) @support_sz(6) def lanczos3(x): fw, to_dtype, eps = set_framework_dependencies(x) return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) @support_sz(2) def linear(x): fw, to_dtype, eps = set_framework_dependencies(x) return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * to_dtype((0 <= x) & (x <= 1))) @support_sz(1) def box(x): fw, to_dtype, eps = set_framework_dependencies(x) return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) ================================================ FILE: code/synthetic/bsrt/model/utils/psconv.py ================================================ import torch import torch.nn as nn class PyConv2d(nn.Module): """PyConv2d with padding (general case). Applies a 2D PyConv over an input signal composed of several input planes. Args: in_channels (int): Number of channels in the input image out_channels (list): Number of channels for each pyramid level produced by the convolution pyconv_kernels (list): Spatial size of the kernel for each pyramid level pyconv_groups (list): Number of blocked connections from input channels to output channels for each pyramid level stride (int or tuple, optional): Stride of the convolution. Default: 1 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False`` Example:: >>> # PyConv with two pyramid levels, kernels: 3x3, 5x5 >>> m = PyConv2d(in_channels=64, out_channels=[32, 32], pyconv_kernels=[3, 5], pyconv_groups=[1, 4]) >>> input = torch.randn(4, 64, 56, 56) >>> output = m(input) >>> # PyConv with three pyramid levels, kernels: 3x3, 5x5, 7x7 >>> m = PyConv2d(in_channels=64, out_channels=[16, 16, 32], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8]) >>> input = torch.randn(4, 64, 56, 56) >>> output = m(input) """ def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_groups, stride=1, dilation=1, bias=False): super(PyConv2d, self).__init__() assert len(out_channels) == len(pyconv_kernels) == len(pyconv_groups) self.pyconv_levels = [None] * len(pyconv_kernels) for i in range(len(pyconv_kernels)): self.pyconv_levels[i] = nn.Conv2d(in_channels, out_channels[i], kernel_size=pyconv_kernels[i], stride=stride, padding=pyconv_kernels[i] // 2, groups=pyconv_groups[i], dilation=dilation, bias=bias) self.pyconv_levels = nn.ModuleList(self.pyconv_levels) def forward(self, x): out = [] for level in self.pyconv_levels: out.append(level(x)) return torch.cat(out, 1) ################################################################ class PSConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, parts=4, bias=False): super(PSConv2d, self).__init__() self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, dilation, dilation, groups=parts, bias=bias) self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * dilation, 2 * dilation, groups=parts, bias=bias) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) def backward_hook(grad): out = grad.clone() out[self.mask] = 0 return out self.mask = torch.zeros(self.conv.weight.shape).byte().cuda() _in_channels = in_channels // parts _out_channels = out_channels // parts for i in range(parts): self.mask[i * _out_channels: (i + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1 self.mask[(i + parts//2)%parts * _out_channels: ((i + parts//2)%parts + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 self.conv.weight.data[self.mask] = 0 self.conv.weight.register_hook(backward_hook) self.weight = self.conv.weight self.bias = self.conv.bias def forward(self, x): x1, x2 = x.chunk(2, dim=1) x_shift = self.gwconv_shift(torch.cat((x2, x1), dim=1)) return self.gwconv(x) + self.conv(x) + x_shift # PSConv-based Group Convolution class PSGConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, parts=4, bias=False): super(PSGConv2d, self).__init__() self.gwconv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups * parts, bias=bias) self.gwconv_shift = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 2 * padding, 2 * dilation, groups=groups * parts, bias=bias) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias) def backward_hook(grad): out = grad.clone() out[self.mask] = 0 return out self.mask = torch.zeros(self.conv.weight.shape).bool().cuda() _in_channels = in_channels // (groups * parts) _out_channels = out_channels // (groups * parts) for i in range(parts): for j in range(groups): self.mask[(i + j * groups) * _out_channels: (i + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, : , :] = 1 self.mask[((i + parts // 2) % parts + j * groups) * _out_channels: ((i + parts // 2) % parts + j * groups + 1) * _out_channels, i * _in_channels: (i + 1) * _in_channels, :, :] = 1 self.conv.weight.data[self.mask] = 0 self.conv.weight.register_hook(backward_hook) self.groups = groups self.weight = self.conv.weight self.bias = self.conv.bias def forward(self, x): x_split = (z.chunk(2, dim=1) for z in x.chunk(self.groups, dim=1)) x_merge = torch.cat(tuple(torch.cat((x2, x1), dim=1) for (x1, x2) in x_split), dim=1) x_shift = self.gwconv_shift(x_merge) gx = self.gwconv(x) cx = self.conv(x) # print(x.shape, gx.shape, cx.shape, x_merge.shape, x_shift.shape) return gx + cx + x_shift ================================================ FILE: code/synthetic/bsrt/model/utils/resize_right.py ================================================ import warnings from math import ceil import model.utils.interp_methods as interp_methods class NoneClass: pass try: import torch from torch import nn nnModuleWrapped = nn.Module except ImportError: warnings.warn('No PyTorch found, will work only with Numpy') torch = None nnModuleWrapped = NoneClass try: import numpy except ImportError: warnings.warn('No Numpy found, will work only with PyTorch') numpy = None if numpy is None and torch is None: raise ImportError("Must have either Numpy or PyTorch but both not found") def resize(input, scale_factors=None, out_shape=None, interp_method=interp_methods.cubic, support_sz=None, antialiasing=True): # get properties of the input tensor in_shape, n_dims = input.shape, input.ndim # fw stands for framework that can be either numpy or torch, # determined by the input type fw = numpy if type(input) is numpy.ndarray else torch eps = fw.finfo(fw.float32).eps # set missing scale factors or output shapem one according to another, # scream if both missing scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw) # sort indices of dimensions according to scale of each dimension. # since we are going dim by dim this is efficient sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind]) if scale_factors[dim] != 1.] # unless support size is specified by the user, it is an attribute # of the interpolation method if support_sz is None: support_sz = interp_method.support_sz # when using pytorch, we need to know what is the input tensor device if fw is torch: device = input.device # output begins identical to input and changes with each iteration output = input # iterate over dims for dim, scale_factor in sorted_filtered_dims_and_scales: # get 1d set of weights and fields of view for each output location # along this dim field_of_view, weights = prepare_weights_and_field_of_view_1d( dim, scale_factor, in_shape[dim], out_shape[dim], interp_method, support_sz, antialiasing, fw, eps, device) # multiply the weights by the values in the field of view and # aggreagate output = apply_weights(output, field_of_view, weights, dim, n_dims, fw) return output class ResizeLayer(nnModuleWrapped): def __init__(self, in_shape, scale_factors=None, out_shape=None, interp_method=interp_methods.cubic, support_sz=None, antialiasing=True): super(ResizeLayer, self).__init__() # fw stands for framework, that can be either numpy or torch. since # this is a torch layer, only one option in this case. fw = torch eps = fw.finfo(fw.float32).eps # set missing scale factors or output shapem one according to another, # scream if both missing scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw) # unless support size is specified by the user, it is an attribute # of the interpolation method if support_sz is None: support_sz = interp_method.support_sz self.n_dims = len(in_shape) # sort indices of dimensions according to scale of each dimension. # since we are going dim by dim this is efficient self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) for dim in sorted(range(self.n_dims), key=lambda ind: scale_factors[ind]) if scale_factors[dim] != 1.] # iterate over dims field_of_view_list = [] weights_list = [] for dim, scale_factor in self.sorted_filtered_dims_and_scales: # get 1d set of weights and fields of view for each output # location along this dim field_of_view, weights = prepare_weights_and_field_of_view_1d( dim, scale_factor, in_shape[dim], out_shape[dim], interp_method, support_sz, antialiasing, fw, eps, input.device) # keep weights and fields of views for all dims weights_list.append(nn.Parameter(weights, requires_grad=False)) field_of_view_list.append(nn.Parameter(field_of_view, requires_grad=False)) self.field_of_view = nn.ParameterList(field_of_view_list) self.weights = nn.ParameterList(weights_list) self.in_shape = in_shape def forward(self, input): # output begins identical to input and changes with each iteration output = input for (dim, scale_factor), field_of_view, weights in zip( self.sorted_filtered_dims_and_scales, self.field_of_view, self.weights): # multiply the weights by the values in the field of view and # aggreagate output = apply_weights(output, field_of_view, weights, dim, self.n_dims, torch) return output def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz, interp_method, support_sz, antialiasing, fw, eps, device=None): # If antialiasing is taking place, we modify the window size and the # interpolation method (see inside function) interp_method, cur_support_sz = apply_antialiasing_if_needed( interp_method, support_sz, scale_factor, antialiasing) # STEP 1- PROJECTED GRID: The non-integer locations of the projection of # output pixel locations to the input tensor projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device) # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels # that influence it field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps) # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the # field of view for each output pixel weights = get_weights(interp_method, projected_grid, field_of_view) return field_of_view, weights def apply_weights(input, field_of_view, weights, dim, n_dims, fw): # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying # its set of weights with the pixel values in its field of view. # We now multiply the fields of view with their matching weights. # We do this by tensor multiplication and broadcasting. # this step is separated to a different function, so that it can be # repeated with the same calculated weights and fields. # for this operations we assume the resized dim is the first one. # so we transpose and will transpose back after multiplying tmp_input = fw_swapaxes(input, dim, 0, fw) # field_of_view is a tensor of order 2: for each output (1d location # along cur dim)- a list of 1d neighbors locations. # note that this whole operations is applied to each dim separately, # this is why it is all in 1d. # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1: # for each output pixel (this time indicated in all dims), these are the # values of the neighbors in the 1d field of view. note that we only # consider neighbors along the current dim, but such set exists for every # multi-dim location, hence the final tensor order is image_dims+1. neighbors = tmp_input[field_of_view] # weights is an order 2 tensor: for each output location along 1d- a list # of weighs matching the field of view. we augment it with ones, for # broadcasting, so that when multiplies some tensor the weights affect # only its first dim. tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1))) # now we simply multiply the weights with the neighbors, and then sum # along the field of view, to get a single value per out pixel tmp_output = (neighbors * tmp_weights).sum(1) # we transpose back the resized dim to its original position return fw_swapaxes(tmp_output, 0, dim, fw) def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw): # eventually we must have both scale-factors and out-sizes for all in/out # dims. however, we support many possible partial arguments if scale_factors is None and out_shape is None: raise ValueError("either scale_factors or out_shape should be " "provided") if out_shape is not None: # if out_shape has less dims than in_shape, we defaultly resize the # first dims for numpy and last dims for torch out_shape = (list(out_shape) + list(in_shape[:-len(out_shape)]) if fw is numpy else list(in_shape[:-len(out_shape)]) + list(out_shape)) if scale_factors is None: # if no scale given, we calculate it as the out to in ratio # (not recomended) scale_factors = [out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape)] if scale_factors is not None: # by default, if a single number is given as scale, we assume resizing # two dims (most common are images with 2 spatial dims) scale_factors = (scale_factors if isinstance(scale_factors, (list, tuple)) else [scale_factors, scale_factors]) # if less scale_factors than in_shape dims, we defaultly resize the # first dims for numpy and last dims for torch scale_factors = (list(scale_factors) + [1] * (len(in_shape) - len(scale_factors)) if fw is numpy else [1] * (len(in_shape) - len(scale_factors)) + list(scale_factors)) if out_shape is None: # when no out_shape given, it is calculated by multiplying the # scale by the in_shape (not recomended) out_shape = [ceil(scale_factor * in_sz) for scale_factor, in_sz in zip(scale_factors, in_shape)] # next line intentionally after out_shape determined for stability scale_factors = [float(sf) for sf in scale_factors] return scale_factors, out_shape def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None): # we start by having the ouput coordinates which are just integer locations out_coordinates = fw.arange(out_sz) # if using torch we need to match the grid tensor device to the input device out_coordinates = fw_set_device(out_coordinates, device, fw) # This is projecting the ouput pixel locations in 1d to the input tensor, # as non-integer locations. # the following fomrula is derived in the paper # "From Discrete to Continuous Convolutions" by Shocher et al. return (out_coordinates / scale_factor + (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor)) def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps): # for each output pixel, map which input pixels influence it, in 1d. # we start by calculating the leftmost neighbor, using half of the window # size (eps is for when boundary is exact int) left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) # then we simply take all the pixel centers in the field by counting # window size pixels from the left boundary ordinal_numbers = fw.arange(ceil(cur_support_sz - eps)) # in case using torch we need to match the device ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw) field_of_view = left_boundaries[:, None] + ordinal_numbers # next we do a trick instead of padding, we map the field of view so that # it would be like mirror padding, without actually padding # (which would require enlarging the input tensor) mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw) field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])] field_of_view = fw_set_device(field_of_view,projected_grid.device, fw) return field_of_view def get_weights(interp_method, projected_grid, field_of_view): # the set of weights per each output pixels is the result of the chosen # interpolation method applied to the distances between projected grid # locations and the pixel-centers in the field of view (distances are # directed, can be positive or negative) weights = interp_method(projected_grid[:, None] - field_of_view) # we now carefully normalize the weights to sum to 1 per each output pixel sum_weights = weights.sum(1, keepdims=True) sum_weights[sum_weights == 0] = 1 return weights / sum_weights def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, antialiasing): # antialiasing is "stretching" the field of view according to the scale # factor (only for downscaling). this is low-pass filtering. this # requires modifying both the interpolation (stretching the 1d # function and multiplying by the scale-factor) and the window size. if scale_factor >= 1.0 or not antialiasing: return interp_method, support_sz cur_interp_method = (lambda arg: scale_factor * interp_method(scale_factor * arg)) cur_support_sz = support_sz / scale_factor return cur_interp_method, cur_support_sz def fw_ceil(x, fw): if fw is numpy: return fw.int_(fw.ceil(x)) else: return x.ceil().long() def fw_cat(x, fw): if fw is numpy: return fw.concatenate(x) else: return fw.cat(x) def fw_swapaxes(x, ax_1, ax_2, fw): if fw is numpy: return fw.swapaxes(x, ax_1, ax_2) else: return x.transpose(ax_1, ax_2) def fw_set_device(x, device, fw): if fw is numpy: return x else: return x.to(device) ================================================ FILE: code/synthetic/bsrt/option.py ================================================ import argparse parser = argparse.ArgumentParser(description='EDSR and MDSR') parser.add_argument('--n_resblocks', type=int, default=16, help='number of residual blocks') parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps') parser.add_argument('--n_colors', type=int, default=3, help='number of color channels to use') parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') parser.add_argument('--burst_size', type=int, default=14, help='burst size, max 14') parser.add_argument('--burst_channel', type=int, default=4, help='burst size, max 14') parser.add_argument('--swinfeature', action='store_true', help='use swin transformer to extract features') parser.add_argument('--model_level', type=str, default='S', help='S: small, L: large') ################## fine-tune ################## parser.add_argument('--finetune', action='store_true', help='finetune model') parser.add_argument('--finetune_align', action='store_true', help='finetune alignment module') parser.add_argument('--finetune_swin', action='store_true', help='finetune swin trans module') parser.add_argument('--finetune_conv', action='store_true', help='finetune rest convs') parser.add_argument('--finetune_prelayer', action='store_true', help='finetune finetune pre feature extract layer') parser.add_argument('--finetune_upconv', action='store_true', help='finetune finetune up conv layer') parser.add_argument('--finetune_spynet', action='store_true', help='finetune finetune up conv layer') # Hardware specifications parser.add_argument('--n_threads', type=int, default=6, help='number of threads for data loading') parser.add_argument('--cpu', action='store_true', help='use cpu only') parser.add_argument('--n_GPUs', type=int, default=2, help='number of GPUs') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--local_rank', type=int, default=-1, help='proc index') parser.add_argument('--fp16', action='store_true', help='use fp16 only') parser.add_argument('--use_checkpoint', action='store_true', help='use use_checkpoint in swin transformer') # Data specifications parser.add_argument('--root', type=str, default='/data/dataset/ntire21/burstsr/synthetic', help='dataset directory') parser.add_argument('--mode', type=str, default='train', help='demo image directory') parser.add_argument('--scale', type=str, default='4', help='super resolution scale') parser.add_argument('--patch_size', type=int, default=256, help='output patch size') parser.add_argument('--rgb_range', type=int, default=1, help='maximum value of RGB') parser.add_argument('--chop', action='store_true', help='enable memory-efficient forward') parser.add_argument('--no_augment', action='store_true', help='do not use data augmentation') # Model specifications parser.add_argument('--model', default='LRSC_EDVR', help='model name') parser.add_argument('--act', type=str, default='relu', help='activation function') parser.add_argument('--pre_train', type=str, default='', help='pre-trained model directory') parser.add_argument('--extend', type=str, default='.', help='pre-trained model directory') parser.add_argument('--res_scale', type=float, default=1, help='residual scaling') parser.add_argument('--shift_mean', default=True, help='subtract pixel mean from the input') parser.add_argument('--dilation', action='store_true', help='use dilated convolution') parser.add_argument('--precision', type=str, default='single', choices=('single', 'half'), help='FP precision for test (single | half)') # Option for Residual channel attention network (RCAN) parser.add_argument('--n_resgroups', type=int, default=20, help='number of residual groups') parser.add_argument('--reduction', type=int, default=16, help='number of feature maps reduction') parser.add_argument('--DA', action='store_true', help='use Dual Attention') parser.add_argument('--CA', action='store_true', help='use Channel Attention') parser.add_argument('--non_local', action='store_true', help='use Dual Attention') # Training specifications parser.add_argument('--reset', action='store_true', help='reset the training') parser.add_argument('--test_every', type=int, default=1000, help='do test per every N batches') parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train') parser.add_argument('--batch_size', type=int, default=8, help='input batch size for training') parser.add_argument('--split_batch', type=int, default=1, help='split the batch into smaller chunks') parser.add_argument('--self_ensemble', action='store_true', help='use self-ensemble method for test') parser.add_argument('--test_only', action='store_true', help='set this option to test the model') parser.add_argument('--gan_k', type=int, default=1, help='k value for adversarial loss') # Optimization specifications parser.add_argument('--decay', type=str, default='100-200', help='learning rate decay type') parser.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay') parser.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM', 'RMSprop'), help='optimizer to use (SGD | ADAM | RMSprop)') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), help='ADAM beta') parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') parser.add_argument('--gclip', type=float, default=0, help='gradient clipping threshold (0 = no clipping)') # Loss specifications parser.add_argument('--loss', type=str, default='1*L1', help='loss function configuration') parser.add_argument('--skip_threshold', type=float, default='1e8', help='skipping batch that has large error') # Log specifications parser.add_argument('--save', type=str, default='test', help='file name to save') parser.add_argument('--load', type=str, default='', help='file name to load') parser.add_argument('--resume', type=int, default=0, help='resume from specific checkpoint') parser.add_argument('--save_models', action='store_true', help='save all intermediate models') parser.add_argument('--print_every', type=int, default=20, help='how many batches to wait before logging training status') parser.add_argument('--save_results', action='store_true', help='save output results') parser.add_argument('--save_gt', action='store_true', help='save low-resolution and high-resolution images together') args = parser.parse_args() args.scale = list(map(lambda x: int(x), args.scale.split('+'))) if args.epochs == 0: args.epochs = 1e8 for arg in vars(args): if vars(args)[arg] == 'True': vars(args)[arg] = True elif vars(args)[arg] == 'False': vars(args)[arg] = False ================================================ FILE: code/synthetic/bsrt/requirements.txt ================================================ matplotlib imageio opencv-python tensorboardX ================================================ FILE: code/synthetic/bsrt/scripts/__init__.py ================================================ ================================================ FILE: code/synthetic/bsrt/scripts/cal_mean_std.py ================================================ import torch import numpy as np from tqdm import tqdm from datasets.burstsr_dataset import BurstSRDataset, flatten_raw_image from datasets.synthetic_burst_train_set import SyntheticBurst from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB def main(): train_zurich_raw2rgb = ZurichRAW2RGB(root='/data/dataset/ntire21/burstsr/synthetic', split='train') train_data = SyntheticBurst(train_zurich_raw2rgb, burst_size=14, crop_sz=384) means = [] stds = [] for data in tqdm(train_data): print(data.shape) break if __name__ == '__main__': # if not args.cpu: torch.cuda.set_device(0) main() ================================================ FILE: code/synthetic/bsrt/scripts/demo.sh ================================================ set -ex rlaunch --cpu=4 --gpu=1 --memory=10240 -- python ./scripts/evaluate_burstsr_val.py ================================================ FILE: code/synthetic/bsrt/scripts/download_burstsr_dataset.py ================================================ import os import urllib.request import zipfile import shutil import argparse def download_burstsr_dataset(download_path): out_dir = download_path + '/burstsr_dataset' # Download train folders for i in range(9): if not os.path.isfile('{}/train_{:02d}.zip'.format(out_dir, i)): print('Downloading train_{:02d}'.format(i)) urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/train_{:02d}.zip'.format(i), '{}/tmp.zip'.format(out_dir)) os.rename('{}/tmp.zip'.format(out_dir), '{}/train_{:02d}.zip'.format(out_dir, i)) # Download val folder if not os.path.isfile('{}/val.zip'.format(out_dir)): print('Downloading val') urllib.request.urlretrieve('https://data.vision.ee.ethz.ch/bhatg/BurstSRChallenge/val.zip', '{}/tmp.zip'.format(out_dir)) os.rename('{}/tmp.zip'.format(out_dir), '{}/val.zip'.format(out_dir)) # Unpack train set for i in range(9): print('Unpacking train_{:02d}'.format(i)) with zipfile.ZipFile('{}/train_{:02d}.zip'.format(out_dir, i), 'r') as zip_ref: zip_ref.extractall('{}'.format(out_dir)) # Move files to a common directory os.makedirs('{}/train'.format(out_dir), exist_ok=True) for i in range(9): file_list = os.listdir('{}/train_{:02d}'.format(out_dir, i)) for b in file_list: source_dir = '{}/train_{:02d}/{}'.format(out_dir, i, b) dst_dir = '{}/train/{}'.format(out_dir, b) if os.path.isdir(source_dir): shutil.move(source_dir, dst_dir) # Delete individual subsets for i in range(9): shutil.rmtree('{}/train_{:02d}'.format(out_dir, i)) # Unpack val set print('Unpacking val') with zipfile.ZipFile('{}/val.zip'.format(out_dir), 'r') as zip_ref: zip_ref.extractall('{}'.format(out_dir)) def main(): parser = argparse.ArgumentParser(description='Downloads and unpacks BurstSR dataset') parser.add_argument('path', type=str, help='Path where the dataset will be downloaded') args = parser.parse_args() download_burstsr_dataset(args.path) if __name__ == '__main__': main() ================================================ FILE: code/synthetic/bsrt/scripts/evaluate.sh ================================================ set -ex rlaunch --cpu=4 --gpu=1 --memory=10240 -- python scripts/evaluate_burstsr_val.py ================================================ FILE: code/synthetic/bsrt/scripts/evaluate_burstsr_val.py ================================================ import torch.nn.functional as F from datasets.burstsr_dataset import BurstSRDataset from utils.metrics import AlignedPSNR from pwcnet.pwcnet import PWCNet root = '/data/dataset/ntire21/burstsr/real/NTIRE/burstsr_dataset' class SimpleBaseline: def __init__(self): pass def __call__(self, burst): burst_rgb = burst[:, 0, [0, 1, 3]] burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') return burst_rgb def main(): # Load dataset dataset = BurstSRDataset(root=root, split='val', burst_size=14, crop_sz=80, random_flip=False) # TODO Set your network here net = SimpleBaseline() device = 'cuda' # Load alignment network, used in AlignedPSNR alignment_net = PWCNet(load_pretrained=True, weights_path='PATH_TO_PWCNET_WEIGHTS') alignment_net = alignment_net.to(device) aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40) scores_all = [] for idx in range(len(dataset)): burst, frame_gt, meta_info_burst, meta_info_gt = dataset[idx] burst = burst.unsqueeze(0).to(device) frame_gt = frame_gt.unsqueeze(0).to(device) net_pred = net(burst) # Calculate Aligned PSNR score = aligned_psnr_fn(net_pred, frame_gt, burst) scores_all.append(score) mean_psnr = sum(scores_all) / len(scores_all) print('Mean PSNR is {:0.3f}'.format(mean_psnr.item())) if __name__ == '__main__': main() ================================================ FILE: code/synthetic/bsrt/scripts/save_results_synburst_val.py ================================================ import torch.nn.functional as F import cv2 from datasets.synthetic_burst_val_set import SyntheticBurstVal import torch import numpy as np import os class SimpleBaseline: def __init__(self): pass def __call__(self, burst): burst_rgb = burst[:, 0, [0, 1, 3]] burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') return burst_rgb def main(): dataset = SyntheticBurstVal('PATH_TO_SyntheticBurstVal') out_dir = 'PATH_WHERE_RESULTS_ARE_SAVED' # TODO Set your network here net = SimpleBaseline() device = 'cuda' os.makedirs(out_dir, exist_ok=True) for idx in range(len(dataset)): burst, burst_name = dataset[idx] burst = burst.to(device).unsqueeze(0) with torch.no_grad(): net_pred = net(burst) # Normalize to 0 2^14 range and convert to numpy array net_pred_np = (net_pred.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16) # Save predictions as png cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np) if __name__ == '__main__': main() ================================================ FILE: code/synthetic/bsrt/scripts/test_burstsr_dataset.py ================================================ import torch.nn.functional as F import cv2 from datasets.burstsr_dataset import BurstSRDataset from torch.utils.data.dataloader import DataLoader from utils.metrics import AlignedPSNR from utils.postprocessing_functions import BurstSRPostProcess from utils.data_format_utils import convert_dict from pwcnet.pwcnet import PWCNet def main(): # Load dataset dataset = BurstSRDataset(root='PATH_TO_BURST_SR', split='val', burst_size=3, crop_sz=56, random_flip=False) data_loader = DataLoader(dataset, batch_size=2) # Load alignment network, used in AlignedPSNR alignment_net = PWCNet(load_pretrained=True, weights_path='PATH_TO_PWCNET_WEIGHTS') alignment_net = alignment_net.to('cuda') aligned_psnr_fn = AlignedPSNR(alignment_net=alignment_net, boundary_ignore=40) # Postprocessing function to obtain sRGB images postprocess_fn = BurstSRPostProcess(return_np=True) for d in data_loader: burst, frame_gt, meta_info_burst, meta_info_gt = d # A simple baseline which upsamples the base image using bilinear upsampling burst_rgb = burst[:, 0, [0, 1, 3]] burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') # Calculate Aligned PSNR score = aligned_psnr_fn(burst_rgb.cuda(), frame_gt.cuda(), burst.cuda()) print('PSNR is {:0.3f}'.format(score)) meta_info_gt = convert_dict(meta_info_gt, burst.shape[0]) # Apply simple post-processing to obtain RGB images pred_0 = postprocess_fn.process(burst_rgb[0], meta_info_gt[0]) gt_0 = postprocess_fn.process(frame_gt[0], meta_info_gt[0]) pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR) gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR) # Visualize input, ground truth cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0) cv2.imshow('GT', gt_0) input_key = cv2.waitKey(0) if input_key == ord('q'): return if __name__ == '__main__': main() ================================================ FILE: code/synthetic/bsrt/scripts/test_synthetic_bursts.py ================================================ import torch.nn.functional as F import cv2 from datasets.synthetic_burst_train_set import SyntheticBurst from torch.utils.data.dataloader import DataLoader from utils.metrics import PSNR from utils.postprocessing_functions import SimplePostProcess from utils.data_format_utils import convert_dict from datasets.zurich_raw2rgb_dataset import ZurichRAW2RGB def main(): zurich_raw2rgb = ZurichRAW2RGB(root='PATH_TO_ZURICH_RAW_TO_RGB', split='test') dataset = SyntheticBurst(zurich_raw2rgb, burst_size=3, crop_sz=256) data_loader = DataLoader(dataset, batch_size=2) # Function to calculate PSNR. Note that the boundary pixels (40 pixels) will be ignored during PSNR computation psnr_fn = PSNR(boundary_ignore=40) # Postprocessing function to obtain sRGB images postprocess_fn = SimplePostProcess(return_np=True) for d in data_loader: burst, frame_gt, flow_vectors, meta_info = d # A simple baseline which upsamples the base image using bilinear upsampling burst_rgb = burst[:, 0, [0, 1, 3]] burst_rgb = burst_rgb.view(-1, *burst_rgb.shape[-3:]) burst_rgb = F.interpolate(burst_rgb, scale_factor=8, mode='bilinear') # Calculate PSNR score = psnr_fn(burst_rgb, frame_gt) print('PSNR is {:0.3f}'.format(score)) meta_info = convert_dict(meta_info, burst.shape[0]) # Apply simple post-processing to obtain RGB images pred_0 = postprocess_fn.process(burst_rgb[0], meta_info[0]) gt_0 = postprocess_fn.process(frame_gt[0], meta_info[0]) pred_0 = cv2.cvtColor(pred_0, cv2.COLOR_RGB2BGR) gt_0 = cv2.cvtColor(gt_0, cv2.COLOR_RGB2BGR) # Visualize input, ground truth cv2.imshow('Input (Demosaicekd + Upsampled)', pred_0) cv2.imshow('GT', gt_0) input_key = cv2.waitKey(0) if input_key == ord('q'): return if __name__ == '__main__': main() ================================================ FILE: code/synthetic/bsrt/test.py ================================================ import cv2 import torch import numpy as np import os from tqdm import tqdm import random import utility from option import args from datasets.synthetic_burst_test_set import SyntheticBurstTest from datasets.burstsr_dataset import flatten_raw_image_batch, pack_raw_image_batch import model import torch.multiprocessing as mp import torch.backends.cudnn as cudnn import torch.utils.data.distributed import time checkpoint = utility.checkpoint(args) def ttaup(burst): burst0 = flatten_raw_image_batch(burst) # B, T, C, H, W burst1 = utility.bayer_aug(burst0, flip_h=False, flip_w=False, transpose=True) burst0 = pack_raw_image_batch(burst0) burst1 = pack_raw_image_batch(burst1) return [burst0, burst1] def ttadown(bursts): burst0 = bursts[0] burst1 = bursts[1].permute(0, 1, 3, 2) out = (burst0 + burst1) / 2 return out def main(): mp.spawn(main_worker, nprocs=1, args=(1, args)) def main_worker(local_rank, nprocs, args): device = 'cuda' cudnn.benchmark = True args.local_rank = local_rank utility.setup(local_rank, nprocs) torch.cuda.set_device(local_rank) dataset = SyntheticBurstTest(args.root) out_dir = 'bsrt_synburst' os.makedirs(out_dir, exist_ok=True) _model = model.Model(args, checkpoint) tt = [] for idx in tqdm(range(len(dataset))): burst, meta_info = dataset[idx] burst_name = meta_info['burst_name'] burst = burst.to(device).unsqueeze(0) bursts = ttaup(burst) srs = [] with torch.no_grad(): for x in bursts: tic = time.time() sr = _model(x, 0) toc = time.time() tt.append(toc-tic) srs.append(sr) sr = ttadown(srs) # Normalize to 0 2^14 range and convert to numpy array net_pred_np = (sr.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0) * 2 ** 14).cpu().numpy().astype(np.uint16) cv2.imwrite('{}/{}.png'.format(out_dir, burst_name), net_pred_np) print('avg time: {:.4f}'.format(np.mean(tt))) utility.cleanup() if __name__ == '__main__': main() ================================================ FILE: code/synthetic/bsrt/test_synburst.py ================================================ import cv2 import torch import numpy as np import os from tqdm import tqdm import random import utility from option import args from utils.postprocessing_functions import SimplePostProcess from datasets.burstsr_dataset import flatten_raw_image_batch, pack_raw_image, pack_raw_image_batch from datasets.synthetic_burst_val_set import SyntheticBurstVal from utils.metrics import PSNR from utils.data_format_utils import convert_dict from data_processing.camera_pipeline import demosaic import model import torch.multiprocessing as mp import torch.backends.cudnn as cudnn import torch.utils.data.distributed import time # from torchsummaryX import summary checkpoint = utility.checkpoint(args) def ttaup(burst): # burst0 = flatten_raw_image_batch(burst) # B, T, C, H, W # burst1 = utility.bayer_aug(burst0, flip_h=False, flip_w=False, transpose=True) # burst1 = pack_raw_image_batch(burst1) return [burst] def ttadown(bursts): burst0 = bursts[0] # burst1 = bursts[1].permute(0, 1, 3, 2) # out = (burst0 + burst1) / 2 out = burst0 return out def main(): mp.spawn(main_worker, nprocs=1, args=(1, args)) def main_worker(local_rank, nprocs, args): cudnn.benchmark = True args.local_rank = local_rank utility.setup(local_rank, nprocs) torch.cuda.set_device(local_rank) dataset = SyntheticBurstVal(root=args.root) out_dir = 'val/bsrt_synburst' _model = model.Model(args, checkpoint) for param in _model.parameters(): param.requires_grad = False psnr_fn = PSNR(boundary_ignore=40) postprocess_fn = SimplePostProcess(return_np=True) os.makedirs(out_dir, exist_ok=True) tt = [] psnrs, ssims, lpipss = [], [], [] for idx in tqdm(range(len(dataset))): burst_, gt, meta_info = dataset[idx] burst_ = burst_.unsqueeze(0).cuda() gt = gt.unsqueeze(0).cuda() name = meta_info['burst_name'] bursts = ttaup(burst_) srs = [] with torch.no_grad(): for x in bursts: tic = time.time() sr = _model(x, 0).float() toc = time.time() tt.append(toc-tic) srs.append(sr) sr = ttadown(srs) # sr_int = (sr.clamp(0.0, 1.0) * 2 ** 14).short() # sr = sr_int.float() / (2 ** 14) psnr, ssim, lpips = psnr_fn(sr, gt) psnrs.append(psnr.item()) ssims.append(ssim.item()) lpipss.append(lpips.item()) # lrs = burst_[0] # os.makedirs(f'{out_dir}/{name}', exist_ok=True) # for i, lr in enumerate(lrs): # # print(lr[[0, 1, 3],...].shape) # lr = postprocess_fn.process(lr[[0, 1, 3],...], meta_info) # lr = cv2.cvtColor(lr, cv2.COLOR_RGB2BGR) # cv2.imwrite('{}/{}/{:2d}.png'.format(out_dir, name, i), lr) # gt = postprocess_fn.process(gt[0], meta_info) # gt = cv2.cvtColor(gt, cv2.COLOR_RGB2BGR) # cv2.imwrite('{}/{}_gt.png'.format(out_dir, name), gt) # sr_ = postprocess_fn.process(sr[0], meta_info) # sr_ = cv2.cvtColor(sr_, cv2.COLOR_RGB2BGR) # cv2.imwrite('{}/{}_bsrt.png'.format(out_dir, name), sr_) del burst_ del sr del gt print(f'avg PSNR: {np.mean(psnrs):.6f}') print(f'avg SSIM: {np.mean(ssims):.6f}') print(f'avg LPIPS: {np.mean(lpipss):.6f}') print(f' avg time: {np.mean(tt):.6f}') # utility.cleanup() if __name__ == '__main__': main() ================================================ FILE: code/synthetic/bsrt/trainer.py ================================================ import os, sys from decimal import Decimal import cv2 import utility import random import torch from tensorboardX import SummaryWriter from utils.postprocessing_functions import SimplePostProcess from utils.data_format_utils import convert_dict from utils.metrics import PSNR, L1, L2, CharbonnierLoss, MSSSIMLoss from datasets.burstsr_dataset import pack_raw_image, flatten_raw_image_batch, pack_raw_image_batch from data_processing.camera_pipeline import demosaic from tqdm import tqdm import time from torch.cuda.amp import autocast as autocast, GradScaler train_log_dir = '../train_log/' exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] tfboard_name = exp_name + "_" exp_train_log_dir = os.path.join(train_log_dir, exp_name) LOG_DIR = os.path.join(exp_train_log_dir, 'logs') # save img path IMG_SAVE_DIR = os.path.join(exp_train_log_dir, 'img_log') # Where to load model LOAD_MODEL_DIR = os.path.join(exp_train_log_dir, 'models') # Where to save new model SAVE_MODEL_DIR = os.path.join(exp_train_log_dir, 'real_models') SAVE_STATE_DIR = os.path.join(exp_train_log_dir, 'training_states') # Where to save visualization images (for report) RESULTS_DIR = os.path.join(exp_train_log_dir, 'report') # print(SAVE_STATE_DIR) utility.mkdir(SAVE_STATE_DIR) utility.mkdir(SAVE_MODEL_DIR) utility.mkdir(IMG_SAVE_DIR) utility.mkdir(LOG_DIR) class Trainer(): def __init__(self, args, train_loader, train_sampler, valid_loader, my_model, my_loss, ckp): self.args = args self.scale = args.scale[0] self.ckp = ckp self.loader_train = train_loader self.loader_valid = valid_loader self.train_sampler = train_sampler self.model = my_model self.loss = my_loss self.optimizer = utility.make_optimizer(args, self.model) ################################### if args.pre_train == "": self.fix_unflagged = True else: self.fix_unflagged = False self.fix_epoch = 5 self.fix_keys = ["spynet", "dcnpack"] ################################### self.psnr_fn = PSNR(boundary_ignore=40) # Postprocessing function to obtain sRGB images self.postprocess_fn = SimplePostProcess(return_np=True) if 'L1' in args.loss: self.aligned_loss = L1(boundary_ignore=None).cuda(args.local_rank) elif 'MSE' in args.loss: self.aligned_loss = L2(boundary_ignore=None).cuda(args.local_rank) elif 'CB' in args.loss: self.aligned_loss = CharbonnierLoss(boundary_ignore=None).cuda(args.local_rank) elif 'MSSSIM' in args.loss: self.aligned_loss = MSSSIMLoss(boundary_ignore=None).cuda(args.local_rank) if self.args.fp16: self.scaler = GradScaler() self.best_psnr = 0. self.best_epoch = 0 self.error_last = 1e8 self.glob_iter = 0 self.log_dir = LOG_DIR + "/" + args.save self.img_save_dir = IMG_SAVE_DIR + "/" + args.save # Where to load model self.load_model_dir = LOAD_MODEL_DIR + "/" + args.save # Where to save new model self.save_model_dir = SAVE_MODEL_DIR + "/" + args.save self.save_state_dir = SAVE_STATE_DIR + "/" + args.save # Where to save visualization images (for report) self.results_dir = RESULTS_DIR + "/" + args.save if self.args.load != '': self.optimizer.load(self.save_state_dir, epoch=int(self.args.load)) utility.mkdir(self.save_state_dir) utility.mkdir(self.save_model_dir) utility.mkdir(self.img_save_dir) utility.mkdir(self.log_dir) utility.mkdir('frames') # self.writer = SummaryWriter(log_dir=self.log_dir) if self.args.local_rank <= 0: number_parameters = sum(map(lambda x: x.numel(), self.model.parameters())) print("number of parameters: ", number_parameters) def train(self): self.loss.step() epoch = self.optimizer.get_last_epoch() + 1 lr = self.optimizer.get_lr() if self.train_sampler: self.train_sampler.set_epoch(epoch) if epoch % 200 == 0: self.ckp.write_log( '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) ) self.loss.start_log() # train alignment module after 5 epochs. if self.args.pre_train == "": if self.fix_unflagged and epoch < self.fix_epoch: if self.args.local_rank <= 0: print(f'Fix keys: {self.fix_keys} for the first {self.fix_epoch} epochs.') self.fix_unflagged = False for name, param in self.model.named_parameters(): if any([key in name for key in self.fix_keys]): param.requires_grad_(False) elif epoch == self.fix_epoch: if self.args.local_rank <= 0: print(f'Train all the parameters from {self.fix_epoch} epochs.') self.model.requires_grad_(True) # self.test() self.model.train() if self.args.local_rank == 0: timer_data, timer_model, timer_epoch = utility.timer(), utility.timer(), utility.timer() timer_epoch.tic() for batch, batch_value in enumerate(self.loader_train): burst, gt, flow_vectors, meta_info = batch_value burst, gt, flow_vectors = self.prepare(burst, gt, flow_vectors) # burst = flatten_raw_image_batch(burst) if self.args.local_rank == 0: timer_data.hold() timer_model.tic() if self.args.fp16: with autocast(): sr = self.model(burst, 0) loss = self.aligned_loss(sr, gt) else: sr = self.model(burst, 0) loss = self.aligned_loss(sr, gt) if self.args.n_GPUs > 1: torch.distributed.barrier() reduced_loss = utility.reduce_mean(loss, self.args.n_GPUs) else: reduced_loss = loss self.optimizer.zero_grad() if self.args.fp16: self.scaler.scale(loss).backward() # torch.nn.utils.clip_grad_value_(self.model.parameters(), .02) if torch.isinf(sr).sum() + torch.isnan(sr).sum() <= 0: self.scaler.step(self.optimizer) self.scaler.update() else: print(f'Nan num: {torch.isnan(sr).sum()}, inf num: {torch.isinf(sr).sum()}') reduced_loss = None os._exit(0) sys.exit(0) else: loss.backward() # torch.nn.utils.clip_grad_value_(self.model.parameters(), .02) if torch.isinf(sr).sum() + torch.isnan(sr).sum() <= 0: self.optimizer.step() else: print(f'Nan num: {torch.isnan(sr).sum()}, inf num: {torch.isinf(sr).sum()}') reduced_loss = None if self.args.local_rank == 0: timer_model.hold() if (batch + 1) % self.args.print_every == 0: self.ckp.write_log('[{}/{}]\t[{:.4f}]\t{:.1f}+{:.1f}s'.format( (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), reduced_loss.item(), timer_model.release(), timer_data.release())) self.glob_iter += 1 timer_data.tic() if self.args.local_rank <= 0 and (batch + 1) % 2000 == 0: if not self.args.test_only: filename = exp_name + '_latest' + '.pth' self.save_model(filename) if self.args.local_rank <= 0: timer_epoch.hold() print('Epoch {} cost time: {:.1f}s, lr: {:5f}'.format(epoch, timer_epoch.release(), lr)) if (epoch) % 1 == 0 and not self.args.test_only: filename = exp_name + '_epoch_' + str(epoch) + '.pth' self.save_model(filename) if not self.args.test_only: filename = exp_name + '_latest' + '.pth' self.save_model(filename) torch.cuda.synchronize() torch.cuda.empty_cache() self.test() self.loss.end_log(len(self.loader_train)) self.error_last = self.loss.log[-1, -1] self.optimizer.schedule() def test(self, print_time=False): def ttaup(burst): # burst0 = flatten_raw_image_batch(burst) # B, T, C, H, W # burst1 = utility.bayer_aug(burst0, flip_h=False, flip_w=False, transpose=True) # burst1 = pack_raw_image_batch(burst1) return [burst] def ttadown(bursts): burst0 = bursts[0] # burst1 = bursts[1].permute(0, 1, 3, 2) # out = (burst0 + burst1) / 2 out = burst0 return out torch.set_grad_enabled(False) epoch = self.optimizer.get_last_epoch() + 1 self.model.eval() if self.args.local_rank == 0: timer_test = utility.timer() if epoch == 1 or epoch % 1 == 0: self.model.eval() total_psnr = 0 total_ssim = 0 total_lpips = 0 count = 0 if self.args.local_rank <= 0: print("Testing...") for i, batch_value in enumerate(self.loader_valid): burst_, gt, meta_info = batch_value burst_, gt = self.prepare(burst_, gt) bursts = ttaup(burst_) # burst_ = flatten_raw_image_batch(burst_) if print_time and self.args.local_rank <= 0: tic = time.time() with torch.no_grad(): srs = [] for burst in bursts: if self.args.fp16: with autocast(): sr = self.model(burst, 0).float() else: sr = self.model(burst, 0).float() srs.append(sr) sr = ttadown(srs) if print_time and self.args.local_rank <= 0: toc = time.time() print(f'model pass time: {toc-tic:.4f}') psnr_score, ssim_score, lpips_score = self.psnr_fn(sr, gt) if self.args.n_GPUs > 1: torch.distributed.barrier() psnr_score = utility.reduce_mean(psnr_score, self.args.n_GPUs) ssim_score = utility.reduce_mean(ssim_score, self.args.n_GPUs) lpips_score = utility.reduce_mean(lpips_score, self.args.n_GPUs) total_psnr += psnr_score total_ssim += ssim_score total_lpips += lpips_score count += 1 total_psnr = total_psnr / count total_ssim = total_ssim / count total_lpips = total_lpips / count if self.args.local_rank == 0: print("[Epoch: {}][PSNR: {:.4f}][SSIM: {:.4f}][LPIPS: {:.4f}][Best PSNR: {:.4f}][Best Epoch: {}]" .format(epoch, total_psnr, total_ssim, total_lpips, self.best_psnr, self.best_epoch)) if epoch > 1 and total_psnr > self.best_psnr: self.best_psnr = total_psnr self.best_epoch = epoch filename = exp_name + '_best_epoch.pth' self.save_model(filename) # self.writer.add_scalars('PSNR', {tfboard_name + '_PSNR': total_psnr}, self.glob_iter) print('Forward: {:.2f}s\n'.format(timer_test.toc())) torch.cuda.synchronize() torch.set_grad_enabled(True) torch.cuda.empty_cache() def save_model(self, filename): print('save model...') net_save_path = os.path.join(self.save_model_dir, filename) model = self.model.model if self.args.n_GPUs > 1: model = model.module # self.optimizer.save(self.save_state_dir) torch.save(model.state_dict(), net_save_path) def prepare(self, *args): device = torch.device('cpu' if self.args.cpu else 'cuda:{}'.format(self.args.local_rank)) def _prepare(tensor): if self.args.precision == 'half': tensor = tensor.half() return tensor.to(device) # print(_prepare(args[0]).device) return [_prepare(a) for a in args] def terminate(self): if self.args.test_only: self.test() return True else: epoch = self.optimizer.get_last_epoch() + 1 return epoch >= self.args.epochs ================================================ FILE: code/synthetic/bsrt/utility.py ================================================ import math import time import datetime from multiprocessing import Process from multiprocessing import Queue import torch import torch.nn.functional as F import matplotlib.pyplot as plt import torch.multiprocessing as mp import numpy as np import imageio import os import sys import torch.optim as optim import torch.optim.lr_scheduler as lrs import torch.distributed as dist import matplotlib matplotlib.use('Agg') def reduce_mean(tensor, nprocs): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= nprocs return rt def gradient(data): D_dy = data[:, :, 1:] - data[:, :, :-1] D_dx = data[:, :, :, 1:] - data[:, :, :, :-1] return D_dx, D_dy def smooth_grad_1st(flo, image, alpha): img_dx, img_dy = gradient(image) weights_x = torch.exp(-torch.mean(torch.abs(img_dx), 1, keepdims=True) * alpha) weights_y = torch.exp(-torch.mean(torch.abs(img_dy), 1, keepdims=True) * alpha) dx, dy = gradient(flo) loss_x = weights_x * torch.abs(dx) / 2.0 loss_y = weights_y * torch.abs(dy) / 2.0 return torch.mean(loss_x) / 2.0 + torch.mean(loss_y) / 2.0 def smooth_loss(flow, img): loss = smooth_grad_1st(flow, img, 10) return sum([torch.mean(loss)]) def setup(rank, world_size): if sys.platform == 'win32': # Distributed package only covers collective communications with Gloo # backend and FileStore on Windows platform. Set init_method parameter # in init_process_group to a local file. # Example init_method="file:///f:/libtmp/some_file" init_method = "tcp://localhost:1234" # initialize the process group dist.init_process_group( "gloo", init_method=init_method, rank=rank, world_size=world_size ) else: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '4321' # if mp.get_start_method(allow_none=True) is None: if ( mp.get_start_method(allow_none=True) != "spawn" ): # Return the name of start method used for starting processes mp.set_start_method("spawn", force=True) ##'spawn' is the default on Windows # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def mkdir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) class timer(): def __init__(self): self.acc = 0 self.tic() def tic(self): self.t0 = time.time() def toc(self, restart=False): diff = time.time() - self.t0 if restart: self.t0 = time.time() return diff def hold(self): self.acc += self.toc() def release(self): ret = self.acc self.acc = 0 return ret def reset(self): self.acc = 0 class checkpoint(): def __init__(self, args): self.args = args self.ok = True self.log = torch.Tensor() now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') if not args.load: if not args.save: args.save = now self.dir = os.path.join('..', 'experiment', args.save) else: self.dir = os.path.join('..', 'experiment', args.load) if os.path.exists(self.dir): self.log = torch.load(self.get_path('psnr_log.pt')) print('Continue from epoch {}...'.format(len(self.log))) else: args.load = '' if args.reset: os.system('rm -rf ' + self.dir) args.load = '' os.makedirs(self.dir, exist_ok=True) os.makedirs(self.get_path('model'), exist_ok=True) # for d in args.data_test: # os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w' self.log_file = open(self.get_path('log.txt'), open_type) with open(self.get_path('config.txt'), open_type) as f: f.write(now + '\n\n') for arg in vars(args): f.write('{}: {}\n'.format(arg, getattr(args, arg))) f.write('\n') self.n_processes = 8 def get_path(self, *subdir): return os.path.join(self.dir, *subdir) def save(self, trainer, epoch, is_best=False): trainer.model.save(self.get_path('model'), epoch, is_best=is_best) trainer.loss.save(self.dir) trainer.loss.plot_loss(self.dir, epoch) self.plot_psnr(epoch) trainer.optimizer.save(self.dir) torch.save(self.log, self.get_path('psnr_log.pt')) def add_log(self, log): self.log = torch.cat([self.log, log]) def write_log(self, log, refresh=False): print(log) self.log_file.write(log + '\n') if refresh: self.log_file.close() self.log_file = open(self.get_path('log.txt'), 'a') def done(self): self.log_file.close() def plot_psnr(self, epoch): axis = np.linspace(1, epoch, epoch) for idx_data, d in enumerate(self.args.data_test): label = 'SR on {}'.format(d) fig = plt.figure() plt.title(label) for idx_scale, scale in enumerate(self.args.scale): plt.plot( axis, self.log[:, idx_data, idx_scale].numpy(), label='Scale {}'.format(scale) ) plt.legend() plt.xlabel('Epochs') plt.ylabel('PSNR') plt.grid(True) plt.savefig(self.get_path('test_{}.pdf'.format(d))) plt.close(fig) def begin_background(self): self.queue = Queue() def bg_target(queue): while True: if not queue.empty(): filename, tensor = queue.get() if filename is None: break imageio.imwrite(filename, tensor.numpy()) self.process = [ Process(target=bg_target, args=(self.queue,)) \ for _ in range(self.n_processes) ] for p in self.process: p.start() def end_background(self): for _ in range(self.n_processes): self.queue.put((None, None)) while not self.queue.empty(): time.sleep(1) for p in self.process: p.join() def save_results(self, dataset, filename, save_list, scale): if self.args.save_results: filename = self.get_path( 'results-{}'.format(dataset.dataset.name), '{}_x{}_'.format(filename, scale) ) postfix = ('SR', 'LR', 'HR') for v, p in zip(save_list, postfix): normalized = v[0].mul(255 / self.args.rgb_range) tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) def quantize(img, rgb_range): pixel_range = 255 / rgb_range return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) def calc_psnr(sr, hr, scale, rgb_range, dataset=None): if hr.nelement() == 1: return 0 diff = (sr - hr) / rgb_range if dataset and dataset.dataset.benchmark: shave = scale if diff.size(1) > 1: gray_coeffs = [65.738, 129.057, 25.064] convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 diff = diff.mul(convert).sum(dim=1) else: shave = scale + 6 valid = diff[..., shave:-shave, shave:-shave] mse = valid.pow(2).mean() return -10 * math.log10(mse) def make_optimizer(args, target): ''' make optimizer and scheduler together ''' # optimizer trainable = filter(lambda x: x.requires_grad, target.parameters()) kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} if args.optimizer == 'SGD': optimizer_class = optim.SGD kwargs_optimizer['momentum'] = args.momentum elif args.optimizer == 'ADAM': optimizer_class = optim.Adam kwargs_optimizer['betas'] = args.betas kwargs_optimizer['eps'] = args.epsilon elif args.optimizer == 'RMSprop': optimizer_class = optim.RMSprop kwargs_optimizer['eps'] = args.epsilon # scheduler milestones = list(map(lambda x: int(x), args.decay.split('-'))) kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} scheduler_class = lrs.MultiStepLR class CustomOptimizer(optimizer_class): def __init__(self, *args, **kwargs): super(CustomOptimizer, self).__init__(*args, **kwargs) def _register_scheduler(self, scheduler_class, **kwargs): self.scheduler = scheduler_class(self, **kwargs) def save(self, save_dir): torch.save(self.state_dict(), self.get_dir(save_dir)) def load(self, load_dir, epoch=1): self.load_state_dict(torch.load(self.get_dir(load_dir))) if epoch > 1: for _ in range(epoch): self.scheduler.step() def get_dir(self, dir_path): return os.path.join(dir_path, 'optimizer.pt') def schedule(self): self.scheduler.step() def get_lr(self): return self.scheduler.get_last_lr()[0] def get_last_epoch(self): return self.scheduler.last_epoch optimizer = CustomOptimizer(trainable, **kwargs_optimizer) optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) return optimizer def write_gray_to_tfboard(img): img_debug = img[0, ...].detach().cpu().numpy() # img_debug = cv2.normalize(img_debug, None, 0, 255, # cv2.NORM_MINMAX, cv2.CV_8U) img_debug = img_debug * 255 img_debug = np.clip(img_debug, 0, 255) img_debug = img_debug.astype(np.uint8) return img_debug[0, ...] ######################## BayerUnifyAug ############################ BAYER_PATTERNS = ["RGGB", "BGGR", "GRBG", "GBRG"] NORMALIZATION_MODE = ["crop", "pad"] def bayer_unify(raw, input_pattern, target_pattern, mode) -> np.ndarray: """ Convert a bayer raw image from one bayer pattern to another. mode: {"crop", "pad"} The way to handle submosaic shift. "crop" abandons the outmost pixels, and "pad" introduces extra pixels. Use "crop" in training and "pad" in testing. """ if input_pattern == target_pattern: h_offset, w_offset = 0, 0 elif input_pattern[0] == target_pattern[2] and input_pattern[1] == target_pattern[3]: h_offset, w_offset = 1, 0 elif input_pattern[0] == target_pattern[1] and input_pattern[2] == target_pattern[3]: h_offset, w_offset = 0, 1 elif input_pattern[0] == target_pattern[3] and input_pattern[1] == target_pattern[2]: h_offset, w_offset = 1, 1 else: # This is not happening in ["RGGB", "BGGR", "GRBG", "GBRG"] raise RuntimeError('Unexpected pair of input and target bayer pattern!') if mode == "pad": # out = np.pad(raw, [[h_offset, h_offset], [w_offset, w_offset]], 'reflect') out = F.pad(raw, (w_offset, w_offset, h_offset, h_offset), mode='reflect') elif mode == "crop": _, _, _, h, w = raw.shape out = raw[..., h_offset:h - h_offset, w_offset:w - w_offset] else: raise ValueError('Unknown normalization mode!') return out def bayer_aug(raw, flip_h=False, flip_w=False, transpose=False, input_pattern='RGGB') -> np.ndarray: """ Apply augmentation to a bayer raw image. """ aug_pattern, target_pattern = input_pattern, input_pattern out = raw if flip_h: out = torch.flip(out, [3]) # GBRG, RGGB aug_pattern = aug_pattern[2] + aug_pattern[3] + aug_pattern[0] + aug_pattern[1] if flip_w: out = torch.flip(out, [4]) aug_pattern = aug_pattern[1] + aug_pattern[0] + aug_pattern[3] + aug_pattern[2] if transpose: out = out.permute(0, 1, 2, 4, 3) aug_pattern = aug_pattern[0] + aug_pattern[2] + aug_pattern[1] + aug_pattern[3] out = bayer_unify(out, aug_pattern, target_pattern, "crop") return out ================================================ FILE: code/synthetic/bsrt/utils/__init__.py ================================================ ================================================ FILE: code/synthetic/bsrt/utils/data_format_utils.py ================================================ import numpy as np import torch import cv2 as cv def numpy_to_torch(a: np.ndarray): return torch.from_numpy(a).float().permute(2, 0, 1) def torch_to_numpy(a: torch.Tensor): return a.permute(1, 2, 0).cpu().numpy() def torch_to_npimage(a: torch.Tensor, unnormalize=True): a_np = torch_to_numpy(a) if unnormalize: a_np = a_np * 255 a_np = a_np.astype(np.uint8) return cv.cvtColor(a_np, cv.COLOR_RGB2BGR) def npimage_to_torch(a, normalize=True, input_bgr=True): if input_bgr: a = cv.cvtColor(a, cv.COLOR_BGR2RGB) a_t = numpy_to_torch(a) if normalize: a_t = a_t / 255.0 return a_t def convert_dict(base_dict, batch_sz): out_dict = [] for b_elem in range(batch_sz): b_info = {} for k, v in base_dict.items(): if isinstance(v, (list, torch.Tensor)): b_info[k] = v[b_elem] out_dict.append(b_info) return out_dict ================================================ FILE: code/synthetic/bsrt/utils/debayer.py ================================================ import torch import torch.nn import torch.nn.functional class Debayer3x3(torch.nn.Module): '''Demosaicing of Bayer images using 3x3 convolutions. Requires BG-Bayer color filter array layout. That is, the image[1,1]='B', image[1,2]='G'. This corresponds to OpenCV naming conventions. Compared to Debayer2x2 this method does not use upsampling. Instead, we identify five 3x3 interpolation kernels that are sufficient to reconstruct every color channel at every pixel location. We convolve the image with these 5 kernels using stride=1 and a one pixel replication padding. Finally, we gather the correct channel values for each pixel location. Todo so, we recognize that the Bayer pattern repeats horizontally and vertically every 2 pixels. Therefore, we define the correct index lookups for a 2x2 grid cell and then repeat to image dimensions. Note, in every 2x2 grid cell we have red, blue and two greens (G1,G2). The lookups for the two greens differ. ''' def __init__(self): super(Debayer3x3, self).__init__() self.kernels = torch.nn.Parameter( torch.tensor([ [0,0,0], [0,1,0], [0,0,0], [0, 0.25, 0], [0.25, 0, 0.25], [0, 0.25, 0], [0.25, 0, 0.25], [0, 0, 0], [0.25, 0, 0.25], [0, 0, 0], [0.5, 0, 0.5], [0, 0, 0], [0, 0.5, 0], [0, 0, 0], [0, 0.5, 0], ]).view(5,1,3,3), requires_grad=False ) self.index = torch.nn.Parameter( torch.tensor([ # dest channel r [0, 3], # pixel is R,G1 [4, 2], # pixel is G2,B # dest channel g [1, 0], # pixel is R,G1 [0, 1], # pixel is G2,B # dest channel b [2, 4], # pixel is R,G1 [3, 0], # pixel is G2,B ]).view(1,3,2,2), requires_grad=False ) def forward(self, x): '''Debayer image. Parameters ---------- x : Bx1xHxW tensor Images to debayer Returns ------- rgb : Bx3xHxW tensor Color images in RGB channel order. ''' B,C,H,W = x.shape x = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') c = torch.nn.functional.conv2d(x, self.kernels, stride=1) rgb = torch.gather(c, 1, self.index.repeat(B,1,H//2,W//2)) return rgb class Debayer2x2(torch.nn.Module): '''Demosaicing of Bayer images using 2x2 convolutions. Requires BG-Bayer color filter array layout. That is, the image[1,1]='B', image[1,2]='G'. This corresponds to OpenCV naming conventions. ''' def __init__(self): super(Debayer2x2, self).__init__() self.kernels = torch.nn.Parameter( torch.tensor([ [1, 0], [0, 0], [0, 0.5], [0.5, 0], [0, 0], [0, 1], ]).view(3,1,2,2), requires_grad=False ) def forward(self, x): '''Debayer image. Parameters ---------- x : Bx1xHxW tensor Images to debayer Returns ------- rgb : Bx3xHxW tensor Color images in RGB channel order. ''' x = torch.nn.functional.conv2d(x, self.kernels, stride=2) x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) return x class DebayerSplit(torch.nn.Module): '''Demosaicing of Bayer images using 3x3 green convolution and red,blue upsampling. Requires BG-Bayer color filter array layout. That is, the image[1,1]='B', image[1,2]='G'. This corresponds to OpenCV naming conventions. ''' def __init__(self): super().__init__() self.pad = torch.nn.ReflectionPad2d(1) self.kernel = torch.nn.Parameter( torch.tensor([ [0,1,0], [1,0,1], [0,1,0] ])[None, None] * 0.25) def forward(self, x): '''Debayer image. Parameters ---------- x : Bx1xHxW tensor Images to debayer Returns ------- rgb : Bx3xHxW tensor Color images in RGB channel order. ''' B,_,H,W = x.shape red = x[:, :, ::2, ::2] blue = x[:, :, 1::2, 1::2] green = torch.nn.functional.conv2d(self.pad(x), self.kernel) green[:, :, ::2, 1::2] = x[:, :, ::2, 1::2] green[:, :, 1::2, ::2] = x[:, :, 1::2, ::2] return torch.cat(( torch.nn.functional.interpolate(red, size=(H, W), mode='bilinear', align_corners=False), green, torch.nn.functional.interpolate(blue, size=(H, W), mode='bilinear', align_corners=False)), dim=1) ================================================ FILE: code/synthetic/bsrt/utils/interp_methods.py ================================================ from math import pi try: import torch except ImportError: torch = None try: import numpy except ImportError: numpy = None if numpy is None and torch is None: raise ImportError("Must have either Numpy or PyTorch but both not found") def set_framework_dependencies(x): if type(x) is numpy.ndarray: to_dtype = lambda a: a fw = numpy else: to_dtype = lambda a: a.to(x.dtype) fw = torch eps = fw.finfo(fw.float32).eps return fw, to_dtype, eps def support_sz(sz): def wrapper(f): f.support_sz = sz return f return wrapper @support_sz(4) def cubic(x): fw, to_dtype, eps = set_framework_dependencies(x) absx = fw.abs(x) absx2 = absx ** 2 absx3 = absx ** 3 return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * to_dtype((1. < absx) & (absx <= 2.))) @support_sz(4) def lanczos2(x): fw, to_dtype, eps = set_framework_dependencies(x) return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) @support_sz(6) def lanczos3(x): fw, to_dtype, eps = set_framework_dependencies(x) return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) @support_sz(2) def linear(x): fw, to_dtype, eps = set_framework_dependencies(x) return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * to_dtype((0 <= x) & (x <= 1))) @support_sz(1) def box(x): fw, to_dtype, eps = set_framework_dependencies(x) return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) ================================================ FILE: code/synthetic/bsrt/utils/metrics.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F import utils.spatial_color_alignment as sca_utils from utils.spatial_color_alignment import get_gaussian_kernel, match_colors from utils.warp import warp from torch.cuda.amp import autocast from loss.Charbonnier import CharbonnierLoss as CBLoss from loss.mssim import MSSSIM from pytorch_msssim import ssim import lpips class MSSSIMLoss(nn.Module): def __init__(self, boundary_ignore=None): super().__init__() self.boundary_ignore = boundary_ignore self.msssim = MSSSIM() def forward(self, pred, gt, valid=None): if self.boundary_ignore is not None: pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_m = pred gt_m = gt loss = self.msssim(pred_m, gt_m) return loss class CharbonnierLoss(nn.Module): def __init__(self, boundary_ignore=None): super().__init__() self.boundary_ignore = boundary_ignore self.charbonnier_loss = CBLoss(reduce=True) def forward(self, pred, gt, valid=None): if self.boundary_ignore is not None: pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_m = pred gt_m = gt loss = self.charbonnier_loss(pred_m, gt_m) return loss class L1(nn.Module): def __init__(self, boundary_ignore=None): super().__init__() self.boundary_ignore = boundary_ignore def forward(self, pred, gt, valid=None): if self.boundary_ignore is not None: pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] if valid is not None: valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_m = pred gt_m = gt if valid is None: mse = F.l1_loss(pred_m, gt_m) else: mse = F.l1_loss(pred_m, gt_m, reduction='none') eps = 1e-12 elem_ratio = mse.numel() / valid.numel() mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) return mse class L2(nn.Module): def __init__(self, boundary_ignore=None): super().__init__() self.boundary_ignore = boundary_ignore self.loss_fn = lpips.LPIPS(net='alex').cuda() def forward(self, pred, gt, valid=None): if self.boundary_ignore is not None: pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] if valid is not None: valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_m = pred gt_m = gt if valid is None: mse = F.mse_loss(pred_m, gt_m) else: mse = F.mse_loss(pred_m, gt_m, reduction='none') eps = 1e-12 elem_ratio = mse.numel() / valid.numel() mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) ss = ssim(pred_m.contiguous(), gt_m.contiguous(), data_range=1.0, size_average=True) lp = self.loss_fn(pred_m.contiguous(), gt_m.contiguous()).squeeze() return mse, ss, lp class PSNR(nn.Module): def __init__(self, boundary_ignore=None, max_value=1.0): super().__init__() self.l2 = L2(boundary_ignore=boundary_ignore) self.max_value = max_value def psnr(self, pred, gt, valid=None): mse, ss, lp = self.l2(pred, gt, valid=valid) psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10() return psnr, ss, lp def forward(self, pred, gt, valid=None): assert pred.dim() == 4 and pred.shape == gt.shape if valid is None: all_scores = [self.psnr(p.unsqueeze(0), g.unsqueeze(0)) for p, g in zip(pred, gt)] else: all_scores = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), v.unsqueeze(0)) for p, g, v in zip(pred, gt, valid)] # psnr, ss, lp = sum(psnr_all) / len(psnr_all) psnr = sum([score[0] for score in all_scores]) / len(all_scores) ssim_ = sum([score[1] for score in all_scores]) / len(all_scores) lpips_ = sum([score[2] for score in all_scores]) / len(all_scores) return psnr, ssim_, lpips_ class AlignedL1(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None): super().__init__() self.sr_factor = sr_factor self.boundary_ignore = boundary_ignore self.alignment_net = alignment_net self.gauss_kernel, self.ksz = get_gaussian_kernel(sd=1.5) def forward(self, pred, gt, burst_input): # Estimate flow between the prediction and the ground truth with torch.no_grad(): flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6)) # Warp the prediction to the ground truth coordinates pred_warped = warp(pred, flow) # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between # the input and the ground truth sr_factor = self.sr_factor ds_factor = 1.0 / float(2.0 * sr_factor) flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous() burst_0_warped = warp(burst_0, flow_ds) frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) # Match the colorspace between the prediction and ground truth pred_warped_m, valid = match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz, self.gauss_kernel) # Ignore boundary pixels if specified if self.boundary_ignore is not None: pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] pred_warped_m = pred_warped_m.contiguous() gt = gt.contiguous() # Estimate MSE l1 = F.l1_loss(pred_warped_m, gt, reduction='none') eps = 1e-12 elem_ratio = l1.numel() / valid.numel() l1 = (l1 * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) return l1 class AlignedL2(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None): super().__init__() self.sr_factor = sr_factor self.boundary_ignore = boundary_ignore self.alignment_net = alignment_net self.loss_fn = lpips.LPIPS(net='alex').cuda() self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5) def forward(self, pred, gt, burst_input): # Estimate flow between the prediction and the ground truth with torch.no_grad(): flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6)) # Warp the prediction to the ground truth coordinates pred_warped = warp(pred, flow) # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between # the input and the ground truth sr_factor = self.sr_factor ds_factor = 1.0 / float(2.0 * sr_factor) flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous() burst_0_warped = warp(burst_0, flow_ds) frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) # Match the colorspace between the prediction and ground truth pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz, self.gauss_kernel) # Ignore boundary pixels if specified if self.boundary_ignore is not None: pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] # Estimate MSE mse = F.mse_loss(pred_warped_m.contiguous(), gt.contiguous(), reduction='none') eps = 1e-12 elem_ratio = mse.numel() / valid.numel() mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) ss = ssim(pred_warped_m.contiguous(), gt.contiguous(), data_range=1.0, size_average=True) # eps = 1e-12 # elem_ratio = ss.numel() / valid.numel() # ss = (ss * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) lp = self.loss_fn(pred_warped_m.contiguous(), gt.contiguous()).squeeze() return mse, ss, lp class AlignedPSNR(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None, max_value=1.0): super().__init__() self.l2 = AlignedL2(alignment_net=alignment_net, sr_factor=sr_factor, boundary_ignore=boundary_ignore) self.max_value = max_value def psnr(self, pred, gt, burst_input): mse, ss, lp = self.l2(pred, gt, burst_input) psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10() return psnr, ss, lp def forward(self, pred, gt, burst_input): all_scores = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)] psnr = sum([score[0] for score in all_scores]) / len(all_scores) ssim_ = sum([score[1] for score in all_scores]) / len(all_scores) lpips_ = sum([score[2] for score in all_scores]) / len(all_scores) return psnr, ssim_, lpips_ class AlignedSSIM(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None): super().__init__() self.sr_factor = sr_factor self.boundary_ignore = boundary_ignore self.alignment_net = alignment_net self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5) def _ssim(self, pred, gt, burst_input): # Estimate flow between the prediction and the ground truth with torch.no_grad(): flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6)) # Warp the prediction to the ground truth coordinates pred_warped = warp(pred, flow) # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between # the input and the ground truth sr_factor = self.sr_factor ds_factor = 1.0 / float(2.0 * sr_factor) flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous() burst_0_warped = warp(burst_0, flow_ds) frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) # Match the colorspace between the prediction and ground truth pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz, self.gauss_kernel) # Ignore boundary pixels if specified if self.boundary_ignore is not None: pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] # Estimate MSE mse = ssim(pred_warped_m.contiguous(), gt.contiguous(), data_range=1.0, size_average=True) # print(mse.shape) # eps = 1e-12 # elem_ratio = mse.numel() / valid.numel() # mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps) return mse def forward(self, pred, gt, burst_input): ssim_all = [self._ssim(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)] _ssim = sum(ssim_all) / len(ssim_all) return _ssim class AlignedLPIPS(nn.Module): def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None): super().__init__() self.sr_factor = sr_factor self.boundary_ignore = boundary_ignore self.alignment_net = alignment_net self.loss_fn = lpips.LPIPS(net='alex').cuda() self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5) def _lpips(self, pred, gt, burst_input): # Estimate flow between the prediction and the ground truth with torch.no_grad(): flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6)) # Warp the prediction to the ground truth coordinates pred_warped = warp(pred, flow) # Warp the base input frame to the ground truth. This will be used to estimate the color transformation between # the input and the ground truth sr_factor = self.sr_factor ds_factor = 1.0 / float(2.0 * sr_factor) flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) * ds_factor burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous() burst_0_warped = warp(burst_0, flow_ds) frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False) # Match the colorspace between the prediction and ground truth pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz, self.gauss_kernel) # Ignore boundary pixels if specified if self.boundary_ignore is not None: pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore] # Estimate MSE mse = self.loss_fn(pred_warped_m.contiguous(), gt.contiguous()).squeeze() return mse def forward(self, pred, gt, burst_input): lpips_all = [self._lpips(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)] _lpips = sum(lpips_all) / len(lpips_all) return _lpips ================================================ FILE: code/synthetic/bsrt/utils/postprocessing_functions.py ================================================ import torch import numpy as np import utils.data_format_utils as df_utils from data_processing.camera_pipeline import apply_gains, apply_ccm, apply_smoothstep, gamma_compression class SimplePostProcess: def __init__(self, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False): self.gains = gains self.ccm = ccm self.gamma = gamma self.smoothstep = smoothstep self.return_np = return_np def process(self, image, meta_info): return process_linear_image_rgb(image, meta_info, self.gains, self.ccm, self.gamma, self.smoothstep, self.return_np) def process_linear_image_rgb(image, meta_info, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False): if gains: image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain']) if ccm: image = apply_ccm(image, meta_info['cam2rgb']) if meta_info['gamma'] and gamma: image = gamma_compression(image) if meta_info['smoothstep'] and smoothstep: image = apply_smoothstep(image) image = image.clamp(0.0, 1.0) if return_np: image = df_utils.torch_to_npimage(image) return image class BurstSRPostProcess: def __init__(self, no_white_balance=False, gamma=True, smoothstep=True, return_np=False): self.no_white_balance = no_white_balance self.gamma = gamma self.smoothstep = smoothstep self.return_np = return_np def process(self, image, meta_info, external_norm_factor=None): return process_burstsr_image_rgb(image, meta_info, external_norm_factor=external_norm_factor, no_white_balance=self.no_white_balance, gamma=self.gamma, smoothstep=self.smoothstep, return_np=self.return_np) def process_burstsr_image_rgb(im, meta_info, return_np=False, external_norm_factor=None, gamma=True, smoothstep=True, no_white_balance=False): im = im * meta_info.get('norm_factor', 1.0) if not meta_info.get('black_level_subtracted', False): im = (im - torch.tensor(meta_info['black_level'])[[0, 1, -1]].view(3, 1, 1).to(im.device)) if not meta_info.get('while_balance_applied', False) and not no_white_balance: im = im * (meta_info['cam_wb'][[0, 1, -1]].view(3, 1, 1) / meta_info['cam_wb'][1]).to(im.device) im_out = im if external_norm_factor is None: im_out = im_out / im_out.max() else: im_out = im_out / external_norm_factor im_out = im_out.clamp(0.0, 1.0) if gamma: im_out = im_out ** (1.0 / 2.2) if smoothstep: # Smooth curve im_out = 3 * im_out ** 2 - 2 * im_out ** 3 if return_np: im_out = im_out.permute(1, 2, 0).cpu().numpy() * 255.0 im_out = im_out.astype(np.uint8) return im_out ================================================ FILE: code/synthetic/bsrt/utils/resize_right.py ================================================ import warnings from math import ceil import interp_methods class NoneClass: pass try: import torch from torch import nn nnModuleWrapped = nn.Module except ImportError: warnings.warn('No PyTorch found, will work only with Numpy') torch = None nnModuleWrapped = NoneClass try: import numpy except ImportError: warnings.warn('No Numpy found, will work only with PyTorch') numpy = None if numpy is None and torch is None: raise ImportError("Must have either Numpy or PyTorch but both not found") def resize(input, scale_factors=None, out_shape=None, interp_method=interp_methods.cubic, support_sz=None, antialiasing=True): # get properties of the input tensor in_shape, n_dims = input.shape, input.ndim # fw stands for framework that can be either numpy or torch, # determined by the input type fw = numpy if type(input) is numpy.ndarray else torch eps = fw.finfo(fw.float32).eps # set missing scale factors or output shapem one according to another, # scream if both missing scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw) # sort indices of dimensions according to scale of each dimension. # since we are going dim by dim this is efficient sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind]) if scale_factors[dim] != 1.] # unless support size is specified by the user, it is an attribute # of the interpolation method if support_sz is None: support_sz = interp_method.support_sz # when using pytorch, we need to know what is the input tensor device if fw is torch: device = input.device # output begins identical to input and changes with each iteration output = input # iterate over dims for dim, scale_factor in sorted_filtered_dims_and_scales: # get 1d set of weights and fields of view for each output location # along this dim field_of_view, weights = prepare_weights_and_field_of_view_1d( dim, scale_factor, in_shape[dim], out_shape[dim], interp_method, support_sz, antialiasing, fw, eps, device) # multiply the weights by the values in the field of view and # aggreagate output = apply_weights(output, field_of_view, weights, dim, n_dims, fw) return output class ResizeLayer(nnModuleWrapped): def __init__(self, in_shape, scale_factors=None, out_shape=None, interp_method=interp_methods.cubic, support_sz=None, antialiasing=True): super(ResizeLayer, self).__init__() # fw stands for framework, that can be either numpy or torch. since # this is a torch layer, only one option in this case. fw = torch eps = fw.finfo(fw.float32).eps # set missing scale factors or output shapem one according to another, # scream if both missing scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw) # unless support size is specified by the user, it is an attribute # of the interpolation method if support_sz is None: support_sz = interp_method.support_sz self.n_dims = len(in_shape) # sort indices of dimensions according to scale of each dimension. # since we are going dim by dim this is efficient self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) for dim in sorted(range(self.n_dims), key=lambda ind: scale_factors[ind]) if scale_factors[dim] != 1.] # iterate over dims field_of_view_list = [] weights_list = [] for dim, scale_factor in self.sorted_filtered_dims_and_scales: # get 1d set of weights and fields of view for each output # location along this dim field_of_view, weights = prepare_weights_and_field_of_view_1d( dim, scale_factor, in_shape[dim], out_shape[dim], interp_method, support_sz, antialiasing, fw, eps, input.device) # keep weights and fields of views for all dims weights_list.append(nn.Parameter(weights, requires_grad=False)) field_of_view_list.append(nn.Parameter(field_of_view, requires_grad=False)) self.field_of_view = nn.ParameterList(field_of_view_list) self.weights = nn.ParameterList(weights_list) self.in_shape = in_shape def forward(self, input): # output begins identical to input and changes with each iteration output = input for (dim, scale_factor), field_of_view, weights in zip( self.sorted_filtered_dims_and_scales, self.field_of_view, self.weights): # multiply the weights by the values in the field of view and # aggreagate output = apply_weights(output, field_of_view, weights, dim, self.n_dims, torch) return output def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz, interp_method, support_sz, antialiasing, fw, eps, device=None): # If antialiasing is taking place, we modify the window size and the # interpolation method (see inside function) interp_method, cur_support_sz = apply_antialiasing_if_needed( interp_method, support_sz, scale_factor, antialiasing) # STEP 1- PROJECTED GRID: The non-integer locations of the projection of # output pixel locations to the input tensor projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device) # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels # that influence it field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps) # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in the # field of view for each output pixel weights = get_weights(interp_method, projected_grid, field_of_view) return field_of_view, weights def apply_weights(input, field_of_view, weights, dim, n_dims, fw): # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying # its set of weights with the pixel values in its field of view. # We now multiply the fields of view with their matching weights. # We do this by tensor multiplication and broadcasting. # this step is separated to a different function, so that it can be # repeated with the same calculated weights and fields. # for this operations we assume the resized dim is the first one. # so we transpose and will transpose back after multiplying tmp_input = fw_swapaxes(input, dim, 0, fw) # field_of_view is a tensor of order 2: for each output (1d location # along cur dim)- a list of 1d neighbors locations. # note that this whole operations is applied to each dim separately, # this is why it is all in 1d. # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1: # for each output pixel (this time indicated in all dims), these are the # values of the neighbors in the 1d field of view. note that we only # consider neighbors along the current dim, but such set exists for every # multi-dim location, hence the final tensor order is image_dims+1. neighbors = tmp_input[field_of_view] # weights is an order 2 tensor: for each output location along 1d- a list # of weighs matching the field of view. we augment it with ones, for # broadcasting, so that when multiplies some tensor the weights affect # only its first dim. tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1))) # now we simply multiply the weights with the neighbors, and then sum # along the field of view, to get a single value per out pixel tmp_output = (neighbors * tmp_weights).sum(1) # we transpose back the resized dim to its original position return fw_swapaxes(tmp_output, 0, dim, fw) def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw): # eventually we must have both scale-factors and out-sizes for all in/out # dims. however, we support many possible partial arguments if scale_factors is None and out_shape is None: raise ValueError("either scale_factors or out_shape should be " "provided") if out_shape is not None: # if out_shape has less dims than in_shape, we defaultly resize the # first dims for numpy and last dims for torch out_shape = (list(out_shape) + list(in_shape[:-len(out_shape)]) if fw is numpy else list(in_shape[:-len(out_shape)]) + list(out_shape)) if scale_factors is None: # if no scale given, we calculate it as the out to in ratio # (not recomended) scale_factors = [out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape)] if scale_factors is not None: # by default, if a single number is given as scale, we assume resizing # two dims (most common are images with 2 spatial dims) scale_factors = (scale_factors if isinstance(scale_factors, (list, tuple)) else [scale_factors, scale_factors]) # if less scale_factors than in_shape dims, we defaultly resize the # first dims for numpy and last dims for torch scale_factors = (list(scale_factors) + [1] * (len(in_shape) - len(scale_factors)) if fw is numpy else [1] * (len(in_shape) - len(scale_factors)) + list(scale_factors)) if out_shape is None: # when no out_shape given, it is calculated by multiplying the # scale by the in_shape (not recomended) out_shape = [ceil(scale_factor * in_sz) for scale_factor, in_sz in zip(scale_factors, in_shape)] # next line intentionally after out_shape determined for stability scale_factors = [float(sf) for sf in scale_factors] return scale_factors, out_shape def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None): # we start by having the ouput coordinates which are just integer locations out_coordinates = fw.arange(out_sz) # if using torch we need to match the grid tensor device to the input device out_coordinates = fw_set_device(out_coordinates, device, fw) # This is projecting the ouput pixel locations in 1d to the input tensor, # as non-integer locations. # the following fomrula is derived in the paper # "From Discrete to Continuous Convolutions" by Shocher et al. return (out_coordinates / scale_factor + (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor)) def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps): # for each output pixel, map which input pixels influence it, in 1d. # we start by calculating the leftmost neighbor, using half of the window # size (eps is for when boundary is exact int) left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) # then we simply take all the pixel centers in the field by counting # window size pixels from the left boundary ordinal_numbers = fw.arange(ceil(cur_support_sz - eps)) # in case using torch we need to match the device ordinal_numbers = fw_set_device(ordinal_numbers, projected_grid.device, fw) field_of_view = left_boundaries[:, None] + ordinal_numbers # next we do a trick instead of padding, we map the field of view so that # it would be like mirror padding, without actually padding # (which would require enlarging the input tensor) mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw) field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])] field_of_view = fw_set_device(field_of_view,projected_grid.device, fw) return field_of_view def get_weights(interp_method, projected_grid, field_of_view): # the set of weights per each output pixels is the result of the chosen # interpolation method applied to the distances between projected grid # locations and the pixel-centers in the field of view (distances are # directed, can be positive or negative) weights = interp_method(projected_grid[:, None] - field_of_view) # we now carefully normalize the weights to sum to 1 per each output pixel sum_weights = weights.sum(1, keepdims=True) sum_weights[sum_weights == 0] = 1 return weights / sum_weights def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, antialiasing): # antialiasing is "stretching" the field of view according to the scale # factor (only for downscaling). this is low-pass filtering. this # requires modifying both the interpolation (stretching the 1d # function and multiplying by the scale-factor) and the window size. if scale_factor >= 1.0 or not antialiasing: return interp_method, support_sz cur_interp_method = (lambda arg: scale_factor * interp_method(scale_factor * arg)) cur_support_sz = support_sz / scale_factor return cur_interp_method, cur_support_sz def fw_ceil(x, fw): if fw is numpy: return fw.int_(fw.ceil(x)) else: return x.ceil().long() def fw_cat(x, fw): if fw is numpy: return fw.concatenate(x) else: return fw.cat(x) def fw_swapaxes(x, ax_1, ax_2, fw): if fw is numpy: return fw.swapaxes(x, ax_1, ax_2) else: return x.transpose(ax_1, ax_2) def fw_set_device(x, device, fw): if fw is numpy: return x else: return x.to(device) ================================================ FILE: code/synthetic/bsrt/utils/spatial_color_alignment.py ================================================ import math import torch import torch.nn.functional as F def gauss_1d(sz, sigma, center, end_pad=0, density=False): """ Returns a 1-D Gaussian """ k = torch.arange(-(sz-1)/2, (sz+1)/2 + end_pad).reshape(1, -1) gauss = torch.exp(-1.0/(2*sigma**2) * (k - center.reshape(-1, 1))**2) if density: gauss /= math.sqrt(2*math.pi) * sigma return gauss def gauss_2d(sz, sigma, center, end_pad=(0, 0), density=False): """ Returns a 2-D Gaussian """ if isinstance(sigma, (float, int)): sigma = (sigma, sigma) if isinstance(sz, int): sz = (sz, sz) if isinstance(center, (list, tuple)): center = torch.tensor(center).view(1, 2) return gauss_1d(sz[0], sigma[0], center[:, 0], end_pad[0], density).reshape(center.shape[0], 1, -1) * \ gauss_1d(sz[1], sigma[1], center[:, 1], end_pad[1], density).reshape(center.shape[0], -1, 1) def get_gaussian_kernel(sd): """ Returns a Gaussian kernel with standard deviation sd """ ksz = int(4 * sd + 1) assert ksz % 2 == 1 K = gauss_2d(ksz, sd, (0.0, 0.0), density=True) K = K / K.sum() return K.unsqueeze(0), ksz def apply_kernel(im, ksz, gauss_kernel): shape = im.shape im = im.view(-1, 1, *im.shape[-2:]) pad = [ksz // 2, ksz // 2, ksz // 2, ksz // 2] im = F.pad(im, pad, mode='reflect') im_mean = F.conv2d(im, gauss_kernel).view(shape) return im_mean def match_colors(im_ref, im_q, im_test, ksz, gauss_kernel): """ Estimates a color transformation matrix between im_ref and im_q. Applies the estimated transformation to im_test """ gauss_kernel = gauss_kernel.to(im_ref.device) bi = 5 # Apply Gaussian smoothing im_ref_mean = apply_kernel(im_ref, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous() im_q_mean = apply_kernel(im_q, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous() im_ref_mean_re = im_ref_mean.view(*im_ref_mean.shape[:2], -1) im_q_mean_re = im_q_mean.view(*im_q_mean.shape[:2], -1) # Estimate color transformation matrix by minimizing the least squares error c_mat_all = [] for ir, iq in zip(im_ref_mean_re, im_q_mean_re): c = torch.lstsq(ir.t(), iq.t()) c = c.solution[:3] c_mat_all.append(c) c_mat = torch.stack(c_mat_all, dim=0) im_q_mean_conv = torch.matmul(im_q_mean_re.permute(0, 2, 1), c_mat).permute(0, 2, 1) im_q_mean_conv = im_q_mean_conv.view(im_q_mean.shape) err = ((im_q_mean_conv - im_ref_mean) * 255.0).norm(dim=1) thresh = 20 # If error is larger than a threshold, ignore these pixels valid = err < thresh pad = (im_q.shape[-1] - valid.shape[-1]) // 2 pad = [pad, pad, pad, pad] valid = F.pad(valid, pad) upsample_factor = im_test.shape[-1] / valid.shape[-1] valid = F.interpolate(valid.unsqueeze(1).float(), scale_factor=upsample_factor, mode='bilinear', align_corners=False) valid = valid > 0.9 # Apply the transformation to test image im_test_re = im_test.view(*im_test.shape[:2], -1) im_t_conv = torch.matmul(im_test_re.permute(0, 2, 1), c_mat).permute(0, 2, 1) im_t_conv = im_t_conv.view(im_test.shape) return im_t_conv, valid ================================================ FILE: code/synthetic/bsrt/utils/stn.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class SpatialTransformer(nn.Module): """ [SpatialTransformer] represesents a spatial transformation block that uses the output from the UNet to preform an grid_sample https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ def __init__(self, size, mode='bilinear'): """ Instiatiate the block :param size: size of input to the spatial transformer block :param mode: method of interpolation for grid_sampler """ super(OldSpatialTransformer, self).__init__() if isinstance(size, int): size = (size, size) # Create sampling grid vectors = [ torch.arange(0, s) for s in size ] grids = torch.meshgrid(vectors) grid = torch.stack(grids) # y, x, z grid = torch.unsqueeze(grid, 0) #add batch grid = grid.type(torch.FloatTensor) self.register_buffer('grid', grid) self.mode = mode def forward(self, src, flow): """ Push the src and flow through the spatial transform block :param src: the original moving image :param flow: the output from the U-Net """ new_locs = self.grid + flow shape = flow.shape[2:] # Need to normalize grid values to [-1, 1] for resampler for i in range(len(shape)): new_locs[:,i,...] = 2*(new_locs[:,i,...]/(shape[i]-1) - 0.5) if len(shape) == 2: new_locs = new_locs.permute(0, 2, 3, 1) new_locs = new_locs[..., [1,0]] elif len(shape) == 3: new_locs = new_locs.permute(0, 2, 3, 4, 1) new_locs = new_locs[..., [2,1,0]] return F.grid_sample(src, new_locs, mode=self.mode, align_corners=True) ================================================ FILE: code/synthetic/bsrt/utils/warp.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F def warp(feat, flow, mode='bilinear', padding_mode='zeros'): """ warp an image/tensor (im2) back to im1, according to the optical flow im1 --> im2 input flow must be in format (x, y) at every pixel feat: [B, C, H, W] (im2) flow: [B, 2, H, W] flow (x, y) """ B, C, H, W = feat.size() # print(feat.device, flow.device) # mesh grid rowv, colv = torch.meshgrid([torch.arange(0.5, H + 0.5), torch.arange(0.5, W + 0.5)]) grid = torch.stack((colv, rowv), dim=0).unsqueeze(0).float().to(flow.device) # print(grid.device, flow.device, feat.device) # grid = grid.cuda() grid = grid + flow # scale grid to [-1,1] grid_norm_c = 2.0 * grid[:, 0] / W - 1.0 grid_norm_r = 2.0 * grid[:, 1] / H - 1.0 grid_norm = torch.stack((grid_norm_c, grid_norm_r), dim=1).to(flow.device) grid_norm = grid_norm.permute(0, 2, 3, 1) output = F.grid_sample(feat, grid_norm, mode=mode, align_corners=False, padding_mode=padding_mode) return output ================================================ FILE: requirements.txt ================================================ matplotlib imageio opencv-python tensorboardX