Showing preview only (564K chars total). Download the full file or copy to clipboard to get everything.
Repository: Guaishou74851/AdcSR
Branch: main
Commit: d0b2871e3de9
Files: 54
Total size: 518.0 KB
Directory structure:
gitextract_c9hslnpw/
├── LICENSE
├── README.md
├── bsr/
│ ├── degradations.py
│ ├── transforms.py
│ └── utils/
│ ├── __init__.py
│ ├── color_util.py
│ ├── diffjpeg.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_process_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ ├── options.py
│ ├── plot_util.py
│ └── registry.py
├── config.yml
├── dataset.py
├── evaluate.py
├── evaluate_debug.sh
├── forward.py
├── model.py
├── ram/
│ ├── configs/
│ │ ├── condition_config.json
│ │ ├── med_config.json
│ │ ├── q2l_config.json
│ │ └── swin/
│ │ ├── config_swinB_384.json
│ │ ├── config_swinL_384.json
│ │ └── config_swinL_444.json
│ ├── data/
│ │ ├── ram_tag_list.txt
│ │ ├── ram_tag_list_chinese.txt
│ │ ├── ram_tag_list_threshold.txt
│ │ └── tag_list.txt
│ └── models/
│ ├── __init__.py
│ ├── bert.py
│ ├── bert_lora.py
│ ├── ram.py
│ ├── ram_lora.py
│ ├── swin_transformer.py
│ ├── swin_transformer_lora.py
│ ├── tag2text.py
│ ├── tag2text_lora.py
│ ├── utils.py
│ └── vit.py
├── requirements.txt
├── test.py
├── test_debug.sh
├── train.py
├── train.sh
├── train_debug.sh
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<p align="center">
<img src="assets/icon.png" alt="icon" width="200px"/>
</p>
# (CVPR 2025) Adversarial Diffusion Compression for Real-World Image Super-Resolution [PyTorch]
[](https://arxiv.org/abs/2411.13383) [](https://huggingface.co/Guaishou74851/AdcSR) 
[Bin Chen](https://scholar.google.com/citations?user=aZDNm98AAAAJ)<sup>1,3,\*</sup>
| [Gehui Li](https://github.com/cvsym)<sup>1,\*</sup>
| [Rongyuan Wu](https://scholar.google.com/citations?user=A-U8zE8AAAAJ)<sup>2,3,\*</sup>
| [Xindong Zhang](https://scholar.google.com/citations?user=q76RnqIAAAAJ)<sup>3</sup>
| [Jie Chen](https://aimia-pku.github.io/)<sup>1,†</sup>
| [Jian Zhang](https://jianzhang.tech/)<sup>1,†</sup>
| [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang/)<sup>2,3</sup>
<sup>1</sup> *School of Electronic and Computer Engineering, Peking University*
<sup>2</sup> *The Hong Kong Polytechnic University*, <sup>3</sup> *OPPO Research Institute*
<sup>*</sup> Equal Contribution. <sup>†</sup> Corresponding Authors.
⭐ **If AdcSR is helpful to you, please star this repo. Thanks!** 🤗
## 📝 Overview
### Highlights
- **Adversarial Diffusion Compression (ADC).** We remove and prune redundant modules from the one-step diffusion network [OSEDiff](https://github.com/cswry/OSEDiff) and apply adversarial distillation to retain generative capabilities despite reduced capacity.
- **Real-Time [Stable Diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-1)-Based Image Super-Resolution.** AdcSR super-resolves a 128×128 image to 512×512 **in just 0.03s 🚀** on an A100 GPU.
- **Competitive Visual Quality.** Despite **74% fewer parameters 📉** than [OSEDiff](https://github.com/cswry/OSEDiff), AdcSR achieves **competitive image quality** across multiple benchmarks.
### Framework
1. **Structural Compression**
- **Removable modules** (VAE encoder, text prompt extractor, cross-attention, time embeddings) are eliminated.
- **Prunable modules** (UNet, VAE decoder) are **channel-pruned** to optimize efficiency while preserving performance.
<p align="center">
<img src="assets/teaser.png" alt="teaser" width="55%"/>
</p>
2. **Two-Stage Training**
1. **Pretraining a Pruned VAE Decoder** to maintain its ability to decode latent representations.
2. **Adversarial Distillation** to align compressed network features with the teacher model (e.g., [OSEDiff](https://github.com/cswry/OSEDiff)) and ground truth images.
<p align="center">
<img src="assets/method.png" alt="method" />
</p>
## 😍 Visual Results
[<img src="assets/demo1.png" height="240px"/>](https://imgsli.com/MzU2MjU1) [<img src="assets/demo2.png" height="240px"/>](https://imgsli.com/MzU2MjU2) [<img src="assets/demo3.png" height="240px"/>](https://imgsli.com/MzU2MjU3)
[<img src="assets/demo4.png" height="242px"/>](https://imgsli.com/MzU2NTg4) [<img src="assets/demo5.png" height="242px"/>](https://imgsli.com/MzU2NTkw) [<img src="assets/demo6.png" height="242px"/>](https://imgsli.com/MzU2NTk1)
[<img src="assets/demo7.png" height="319px"/>](https://imgsli.com/MzU2OTE0) [<img src="assets/demo8.png" height="319px"/>](https://imgsli.com/MzU2OTE1)
https://github.com/user-attachments/assets/1211cefa-8704-47f5-82cd-ec4ef084b9ec
<img src="assets/comp.png" alt="comp" width="840px" />
## ⚙ Installation
```shell
git clone https://github.com/Guaishou74851/AdcSR.git
cd AdcSR
conda create -n AdcSR python=3.10
conda activate AdcSR
pip install --upgrade pip
pip install -r requirements.txt
chmod +x train.sh train_debug.sh test_debug.sh evaluate_debug.sh
```
## ⚡ Test
1. **Download test datasets** (`DIV2K-Val.zip`, `DRealSR.zip`, `RealSR.zip`) from [Hugging Face](https://huggingface.co/Guaishou74851/AdcSR) or [PKU Disk](https://disk.pku.edu.cn/link/AAD499197CBF054392BC4061F904CC4026).
2. **Unzip** them into `./testset/`, ensuring the structure:
```
./testset/DIV2K-Val/LR/xxx.png
./testset/DIV2K-Val/HR/xxx.png
./testset/DRealSR/LR/xxx.png
./testset/DRealSR/HR/xxx.png
./testset/RealSR/LR/xxx.png
./testset/RealSR/HR/xxx.png
```
3. **Download model weights** (`net_params_200.pkl`) from the same link and place it in `./weight/`.
4. **Run the test script** (or modify and execute `./test_debug.sh` for convenience):
```bash
python test.py --LR_dir=path_to_LR_images --SR_dir=path_to_SR_images
```
The results will be saved in `path_to_SR_images`.
5. **Test Your Own Images**:
- Place your **low-resolution (LR)** images into `./testset/xxx/`.
- Run the command with `--LR_dir=./testset/xxx/ --SR_dir=./yyy/`, and the model will perform **x4 super-resolution**.
## 🍭 Evaluation
Run the evaluation script (or modify and execute `./evaluate_debug.sh` for convenience):
```bash
python evaluate.py --HR_dir=path_to_HR_images --SR_dir=path_to_SR_images
```
## 🔥 Train
This repo provides code for **Stage 2** training (**adversarial distillation**). For **Stage 1** (pretraining the channel-pruned VAE decoder), refer to our paper and use the code of [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion) repo.
1. **Download pretrained model weights** (`DAPE.pth`, `halfDecoder.ckpt`, `osediff.pkl`, `ram_swin_large_14m.pth`) from [Hugging Face](https://huggingface.co/Guaishou74851/AdcSR) or [PKU Disk](https://disk.pku.edu.cn/link/AAD499197CBF054392BC4061F904CC4026), and place them in `./weight/pretrained/`.
2. **Download the [LSDIR](https://huggingface.co/ofsoundof/LSDIR) dataset** and store it in your preferred location.
3. **Modify the dataset path** in `config.yml`:
```yaml
dataroot_gt: path_to_HR_images_of_LSDIR
```
4. **Run the training script** (or modify and execute `./train.sh` or `./train_debug.sh`):
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --nproc_per_node=8 --master_port=23333 train.py
```
The trained model will be saved in `./weight/`.
## 🥰 Acknowledgement
This project is built upon the codes of [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion), [Diffusers](https://github.com/huggingface/diffusers), [BasicSR](https://github.com/XPixelGroup/BasicSR), and [OSEDiff](https://github.com/cswry/OSEDiff). We sincerely thank the authors of these repos for their significant contributions.
## 🎓 Citation
If you find our work helpful, please consider citing:
```latex
@inproceedings{chen2025adversarial,
title={Adversarial Diffusion Compression for Real-World Image Super-Resolution},
author={Chen, Bin and Li, Gehui and Wu, Rongyuan and Zhang, Xindong and Chen, Jie and Zhang, Jian and Zhang, Lei},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2025}
}
```
================================================
FILE: bsr/degradations.py
================================================
import cv2
import math
import numpy as np
import random
import torch
from scipy import special
from scipy.stats import multivariate_normal
from torchvision.transforms._functional_tensor import rgb_to_grayscale
# -------------------------------------------------------------------- #
# --------------------------- blur kernels --------------------------- #
# -------------------------------------------------------------------- #
# --------------------------- util functions --------------------------- #
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
1))).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
def cdf2(d_matrix, grid):
"""Calculate the CDF of the standard bivariate Gaussian distribution.
Used in skewed Gaussian distribution.
Args:
d_matrix (ndarrasy): skew matrix.
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
cdf (ndarray): skewed cdf.
"""
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
grid = np.dot(grid, d_matrix)
cdf = rv.cdf(grid)
return cdf
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a bivariate generalized Gaussian kernel.
``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a plateau-like anisotropic kernel.
1 / (1+x^(beta))
Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_generalized_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate generalized Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# assume beta_range[0] < 1 < beta_range[1]
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_plateau(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate plateau kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi/2, math.pi/2]
beta_range (tuple): [1, 4]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# TODO: this may be not proper
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_mixed_kernels(kernel_list,
kernel_prob,
kernel_size=21,
sigma_x_range=(0.6, 5),
sigma_y_range=(0.6, 5),
rotation_range=(-math.pi, math.pi),
betag_range=(0.5, 8),
betap_range=(0.5, 8),
noise_range=None):
"""Randomly generate mixed kernels.
Args:
kernel_list (tuple): a list name of kernel types,
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
'plateau_aniso']
kernel_prob (tuple): corresponding kernel probability for each
kernel type
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
kernel_type = random.choices(kernel_list, kernel_prob)[0]
if kernel_type == 'iso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
elif kernel_type == 'aniso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
elif kernel_type == 'generalized_iso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=True)
elif kernel_type == 'generalized_aniso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=False)
elif kernel_type == 'plateau_iso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
elif kernel_type == 'plateau_aniso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
return kernel
np.seterr(divide='ignore', invalid='ignore')
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
"""2D sinc filter
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
Args:
cutoff (float): cutoff frequency in radians (pi is max)
kernel_size (int): horizontal and vertical size, must be odd.
pad_to (int): pad kernel size to desired size, must be odd or zero.
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
kernel = np.fromfunction(
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
kernel = kernel / np.sum(kernel)
if pad_to > kernel_size:
pad_size = (pad_to - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
return kernel
# ------------------------------------------------------------- #
# --------------------------- noise --------------------------- #
# ------------------------------------------------------------- #
# ----------------------- Gaussian Noise ----------------------- #
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
"""Generate Gaussian noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
sigma (float): Noise scale (measured in range 255). Default: 10.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
if gray_noise:
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
else:
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
return noise
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
"""Add Gaussian noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
sigma (float): Noise scale (measured in range 255). Default: 10.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
noise = generate_gaussian_noise(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if not isinstance(sigma, (float, int)):
sigma = sigma.view(img.size(0), 1, 1, 1)
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
noise_gray = noise_gray.view(b, 1, h, w)
# always calculate color noise
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
return noise
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Random Gaussian Noise ----------------------- #
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
if np.random.uniform() < gray_prob:
gray_noise = True
else:
gray_noise = False
return generate_gaussian_noise(img, sigma, gray_noise)
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
sigma = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_gaussian_noise_pt(img, sigma, gray_noise)
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Poisson (Shot) Noise ----------------------- #
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
"""Generate poisson noise.
Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
scale (float): Noise scale. Default: 1.0.
gray_noise (bool): Whether generate gray noise. Default: False.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
if gray_noise:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# round and clip image for counting vals correctly
img = np.clip((img * 255.0).round(), 0, 255) / 255.
vals = len(np.unique(img))
vals = 2**np.ceil(np.log2(vals))
out = np.float32(np.random.poisson(img * vals) / float(vals))
noise = out - img
if gray_noise:
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
return noise * scale
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
"""Add poisson noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
scale (float): Noise scale. Default: 1.0.
gray_noise (bool): Whether generate gray noise. Default: False.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
noise = generate_poisson_noise(img, scale, gray_noise)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
"""Generate a batch of poisson noise (PyTorch version)
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
img_gray = rgb_to_grayscale(img, num_output_channels=1)
# round and clip image for counting vals correctly
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img_gray * vals) / vals
noise_gray = out - img_gray
noise_gray = noise_gray.expand(b, 3, h, w)
# always calculate color noise
# round and clip image for counting vals correctly
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img * vals) / vals
noise = out - img
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
if not isinstance(scale, (float, int)):
scale = scale.view(b, 1, 1, 1)
return noise * scale
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
"""Add poisson noise to a batch of images (PyTorch version).
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_poisson_noise_pt(img, scale, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
scale = np.random.uniform(scale_range[0], scale_range[1])
if np.random.uniform() < gray_prob:
gray_noise = True
else:
gray_noise = False
return generate_poisson_noise(img, scale, gray_noise)
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
scale = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_poisson_noise_pt(img, scale, gray_noise)
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ------------------------------------------------------------------------ #
# --------------------------- JPEG compression --------------------------- #
# ------------------------------------------------------------------------ #
def add_jpg_compression(img, quality=90):
"""Add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality (float): JPG compression quality. 0 for lowest quality, 100 for
best quality. Default: 90.
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
img = np.clip(img, 0, 1)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
return img
def random_add_jpg_compression(img, quality_range=(90, 100)):
"""Randomly add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality_range (tuple[float] | list[float]): JPG compression quality
range. 0 for lowest quality, 100 for best quality.
Default: (90, 100).
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
quality = np.random.uniform(quality_range[0], quality_range[1])
return add_jpg_compression(img, quality)
================================================
FILE: bsr/transforms.py
================================================
import cv2
import random
import torch
def mod_crop(img, scale):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
# determine input type: Numpy array or Tensor
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
if input_type == 'Tensor':
h_lq, w_lq = img_lqs[0].size()[-2:]
h_gt, w_gt = img_gts[0].size()[-2:]
else:
h_lq, w_lq = img_lqs[0].shape[0:2]
h_gt, w_gt = img_gts[0].shape[0:2]
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
if input_type == 'Tensor':
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
else:
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
if input_type == 'Tensor':
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
else:
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
We use vertical flip and transpose for rotation implementation.
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Ratotation. Default: True.
flows (list[ndarray]: Flows to be augmented. If the input is an
ndarray, it will be transformed to a list.
Dimension is (h, w, 2). Default: None.
return_status (bool): Return the status of flip and rotation.
Default: False.
Returns:
list[ndarray] | ndarray: Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip: # horizontal
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip: # vertical
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
if return_status:
return imgs, (hflip, vflip, rot90)
else:
return imgs
def img_rotate(img, angle, center=None, scale=1.0):
"""Rotate image.
Args:
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees. Positive values mean
counter-clockwise rotation.
center (tuple[int]): Rotation center. If the center is None,
initialize it as the center of the image. Default: None.
scale (float): Isotropic scale factor. Default: 1.0.
"""
(h, w) = img.shape[:2]
if center is None:
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
rotated_img = cv2.warpAffine(img, matrix, (w, h))
return rotated_img
================================================
FILE: bsr/utils/__init__.py
================================================
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
from .diffjpeg import DiffJPEG
from .file_client import FileClient
from .img_process_util import USMSharp, usm_sharp
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
from .options import yaml_load
__all__ = [
# color_util.py
'bgr2ycbcr',
'rgb2ycbcr',
'rgb2ycbcr_pt',
'ycbcr2bgr',
'ycbcr2rgb',
# file_client.py
'FileClient',
# img_util.py
'img2tensor',
'tensor2img',
'imfrombytes',
'imwrite',
'crop_border',
# logger.py
'MessageLogger',
'AvgTimer',
'init_tb_logger',
'init_wandb_logger',
'get_root_logger',
'get_env_info',
# misc.py
'set_random_seed',
'get_time_str',
'mkdir_and_rename',
'make_exp_dirs',
'scandir',
'check_resume',
'sizeof_fmt',
# diffjpeg
'DiffJPEG',
# img_process_util
'USMSharp',
'usm_sharp',
# options
'yaml_load'
]
================================================
FILE: bsr/utils/color_util.py
================================================
import numpy as np
import torch
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
else:
out_img = np.matmul(
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2rgb(img):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2bgr(img):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
conversion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type = img.dtype
img = img.astype(np.float32)
if img_type == np.float32:
pass
elif img_type == np.uint8:
img /= 255.
else:
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
return img
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace conversion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if dst_type not in (np.uint8, np.float32):
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
if dst_type == np.uint8:
img = img.round()
else:
img /= 255.
return img.astype(dst_type)
def rgb2ycbcr_pt(img, y_only=False):
"""Convert RGB images to YCbCr images (PyTorch version).
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
Args:
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
"""
if y_only:
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
else:
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
out_img = out_img / 255.
return out_img
================================================
FILE: bsr/utils/diffjpeg.py
================================================
"""
Modified from https://github.com/mlomnitz/DiffJPEG
For images not divisible by 8
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
"""
import itertools
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
# ------------------------ utils ------------------------#
y_table = np.array(
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
dtype=np.float32).T
y_table = nn.Parameter(torch.from_numpy(y_table))
c_table = np.empty((8, 8), dtype=np.float32)
c_table.fill(99)
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
c_table = nn.Parameter(torch.from_numpy(c_table))
def diff_round(x):
""" Differentiable rounding function
"""
return torch.round(x) + (x - torch.round(x))**3
def quality_to_factor(quality):
""" Calculate factor corresponding to quality
Args:
quality(float): Quality for jpeg compression.
Returns:
float: Compression factor.
"""
if quality < 50:
quality = 5000. / quality
else:
quality = 200. - quality * 2
return quality / 100.
# ------------------------ compression ------------------------#
class RGB2YCbCrJpeg(nn.Module):
""" Converts RGB image to YCbCr
"""
def __init__(self):
super(RGB2YCbCrJpeg, self).__init__()
matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
dtype=np.float32).T
self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
"""
Args:
image(Tensor): batch x 3 x height x width
Returns:
Tensor: batch x height x width x 3
"""
image = image.permute(0, 2, 3, 1)
result = torch.tensordot(image, self.matrix, dims=1) + self.shift
return result.view(image.shape)
class ChromaSubsampling(nn.Module):
""" Chroma subsampling on CbCr channels
"""
def __init__(self):
super(ChromaSubsampling, self).__init__()
def forward(self, image):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
y(tensor): batch x height x width
cb(tensor): batch x height/2 x width/2
cr(tensor): batch x height/2 x width/2
"""
image_2 = image.permute(0, 3, 1, 2).clone()
cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
cb = cb.permute(0, 2, 3, 1)
cr = cr.permute(0, 2, 3, 1)
return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
class BlockSplitting(nn.Module):
""" Splitting image into patches
"""
def __init__(self):
super(BlockSplitting, self).__init__()
self.k = 8
def forward(self, image):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x h*w/64 x h x w
"""
height, _ = image.shape[1:3]
batch_size = image.shape[0]
image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
class DCT8x8(nn.Module):
""" Discrete Cosine Transformation
"""
def __init__(self):
super(DCT8x8, self).__init__()
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
def forward(self, image):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image = image - 128
result = self.scale * torch.tensordot(image, self.tensor, dims=2)
result.view(image.shape)
return result
class YQuantize(nn.Module):
""" JPEG Quantization for Y channel
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding):
super(YQuantize, self).__init__()
self.rounding = rounding
self.y_table = y_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
image = image.float() / (self.y_table * factor)
else:
b = factor.size(0)
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
image = image.float() / table
image = self.rounding(image)
return image
class CQuantize(nn.Module):
""" JPEG Quantization for CbCr channels
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding):
super(CQuantize, self).__init__()
self.rounding = rounding
self.c_table = c_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
image = image.float() / (self.c_table * factor)
else:
b = factor.size(0)
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
image = image.float() / table
image = self.rounding(image)
return image
class CompressJpeg(nn.Module):
"""Full JPEG compression algorithm
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding=torch.round):
super(CompressJpeg, self).__init__()
self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
self.c_quantize = CQuantize(rounding=rounding)
self.y_quantize = YQuantize(rounding=rounding)
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x 3 x height x width
Returns:
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
"""
y, cb, cr = self.l1(image * 255)
components = {'y': y, 'cb': cb, 'cr': cr}
for k in components.keys():
comp = self.l2(components[k])
if k in ('cb', 'cr'):
comp = self.c_quantize(comp, factor=factor)
else:
comp = self.y_quantize(comp, factor=factor)
components[k] = comp
return components['y'], components['cb'], components['cr']
# ------------------------ decompression ------------------------#
class YDequantize(nn.Module):
"""Dequantize Y channel
"""
def __init__(self):
super(YDequantize, self).__init__()
self.y_table = y_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
out = image * (self.y_table * factor)
else:
b = factor.size(0)
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
out = image * table
return out
class CDequantize(nn.Module):
"""Dequantize CbCr channel
"""
def __init__(self):
super(CDequantize, self).__init__()
self.c_table = c_table
def forward(self, image, factor=1):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if isinstance(factor, (int, float)):
out = image * (self.c_table * factor)
else:
b = factor.size(0)
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
out = image * table
return out
class iDCT8x8(nn.Module):
"""Inverse discrete Cosine Transformation
"""
def __init__(self):
super(iDCT8x8, self).__init__()
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
def forward(self, image):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image = image * self.alpha
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
result.view(image.shape)
return result
class BlockMerging(nn.Module):
"""Merge patches into image
"""
def __init__(self):
super(BlockMerging, self).__init__()
def forward(self, patches, height, width):
"""
Args:
patches(tensor) batch x height*width/64, height x width
height(int)
width(int)
Returns:
Tensor: batch x height x width
"""
k = 8
batch_size = patches.shape[0]
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
return image_transposed.contiguous().view(batch_size, height, width)
class ChromaUpsampling(nn.Module):
"""Upsample chroma layers
"""
def __init__(self):
super(ChromaUpsampling, self).__init__()
def forward(self, y, cb, cr):
"""
Args:
y(tensor): y channel image
cb(tensor): cb channel
cr(tensor): cr channel
Returns:
Tensor: batch x height x width x 3
"""
def repeat(x, k=2):
height, width = x.shape[1:3]
x = x.unsqueeze(-1)
x = x.repeat(1, 1, k, k)
x = x.view(-1, height * k, width * k)
return x
cb = repeat(cb)
cr = repeat(cr)
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
class YCbCr2RGBJpeg(nn.Module):
"""Converts YCbCr image to RGB JPEG
"""
def __init__(self):
super(YCbCr2RGBJpeg, self).__init__()
matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
Tensor: batch x 3 x height x width
"""
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
return result.view(image.shape).permute(0, 3, 1, 2)
class DeCompressJpeg(nn.Module):
"""Full JPEG decompression algorithm
Args:
rounding(function): rounding function to use
"""
def __init__(self, rounding=torch.round):
super(DeCompressJpeg, self).__init__()
self.c_dequantize = CDequantize()
self.y_dequantize = YDequantize()
self.idct = iDCT8x8()
self.merging = BlockMerging()
self.chroma = ChromaUpsampling()
self.colors = YCbCr2RGBJpeg()
def forward(self, y, cb, cr, imgh, imgw, factor=1):
"""
Args:
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
imgh(int)
imgw(int)
factor(float)
Returns:
Tensor: batch x 3 x height x width
"""
components = {'y': y, 'cb': cb, 'cr': cr}
for k in components.keys():
if k in ('cb', 'cr'):
comp = self.c_dequantize(components[k], factor=factor)
height, width = int(imgh / 2), int(imgw / 2)
else:
comp = self.y_dequantize(components[k], factor=factor)
height, width = imgh, imgw
comp = self.idct(comp)
components[k] = self.merging(comp, height, width)
#
image = self.chroma(components['y'], components['cb'], components['cr'])
image = self.colors(image)
image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
return image / 255
# ------------------------ main DiffJPEG ------------------------ #
class DiffJPEG(nn.Module):
"""This JPEG algorithm result is slightly different from cv2.
DiffJPEG supports batch processing.
Args:
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
"""
def __init__(self, differentiable=True):
super(DiffJPEG, self).__init__()
if differentiable:
rounding = diff_round
else:
rounding = torch.round
self.compress = CompressJpeg(rounding=rounding)
self.decompress = DeCompressJpeg(rounding=rounding)
def forward(self, x, quality):
"""
Args:
x (Tensor): Input image, bchw, rgb, [0, 1]
quality(float): Quality factor for jpeg compression scheme.
"""
factor = quality
if isinstance(factor, (int, float)):
factor = quality_to_factor(factor)
else:
for i in range(factor.size(0)):
factor[i] = quality_to_factor(factor[i])
h, w = x.size()[-2:]
h_pad, w_pad = 0, 0
# why should use 16
if h % 16 != 0:
h_pad = 16 - h % 16
if w % 16 != 0:
w_pad = 16 - w % 16
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
y, cb, cr = self.compress(x, factor=factor)
recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
recovered = recovered[:, :, 0:h, 0:w]
return recovered
if __name__ == '__main__':
import cv2
from bsr.utils import img2tensor, tensor2img
img_gt = cv2.imread('test.png') / 255.
# -------------- cv2 -------------- #
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
_, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
img_lq = np.float32(cv2.imdecode(encimg, 1))
cv2.imwrite('cv2_JPEG_20.png', img_lq)
# -------------- DiffJPEG -------------- #
jpeger = DiffJPEG(differentiable=False).cuda()
img_gt = img2tensor(img_gt)
img_gt = torch.stack([img_gt, img_gt]).cuda()
quality = img_gt.new_tensor([20, 40])
out = jpeger(img_gt, quality=quality)
cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
================================================
FILE: bsr/utils/dist_util.py
================================================
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
import functools
import os
import subprocess
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs):
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def get_dist_info():
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
================================================
FILE: bsr/utils/download_util.py
================================================
import math
import os
import requests
from torch.hub import download_url_to_file, get_dir
from tqdm import tqdm
from urllib.parse import urlparse
from .misc import sizeof_fmt
def download_file_from_google_drive(file_id, save_path):
"""Download files from google drive.
Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive
Args:
file_id (str): File id.
save_path (str): Save path.
"""
session = requests.Session()
URL = 'https://docs.google.com/uc?export=download'
params = {'id': file_id}
response = session.get(URL, params=params, stream=True)
token = get_confirm_token(response)
if token:
params['confirm'] = token
response = session.get(URL, params=params, stream=True)
# get file size
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
if 'Content-Range' in response_file_size.headers:
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
else:
file_size = None
save_response_content(response, save_path, file_size)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination, file_size=None, chunk_size=32768):
if file_size is not None:
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
readable_file_size = sizeof_fmt(file_size)
else:
pbar = None
with open(destination, 'wb') as f:
downloaded_size = 0
for chunk in response.iter_content(chunk_size):
downloaded_size += chunk_size
if pbar is not None:
pbar.update(1)
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if pbar is not None:
pbar.close()
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
================================================
FILE: bsr/utils/file_client.py
================================================
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
from abc import ABCMeta, abstractmethod
class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
All backends need to implement two apis: ``get()`` and ``get_text()``.
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
as texts.
"""
@abstractmethod
def get(self, filepath):
pass
@abstractmethod
def get_text(self, filepath):
pass
class MemcachedBackend(BaseStorageBackend):
"""Memcached storage backend.
Attributes:
server_list_cfg (str): Config file for memcached server list.
client_cfg (str): Config file for memcached client.
sys_path (str | None): Additional path to be appended to `sys.path`.
Default: None.
"""
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
if sys_path is not None:
import sys
sys.path.append(sys_path)
try:
import mc
except ImportError:
raise ImportError('Please install memcached to enable MemcachedBackend.')
self.server_list_cfg = server_list_cfg
self.client_cfg = client_cfg
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
# mc.pyvector servers as a point which points to a memory cache
self._mc_buffer = mc.pyvector()
def get(self, filepath):
filepath = str(filepath)
import mc
self._client.Get(filepath, self._mc_buffer)
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend."""
def get(self, filepath):
filepath = str(filepath)
with open(filepath, 'rb') as f:
value_buf = f.read()
return value_buf
def get_text(self, filepath):
filepath = str(filepath)
with open(filepath, 'r') as f:
value_buf = f.read()
return value_buf
class LmdbBackend(BaseStorageBackend):
"""Lmdb storage backend.
Args:
db_paths (str | list[str]): Lmdb database paths.
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
readonly (bool, optional): Lmdb environment parameter. If True,
disallow any write operations. Default: True.
lock (bool, optional): Lmdb environment parameter. If False, when
concurrent access occurs, do not lock the database. Default: False.
readahead (bool, optional): Lmdb environment parameter. If False,
disable the OS filesystem readahead mechanism, which may improve
random read performance when a database is larger than RAM.
Default: False.
Attributes:
db_paths (list): Lmdb database path.
_client (list): A list of several lmdb envs.
"""
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
try:
import lmdb
except ImportError:
raise ImportError('Please install lmdb to enable LmdbBackend.')
if isinstance(client_keys, str):
client_keys = [client_keys]
if isinstance(db_paths, list):
self.db_paths = [str(v) for v in db_paths]
elif isinstance(db_paths, str):
self.db_paths = [str(db_paths)]
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
f'but received {len(client_keys)} and {len(self.db_paths)}.')
self._client = {}
for client, path in zip(client_keys, self.db_paths):
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
def get(self, filepath, client_key):
"""Get values according to the filepath from one lmdb named client_key.
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
client_key (str): Used for distinguishing different lmdb envs.
"""
filepath = str(filepath)
assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
client = self._client[client_key]
with client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii'))
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class FileClient(object):
"""A general file client to access files in different backend.
The client loads a file or text in a specified backend from its path
and return it as a binary file. it can also register other backend
accessor with a given name and backend class.
Attributes:
backend (str): The storage backend type. Options are "disk",
"memcached" and "lmdb".
client (:obj:`BaseStorageBackend`): The backend object.
"""
_backends = {
'disk': HardDiskBackend,
'memcached': MemcachedBackend,
'lmdb': LmdbBackend,
}
def __init__(self, backend='disk', **kwargs):
if backend not in self._backends:
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
f' are {list(self._backends.keys())}')
self.backend = backend
self.client = self._backends[backend](**kwargs)
def get(self, filepath, client_key='default'):
# client_key is used only for lmdb, where different fileclients have
# different lmdb environments.
if self.backend == 'lmdb':
return self.client.get(filepath, client_key)
else:
return self.client.get(filepath)
def get_text(self, filepath):
return self.client.get_text(filepath)
================================================
FILE: bsr/utils/flow_util.py
================================================
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
import cv2
import numpy as np
import os
def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
"""Read an optical flow map.
Args:
flow_path (ndarray or str): Flow path.
quantize (bool): whether to read quantized pair, if set to True,
remaining args will be passed to :func:`dequantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
Returns:
ndarray: Optical flow represented as a (h, w, 2) numpy array
"""
if quantize:
assert concat_axis in [0, 1]
cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
if cat_flow.ndim != 2:
raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
assert cat_flow.shape[concat_axis] % 2 == 0
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
flow = dequantize_flow(dx, dy, *args, **kwargs)
else:
with open(flow_path, 'rb') as f:
try:
header = f.read(4).decode('utf-8')
except Exception:
raise IOError(f'Invalid flow file: {flow_path}')
else:
if header != 'PIEH':
raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
w = np.fromfile(f, np.int32, 1).squeeze()
h = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
return flow.astype(np.float32)
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
"""Write optical flow to file.
If the flow is not quantized, it will be saved as a .flo file losslessly,
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
will be concatenated horizontally into a single image if quantize is True.)
Args:
flow (ndarray): (h, w, 2) array of optical flow.
filename (str): Output filepath.
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
images. If set to True, remaining args will be passed to
:func:`quantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
"""
if not quantize:
with open(filename, 'wb') as f:
f.write('PIEH'.encode('utf-8'))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
f.flush()
else:
assert concat_axis in [0, 1]
dx, dy = quantize_flow(flow, *args, **kwargs)
dxdy = np.concatenate((dx, dy), axis=concat_axis)
os.makedirs(os.path.dirname(filename), exist_ok=True)
cv2.imwrite(filename, dxdy)
def quantize_flow(flow, max_val=0.02, norm=True):
"""Quantize flow to [0, 255].
After this step, the size of flow will be much smaller, and can be
dumped as jpeg images.
Args:
flow (ndarray): (h, w, 2) array of optical flow.
max_val (float): Maximum value of flow, values beyond
[-max_val, max_val] will be truncated.
norm (bool): Whether to divide flow values by image width/height.
Returns:
tuple[ndarray]: Quantized dx and dy.
"""
h, w, _ = flow.shape
dx = flow[..., 0]
dy = flow[..., 1]
if norm:
dx = dx / w # avoid inplace operations
dy = dy / h
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
return tuple(flow_comps)
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
"""Recover from quantized flow.
Args:
dx (ndarray): Quantized dx.
dy (ndarray): Quantized dy.
max_val (float): Maximum value used when quantizing.
denorm (bool): Whether to multiply flow values with width/height.
Returns:
ndarray: Dequantized flow.
"""
assert dx.shape == dy.shape
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
if denorm:
dx *= dx.shape[1]
dy *= dx.shape[0]
flow = np.dstack((dx, dy))
return flow
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
"""Quantize an array of (-inf, inf) to [0, levels-1].
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the quantized array.
Returns:
tuple: Quantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
arr = np.clip(arr, min_val, max_val) - min_val
quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
return quantized_arr
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
"""Dequantize an array.
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the dequantized array.
Returns:
tuple: Dequantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
return dequantized_arr
================================================
FILE: bsr/utils/img_process_util.py
================================================
import cv2
import numpy as np
import torch
from torch.nn import functional as F
def filter2D(img, kernel):
"""PyTorch version of cv2.filter2D
Args:
img (Tensor): (b, c, h, w)
kernel (Tensor): (b, k, k)
"""
k = kernel.size(-1)
b, c, h, w = img.size()
if k % 2 == 1:
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
else:
raise ValueError('Wrong kernel size')
ph, pw = img.size()[-2:]
if kernel.size(0) == 1:
# apply the same kernel to all batch images
img = img.view(b * c, 1, ph, pw)
kernel = kernel.view(1, 1, k, k)
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
else:
img = img.view(1, b * c, ph, pw)
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening.
Input image: I; Blurry image: B.
1. sharp = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * sharp + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
sharp = img + weight * residual
sharp = np.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
class USMSharp(torch.nn.Module):
def __init__(self, radius=50, sigma=0):
super(USMSharp, self).__init__()
if radius % 2 == 0:
radius += 1
self.radius = radius
kernel = cv2.getGaussianKernel(radius, sigma)
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
self.register_buffer('kernel', kernel)
def forward(self, img, weight=0.5, threshold=10):
blur = filter2D(img, self.kernel)
residual = img - blur
mask = torch.abs(residual) * 255 > threshold
mask = mask.float()
soft_mask = filter2D(mask, self.kernel)
sharp = img + weight * residual
sharp = torch.clip(sharp, 0, 1)
return soft_mask * sharp + (1 - soft_mask) * img
================================================
FILE: bsr/utils/img_util.py
================================================
import cv2
import math
import numpy as np
import os
import torch
from torchvision.utils import make_grid
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1:
result = result[0]
return result
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
"""This implementation is slightly faster than tensor2img.
It now only supports torch tensor with shape (1, c, h, w).
Args:
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
min_max (tuple[int]): min and max values for clamp.
"""
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
output = output.type(torch.uint8).cpu().numpy()
if rgb2bgr:
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def imfrombytes(content, flag='color', float32=False):
"""Read an image from bytes.
Args:
content (bytes): Image bytes got from files or other streams.
flag (str): Flags specifying the color type of a loaded image,
candidates are `color`, `grayscale` and `unchanged`.
float32 (bool): Whether to change to float32., If True, will also norm
to [0, 1]. Default: False.
Returns:
ndarray: Loaded image array.
"""
img_np = np.frombuffer(content, np.uint8)
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
img = cv2.imdecode(img_np, imread_flags[flag])
if float32:
img = img.astype(np.float32) / 255.
return img
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv's :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
Returns:
bool: Successful or not.
"""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(file_path))
os.makedirs(dir_name, exist_ok=True)
ok = cv2.imwrite(file_path, img, params)
if not ok:
raise IOError('Failed in writing images.')
def crop_border(imgs, crop_border):
"""Crop borders of images.
Args:
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
crop_border (int): Crop border for each end of height and weight.
Returns:
list[ndarray]: Cropped images.
"""
if crop_border == 0:
return imgs
else:
if isinstance(imgs, list):
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
else:
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
================================================
FILE: bsr/utils/lmdb_util.py
================================================
import cv2
import lmdb
import sys
from multiprocessing import Pool
from os import path as osp
from tqdm import tqdm
def make_lmdb_from_imgs(data_path,
lmdb_path,
img_path_list,
keys,
batch=5000,
compress_level=1,
multiprocessing_read=False,
n_thread=40,
map_size=None):
"""Make lmdb from images.
Contents of lmdb. The file structure is:
::
example.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records 1)image name (with extension),
2)image shape, and 3)compression level, separated by a white space.
For example, the meta information could be:
`000_00000000.png (720,1280,3) 1`, which means:
1) image name (with extension): 000_00000000.png;
2) image shape: (720,1280,3);
3) compression level: 1
We use the image name without extension as the lmdb key.
If `multiprocessing_read` is True, it will read all the images to memory
using multiprocessing. Thus, your server needs to have enough memory.
Args:
data_path (str): Data path for reading images.
lmdb_path (str): Lmdb save path.
img_path_list (str): Image path list.
keys (str): Used for lmdb keys.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
multiprocessing_read (bool): Whether use multiprocessing to read all
the images to memory. Default: False.
n_thread (int): For multiprocessing.
map_size (int | None): Map size for lmdb env. If None, use the
estimated size from images. Default: None
"""
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
f'but got {len(img_path_list)} and {len(keys)}')
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
print(f'Totoal images: {len(img_path_list)}')
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
if multiprocessing_read:
# read all the images to memory (multiprocessing)
dataset = {} # use dict to keep the order for multiprocessing
shapes = {}
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
pbar = tqdm(total=len(img_path_list), unit='image')
def callback(arg):
"""get the image data and update pbar."""
key, dataset[key], shapes[key] = arg
pbar.update(1)
pbar.set_description(f'Read {key}')
pool = Pool(n_thread)
for path, key in zip(img_path_list, keys):
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
pool.close()
pool.join()
pbar.close()
print(f'Finish reading {len(img_path_list)} images.')
# create lmdb environment
if map_size is None:
# obtain data size for one image
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
data_size_per_img = img_byte.nbytes
print('Data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(img_path_list)
map_size = data_size * 10
env = lmdb.open(lmdb_path, map_size=map_size)
# write data to lmdb
pbar = tqdm(total=len(img_path_list), unit='chunk')
txn = env.begin(write=True)
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
pbar.update(1)
pbar.set_description(f'Write {key}')
key_byte = key.encode('ascii')
if multiprocessing_read:
img_byte = dataset[key]
h, w, c = shapes[key]
else:
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
h, w, c = img_shape
txn.put(key_byte, img_byte)
# write meta information
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
if idx % batch == 0:
txn.commit()
txn = env.begin(write=True)
pbar.close()
txn.commit()
env.close()
txt_file.close()
print('\nFinish writing lmdb.')
def read_img_worker(path, key, compress_level):
"""Read image worker.
Args:
path (str): Image path.
key (str): Image key.
compress_level (int): Compress level when encoding images.
Returns:
str: Image key.
byte: Image byte.
tuple[int]: Image shape.
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img.ndim == 2:
h, w = img.shape
c = 1
else:
h, w, c = img.shape
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
return (key, img_byte, (h, w, c))
class LmdbMaker():
"""LMDB Maker.
Args:
lmdb_path (str): Lmdb save path.
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
"""
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
self.lmdb_path = lmdb_path
self.batch = batch
self.compress_level = compress_level
self.env = lmdb.open(lmdb_path, map_size=map_size)
self.txn = self.env.begin(write=True)
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
self.counter = 0
def put(self, img_byte, key, img_shape):
self.counter += 1
key_byte = key.encode('ascii')
self.txn.put(key_byte, img_byte)
# write meta information
h, w, c = img_shape
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
if self.counter % self.batch == 0:
self.txn.commit()
self.txn = self.env.begin(write=True)
def close(self):
self.txn.commit()
self.env.close()
self.txt_file.close()
================================================
FILE: bsr/utils/logger.py
================================================
import datetime
import logging
import time
from .dist_util import get_dist_info, master_only
initialized_logger = {}
class AvgTimer():
def __init__(self, window=200):
self.window = window # average window
self.current_time = 0
self.total_time = 0
self.count = 0
self.avg_time = 0
self.start()
def start(self):
self.start_time = self.tic = time.time()
def record(self):
self.count += 1
self.toc = time.time()
self.current_time = self.toc - self.tic
self.total_time += self.current_time
# calculate average time
self.avg_time = self.total_time / self.count
# reset
if self.count > self.window:
self.count = 0
self.total_time = 0
self.tic = time.time()
def get_current_time(self):
return self.current_time
def get_avg_time(self):
return self.avg_time
class MessageLogger():
"""Message logger for printing.
Args:
opt (dict): Config. It contains the following keys:
name (str): Exp name.
logger (dict): Contains 'print_freq' (str) for logger interval.
train (dict): Contains 'total_iter' (int) for total iters.
use_tb_logger (bool): Use tensorboard logger.
start_iter (int): Start iter. Default: 1.
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
"""
def __init__(self, opt, start_iter=1, tb_logger=None):
self.exp_name = opt['name']
self.interval = opt['logger']['print_freq']
self.start_iter = start_iter
self.max_iters = opt['train']['total_iter']
self.use_tb_logger = opt['logger']['use_tb_logger']
self.tb_logger = tb_logger
self.start_time = time.time()
self.logger = get_root_logger()
def reset_start_time(self):
self.start_time = time.time()
@master_only
def __call__(self, log_vars):
"""Format logging message.
Args:
log_vars (dict): It contains the following keys:
epoch (int): Epoch number.
iter (int): Current iter.
lrs (list): List for learning rates.
time (float): Iter time.
data_time (float): Data time for each iter.
"""
# epoch, iter, learning rates
epoch = log_vars.pop('epoch')
current_iter = log_vars.pop('iter')
lrs = log_vars.pop('lrs')
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
for v in lrs:
message += f'{v:.3e},'
message += ')] '
# time and estimated time
if 'time' in log_vars.keys():
iter_time = log_vars.pop('time')
data_time = log_vars.pop('data_time')
total_time = time.time() - self.start_time
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
message += f'[eta: {eta_str}, '
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
# other items, especially losses
for k, v in log_vars.items():
message += f'{k}: {v:.4e} '
# tensorboard logger
if self.use_tb_logger and 'debug' not in self.exp_name:
if k.startswith('l_'):
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
else:
self.tb_logger.add_scalar(k, v, current_iter)
self.logger.info(message)
@master_only
def init_tb_logger(log_dir):
from torch.utils.tensorboard import SummaryWriter
tb_logger = SummaryWriter(log_dir=log_dir)
return tb_logger
@master_only
def init_wandb_logger(opt):
"""We now only use wandb to sync tensorboard log."""
import wandb
logger = get_root_logger()
project = opt['logger']['wandb']['project']
resume_id = opt['logger']['wandb'].get('resume_id')
if resume_id:
wandb_id = resume_id
resume = 'allow'
logger.warning(f'Resume wandb logger with id={wandb_id}.')
else:
wandb_id = wandb.util.generate_id()
resume = 'never'
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added.
Args:
logger_name (str): root logger name. Default: 'basicsr'.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = logging.getLogger(logger_name)
# if the logger has been initialized, just return it
if logger_name in initialized_logger:
return logger
format_str = '%(asctime)s %(levelname)s: %(message)s'
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(format_str))
logger.addHandler(stream_handler)
logger.propagate = False
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
logger.setLevel(log_level)
# add file handler
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
initialized_logger[logger_name] = True
return logger
def get_env_info():
"""Get environment information.
Currently, only log the software version.
"""
import torch
import torchvision
from basicsr.version import __version__
msg = r"""
____ _ _____ ____
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
/_____/ \__,_//____//_/ \___//____//_/ |_|
______ __ __ __ __
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
"""
msg += ('\nVersion Information: '
f'\n\tBasicSR: {__version__}'
f'\n\tPyTorch: {torch.__version__}'
f'\n\tTorchVision: {torchvision.__version__}')
return msg
================================================
FILE: bsr/utils/matlab_functions.py
================================================
import math
import numpy as np
import torch
def cubic(x):
"""cubic function used for calculate_weights_indices."""
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
(absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
"""Calculate weights and indices, used for imresize function.
Args:
in_length (int): Input length.
out_length (int): Output length.
scale (float): Scale factor.
kernel_width (int): Kernel width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
"""
if (scale < 1) and antialiasing:
# Use a modified kernel (larger kernel width) to simultaneously
# interpolate and antialias
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5 + scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
p = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
out_length, p)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
# apply cubic kernel
if (scale < 1) and antialiasing:
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, p)
# If a column in weights is all zero, get rid of it. only consider the
# first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, p - 2)
weights = weights.narrow(1, 1, p - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, p - 2)
weights = weights.narrow(1, 0, p - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
@torch.no_grad()
def imresize(img, scale, antialiasing=True):
"""imresize function same as MATLAB.
It now only supports bicubic.
The same scale applies for both height and width.
Args:
img (Tensor | Numpy array):
Tensor: Input image with shape (c, h, w), [0, 1] range.
Numpy: Input image with shape (h, w, c), [0, 1] range.
scale (float): Scale factor. The same scale applies for both height
and width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
Default: True.
Returns:
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
"""
squeeze_flag = False
if type(img).__module__ == np.__name__: # numpy type
numpy_type = True
if img.ndim == 2:
img = img[:, :, None]
squeeze_flag = True
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
else:
numpy_type = False
if img.ndim == 2:
img = img.unsqueeze(0)
squeeze_flag = True
in_c, in_h, in_w = img.size()
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
kernel_width = 4
kernel = 'cubic'
# get weights and indices
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
antialiasing)
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
sym_patch = img[:, :sym_len_hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_he:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_c, out_h, in_w)
kernel_width = weights_h.size(1)
for i in range(out_h):
idx = int(indices_h[i][0])
for j in range(in_c):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_we:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_c, out_h, out_w)
kernel_width = weights_w.size(1)
for i in range(out_w):
idx = int(indices_w[i][0])
for j in range(in_c):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
if squeeze_flag:
out_2 = out_2.squeeze(0)
if numpy_type:
out_2 = out_2.numpy()
if not squeeze_flag:
out_2 = out_2.transpose(1, 2, 0)
return out_2
================================================
FILE: bsr/utils/misc.py
================================================
import numpy as np
import os
import random
import time
import torch
from os import path as osp
from .dist_util import master_only
def set_random_seed(seed):
"""Set random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def mkdir_and_rename(path):
"""mkdirs. If path exists, rename it with timestamp and create a new one.
Args:
path (str): Folder path.
"""
if osp.exists(path):
new_name = path + '_archived_' + get_time_str()
print(f'Path already exists. Rename it to {new_name}', flush=True)
os.rename(path, new_name)
os.makedirs(path, exist_ok=True)
@master_only
def make_exp_dirs(opt):
"""Make dirs for experiments."""
path_opt = opt['path'].copy()
if opt['is_train']:
mkdir_and_rename(path_opt.pop('experiments_root'))
else:
mkdir_and_rename(path_opt.pop('results_root'))
for key, path in path_opt.items():
if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
continue
else:
os.makedirs(path, exist_ok=True)
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative paths.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def check_resume(opt, resume_iter):
"""Check resume states and pretrain_network paths.
Args:
opt (dict): Options.
resume_iter (int): Resume iteration.
"""
if opt['path']['resume_state']:
# get all the networks
networks = [key for key in opt.keys() if key.startswith('network_')]
flag_pretrain = False
for network in networks:
if opt['path'].get(f'pretrain_{network}') is not None:
flag_pretrain = True
if flag_pretrain:
print('pretrain_network path will be ignored during resuming.')
# set pretrained model paths
for network in networks:
name = f'pretrain_{network}'
basename = network.replace('network_', '')
if opt['path'].get('ignore_resume_networks') is None or (network
not in opt['path']['ignore_resume_networks']):
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
print(f"Set {name} to {opt['path'][name]}")
# change param_key to params in resume
param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
for param_key in param_keys:
if opt['path'][param_key] == 'params_ema':
opt['path'][param_key] = 'params'
print(f'Set {param_key} to params')
def sizeof_fmt(size, suffix='B'):
"""Get human readable file size.
Args:
size (int): File size.
suffix (str): Suffix. Default: 'B'.
Return:
str: Formatted file size.
"""
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
if abs(size) < 1024.0:
return f'{size:3.1f} {unit}{suffix}'
size /= 1024.0
return f'{size:3.1f} Y{suffix}'
================================================
FILE: bsr/utils/options.py
================================================
import argparse
import os
import random
import torch
import yaml
from collections import OrderedDict
from os import path as osp
from bsr.utils import set_random_seed
from bsr.utils.dist_util import get_dist_info, init_dist, master_only
def ordered_yaml():
"""Support OrderedDict for yaml.
Returns:
tuple: yaml Loader and Dumper.
"""
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Dumper, Loader
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
def dict_constructor(loader, node):
return OrderedDict(loader.construct_pairs(node))
Dumper.add_representer(OrderedDict, dict_representer)
Loader.add_constructor(_mapping_tag, dict_constructor)
return Loader, Dumper
def yaml_load(f):
"""Load yaml file or string.
Args:
f (str): File path or a python string.
Returns:
dict: Loaded dict.
"""
if os.path.isfile(f):
with open(f, 'r') as f:
return yaml.load(f, Loader=ordered_yaml()[0])
else:
return yaml.load(f, Loader=ordered_yaml()[0])
def dict2str(opt, indent_level=1):
"""dict to string for printing options.
Args:
opt (dict): Option dict.
indent_level (int): Indent level. Default: 1.
Return:
(str): Option string for printing.
"""
msg = '\n'
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_level * 2) + k + ':['
msg += dict2str(v, indent_level + 1)
msg += ' ' * (indent_level * 2) + ']\n'
else:
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
return msg
def _postprocess_yml_value(value):
# None
if value == '~' or value.lower() == 'none':
return None
# bool
if value.lower() == 'true':
return True
elif value.lower() == 'false':
return False
# !!float number
if value.startswith('!!float'):
return float(value.replace('!!float', ''))
# number
if value.isdigit():
return int(value)
elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
return float(value)
# list
if value.startswith('['):
return eval(value)
# str
return value
def parse_options(root_path, is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
args = parser.parse_args()
# parse yml to dict
opt = yaml_load(args.opt)
# distributed settings
if args.launcher == 'none':
opt['dist'] = False
print('Disable distributed.', flush=True)
else:
opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt:
init_dist(args.launcher, **opt['dist_params'])
else:
init_dist(args.launcher)
opt['rank'], opt['world_size'] = get_dist_info()
# random seed
seed = opt.get('manual_seed')
if seed is None:
seed = random.randint(1, 10000)
opt['manual_seed'] = seed
set_random_seed(seed + opt['rank'])
# force to update yml options
if args.force_yml is not None:
for entry in args.force_yml:
# now do not support creating new keys
keys, value = entry.split('=')
keys, value = keys.strip(), value.strip()
value = _postprocess_yml_value(value)
eval_str = 'opt'
for key in keys.split(':'):
eval_str += f'["{key}"]'
eval_str += '=value'
# using exec function
exec(eval_str)
opt['auto_resume'] = args.auto_resume
opt['is_train'] = is_train
# debug setting
if args.debug and not opt['name'].startswith('debug'):
opt['name'] = 'debug_' + opt['name']
if opt['num_gpu'] == 'auto':
opt['num_gpu'] = torch.cuda.device_count()
# datasets
for phase, dataset in opt['datasets'].items():
# for multiple datasets, e.g., val_1, val_2; test_1, test_2
phase = phase.split('_')[0]
dataset['phase'] = phase
if 'scale' in opt:
dataset['scale'] = opt['scale']
if dataset.get('dataroot_gt') is not None:
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
if dataset.get('dataroot_lq') is not None:
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
# paths
for key, val in opt['path'].items():
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
opt['path'][key] = osp.expanduser(val)
if is_train:
experiments_root = opt['path'].get('experiments_root')
if experiments_root is None:
experiments_root = osp.join(root_path, 'experiments')
experiments_root = osp.join(experiments_root, opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
opt['path']['log'] = experiments_root
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
# change some options for debug mode
if 'debug' in opt['name']:
if 'val' in opt:
opt['val']['val_freq'] = 8
opt['logger']['print_freq'] = 1
opt['logger']['save_checkpoint_freq'] = 8
else: # test
results_root = opt['path'].get('results_root')
if results_root is None:
results_root = osp.join(root_path, 'results')
results_root = osp.join(results_root, opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = results_root
opt['path']['visualization'] = osp.join(results_root, 'visualization')
return opt, args
@master_only
def copy_opt_file(opt_file, experiments_root):
# copy the yml file to the experiment root
import sys
import time
from shutil import copyfile
cmd = ' '.join(sys.argv)
filename = osp.join(experiments_root, osp.basename(opt_file))
copyfile(opt_file, filename)
with open(filename, 'r+') as f:
lines = f.readlines()
lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
f.seek(0)
f.writelines(lines)
================================================
FILE: bsr/utils/plot_util.py
================================================
import re
def read_data_from_tensorboard(log_path, tag):
"""Get raw data (steps and values) from tensorboard events.
Args:
log_path (str): Path to the tensorboard log.
tag (str): tag to be read.
"""
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
# tensorboard event
event_acc = EventAccumulator(log_path)
event_acc.Reload()
scalar_list = event_acc.Tags()['scalars']
print('tag list: ', scalar_list)
steps = [int(s.step) for s in event_acc.Scalars(tag)]
values = [s.value for s in event_acc.Scalars(tag)]
return steps, values
def read_data_from_txt_2v(path, pattern, step_one=False):
"""Read data from txt with 2 returned values (usually [step, value]).
Args:
path (str): path to the txt file.
pattern (str): re (regular expression) pattern.
step_one (bool): add 1 to steps. Default: False.
"""
with open(path) as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
steps = []
values = []
pattern = re.compile(pattern)
for line in lines:
match = pattern.match(line)
if match:
steps.append(int(match.group(1)))
values.append(float(match.group(2)))
if step_one:
steps = [v + 1 for v in steps]
return steps, values
def read_data_from_txt_1v(path, pattern):
"""Read data from txt with 1 returned values.
Args:
path (str): path to the txt file.
pattern (str): re (regular expression) pattern.
"""
with open(path) as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
data = []
pattern = re.compile(pattern)
for line in lines:
match = pattern.match(line)
if match:
data.append(float(match.group(1)))
return data
def smooth_data(values, smooth_weight):
""" Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501
Args:
values (list): A list of values to be smoothed.
smooth_weight (float): Smooth weight.
"""
values_sm = []
last_sm_value = values[0]
for value in values:
value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value
values_sm.append(value_sm)
last_sm_value = value_sm
return values_sm
================================================
FILE: bsr/utils/registry.py
================================================
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj, suffix=None):
if isinstance(suffix, str):
name = name + '_' + suffix
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj
def register(self, obj=None, suffix=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class, suffix)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj, suffix)
def get(self, name, suffix='basicsr'):
ret = self._obj_map.get(name)
if ret is None:
ret = self._obj_map.get(name + '_' + suffix)
print(f'Name {name} is not found, use name: {name}_{suffix}!')
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret
def __contains__(self, name):
return name in self._obj_map
def __iter__(self):
return iter(self._obj_map.items())
def keys(self):
return self._obj_map.keys()
DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')
================================================
FILE: config.yml
================================================
dataroot_gt: path_to_HR_images_of_LSDIR
scale: 4
# the first degradation process
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
resize_range: [0.3, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 15]
poisson_scale_range: [0.05, 2.0]
gray_noise_prob: 0.4
jpeg_range: [60, 95]
# the second degradation process
second_blur_prob: 0.5
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
resize_range2: [0.6, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 12]
poisson_scale_range2: [0.05, 1.0]
gray_noise_prob2: 0.4
jpeg_range2: [60, 100]
gt_size: 512
blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 0.1
blur_sigma: [0.2, 1.5]
betag_range: [0.5, 2.0]
betap_range: [1, 1.5]
blur_kernel_size2: 11
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 0.1
blur_sigma2: [0.2, 1.0]
betag_range2: [0.5, 2.0]
betap_range2: [1, 1.5]
final_sinc_prob: 0.8
use_hflip: True
use_rot: False
iter_num: 1000
================================================
FILE: dataset.py
================================================
import torch, random, cv2, os, math, glob
import torch.nn.functional as F
import numpy as np
from bsr.degradations import circular_lowpass_kernel, random_mixed_kernels, random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from bsr.transforms import augment, paired_random_crop
from bsr.utils import FileClient, imfrombytes, img2tensor, DiffJPEG
from bsr.utils.img_process_util import filter2D
class RealESRGANDataset(torch.utils.data.Dataset):
def __init__(self, opt, bsz):
super(RealESRGANDataset, self).__init__()
self.opt = opt
self.file_client = FileClient("disk")
self.gt_folder = opt["dataroot_gt"]
self.len = bsz * opt["iter_num"]
self.paths = glob.glob(os.path.join(self.gt_folder, "**/*"), recursive=True)
# blur settings for the first degradation
self.blur_kernel_size = opt["blur_kernel_size"]
self.kernel_list = opt["kernel_list"]
self.kernel_prob = opt["kernel_prob"] # a list for each kernel probability
self.blur_sigma = opt["blur_sigma"]
self.betag_range = opt["betag_range"] # betag used in generalized Gaussian blur kernels
self.betap_range = opt["betap_range"] # betap used in plateau blur kernels
self.sinc_prob = opt["sinc_prob"] # the probability for sinc filters
# blur settings for the second degradation
self.blur_kernel_size2 = opt["blur_kernel_size2"]
self.kernel_list2 = opt["kernel_list2"]
self.kernel_prob2 = opt["kernel_prob2"]
self.blur_sigma2 = opt["blur_sigma2"]
self.betag_range2 = opt["betag_range2"]
self.betap_range2 = opt["betap_range2"]
self.sinc_prob2 = opt["sinc_prob2"]
# a final sinc filter
self.final_sinc_prob = opt["final_sinc_prob"]
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
def __getitem__(self, index):
index = random.randint(0, len(self.paths) - 1)
gt_path = self.paths[index]
img_gt = imfrombytes(self.file_client.get(gt_path, "gt"), float32=True)
img_gt = augment(img_gt, self.opt["use_hflip"], self.opt["use_rot"])
h, w = img_gt.shape[0:2]
crop_pad_size = self.opt.gt_size
if h < crop_pad_size or w < crop_pad_size:
pad_h = max(0, crop_pad_size - h)
pad_w = max(0, crop_pad_size - w)
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
h, w = img_gt.shape[0:2]
top = random.randint(0, h - crop_pad_size)
left = random.randint(0, w - crop_pad_size)
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt["sinc_prob"]:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt["sinc_prob2"]:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.opt["final_sinc_prob"]:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
return_d = {"gt": img_gt, "kernel1": kernel, "kernel2": kernel2, "sinc_kernel": sinc_kernel, "gt_path": gt_path}
return return_d
def __len__(self):
return self.len
class RealESRGANDegrader:
def __init__(self, opt, device):
self.opt = opt
self.device = device
self.jpeger = DiffJPEG(differentiable=False).to(device) # simulate JPEG compression artifacts
self.queue_size = 1200
@torch.no_grad()
def _dequeue_and_enqueue(self):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b, c, h, w = self.lq.size()
if not hasattr(self, "queue_lr"):
assert self.queue_size % b == 0, f"queue size {self.queue_size} should be divisible by batch size {b}"
self.queue_lr = torch.zeros(self.queue_size, c, h, w).to(self.device)
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).to(self.device)
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]
# get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
# update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone()
self.lq = lq_dequeue
self.gt = gt_dequeue
else:
# only do enqueue
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def degrade(self, data):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
# training data synthesis
self.gt = data["gt"].to(self.device)
self.kernel1 = data["kernel1"].to(self.device)
self.kernel2 = data["kernel2"].to(self.device)
self.sinc_kernel = data["sinc_kernel"].to(self.device)
ori_h, ori_w = self.gt.size()[2:4]
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(self.gt, self.kernel1)
# random resize
updown_type = random.choices(["up", "down", "keep"], self.opt["resize_prob"])[0]
if updown_type == "up":
scale = np.random.uniform(1, self.opt["resize_range"][1])
elif updown_type == "down":
scale = np.random.uniform(self.opt["resize_range"][0], 1)
else:
scale = 1
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = self.opt["gray_noise_prob"]
if np.random.uniform() < self.opt["gaussian_noise_prob"]:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt["noise_range"], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt["poisson_scale_range"],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt["jpeg_range"])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if np.random.uniform() < self.opt["second_blur_prob"]:
out = filter2D(out, self.kernel2)
# random resize
updown_type = random.choices(["up", "down", "keep"], self.opt["resize_prob2"])[0]
if updown_type == "up":
scale = np.random.uniform(1, self.opt["resize_range2"][1])
elif updown_type == "down":
scale = np.random.uniform(self.opt["resize_range2"][0], 1)
else:
scale = 1
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(
out, size=(int(ori_h / self.opt["scale"] * scale), int(ori_w / self.opt["scale"] * scale)), mode=mode)
# add noise
gray_noise_prob = self.opt["gray_noise_prob2"]
if np.random.uniform() < self.opt["gaussian_noise_prob2"]:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt["noise_range2"], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt["poisson_scale_range2"],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(out, size=(ori_h // self.opt["scale"], ori_w // self.opt["scale"]), mode=mode)
out = filter2D(out, self.sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt["jpeg_range2"])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt["jpeg_range2"])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(out, size=(ori_h // self.opt["scale"], ori_w // self.opt["scale"]), mode=mode)
out = filter2D(out, self.sinc_kernel)
# clamp and round
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
# random crop
gt_size = self.opt["gt_size"]
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt["scale"])
# training pair pool
self._dequeue_and_enqueue()
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
return self.lq, self.gt
================================================
FILE: evaluate.py
================================================
import torch, os, glob, pyiqa
from argparse import ArgumentParser
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
parser = ArgumentParser()
parser.add_argument("--HR_dir", type=str, default="testset/RealSR/HR")
parser.add_argument("--SR_dir", type=str, default="result/RealSR")
args = parser.parse_args()
device = torch.device("cuda")
psnr = pyiqa.create_metric("psnr", test_y_channel=True, color_space="ycbcr", device=device)
ssim = pyiqa.create_metric("ssim", test_y_channel=True, color_space="ycbcr", device=device)
lpips = pyiqa.create_metric("lpips", device=device)
dists = pyiqa.create_metric("dists", device=device)
fid = pyiqa.create_metric("fid", device=device)
niqe = pyiqa.create_metric("niqe", device=device)
maniqa = pyiqa.create_metric("maniqa-pipal", device=device)
clipiqa = pyiqa.create_metric("clipiqa", device=device)
musiq = pyiqa.create_metric("musiq", device=device)
test_SR_paths = list(sorted(glob.glob(os.path.join(args.SR_dir, "*"))))
test_HR_paths = list(sorted(glob.glob(os.path.join(args.HR_dir, "*"))))
metrics = {"psnr": [], "ssim": [], "lpips": [], "dists": [], "niqe": [], "maniqa": [], "musiq": [], "clipiqa": []}
for i, (SR_path, HR_path) in tqdm(enumerate(zip(test_SR_paths, test_HR_paths))):
SR = Image.open(SR_path).convert("RGB")
SR = transforms.ToTensor()(SR).to(device).unsqueeze(0)
HR = Image.open(HR_path).convert("RGB")
HR = transforms.ToTensor()(HR).to(device).unsqueeze(0)
metrics["psnr"].append(psnr(SR, HR).item())
metrics["ssim"].append(ssim(SR, HR).item())
metrics["lpips"].append(lpips(SR, HR).item())
metrics["dists"].append(dists(SR, HR).item())
metrics["niqe"].append(niqe(SR).item())
metrics["maniqa"].append(maniqa(SR).item())
metrics["clipiqa"].append(clipiqa(SR).item())
metrics["musiq"].append(musiq(SR).item())
for k in metrics.keys():
metrics[k] = np.mean(metrics[k])
metrics["fid"] = fid(args.SR_dir, args.HR_dir)
for k, v in metrics.items():
if k == "niqe":
print(k, f"{v:.3g}")
elif k == "fid":
print(k, f"{v:.5g}")
else:
print(k, f"{v:.4g}")
================================================
FILE: evaluate_debug.sh
================================================
HF_ENDPOINT=https://hf-mirror.com \
CUDA_VISIBLE_DEVICES=0 \
python -u evaluate.py \
--HR_dir=testset/RealSR/HR \
--SR_dir=result/RealSR
================================================
FILE: forward.py
================================================
import torch
def MyUNet2DConditionModel_SD_forward(self, x):
global skip
x = self.conv_in(x)
skip = [x]
x = self.body(x)
return x
def MyCrossAttnDownBlock2D_SD_forward(self, x):
for i in range(2):
x = self.resnets[i](x)
x = self.attentions[i](x)
skip.append(x)
if self.downsamplers is not None:
x = self.downsamplers[0](x)
skip.append(x)
return x
def MyCrossAttnUpBlock2D_SD_forward(self, x):
for i in range(3):
x = self.resnets[i](torch.cat([x, skip.pop()], dim=1))
x = self.attentions[i](x)
if self.upsamplers is not None:
x = self.upsamplers[0](x)
return x
def MyDownBlock2D_SD_forward(self, x):
for i in range(2):
x = self.resnets[i](x)
skip.append(x)
return x
def MyUNetMidBlock2DCrossAttn_SD_forward(self, x):
x = self.resnets[0](x)
x = self.attentions[0](x)
x = self.resnets[1](x)
return x
def MyUpBlock2D_SD_forward(self, x):
for i in range(3):
x = self.resnets[i](torch.cat([x, skip.pop()], dim=1))
x = self.upsamplers[0](x)
return x
def MyResnetBlock2D_SD_forward(self, x_in):
x = self.norm1(x_in)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.conv2(x)
if self.in_channels == self.out_channels:
return x + x_in
return x + self.conv_shortcut(x_in)
def MyTransformer2DModel_SD_forward(self, x_in):
b, c, h, w = x_in.shape
x = self.norm(x_in)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c).contiguous()
x = self.proj_in(x)
for block in self.transformer_blocks:
x = x + block.attn1(block.norm1(x))
x = x + block.ff(block.norm3(x))
x = self.proj_out(x)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2).contiguous()
return x + x_in
================================================
FILE: model.py
================================================
import torch, types, copy
from torch import nn
import torch.nn.functional as F
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, \
CrossAttnUpBlock2D, \
DownBlock2D, \
UpBlock2D, \
UNetMidBlock2DCrossAttn
from diffusers.models.resnet import ResnetBlock2D
from diffusers.models.transformers.transformer_2d import Transformer2DModel
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.downsampling import Downsample2D
from diffusers.models.upsampling import Upsample2D
from forward import MyUNet2DConditionModel_SD_forward, \
MyCrossAttnDownBlock2D_SD_forward, \
MyDownBlock2D_SD_forward, \
MyUNetMidBlock2DCrossAttn_SD_forward, \
MyCrossAttnUpBlock2D_SD_forward, \
MyUpBlock2D_SD_forward, \
MyResnetBlock2D_SD_forward, \
MyTransformer2DModel_SD_forward
def find_parent(model, module_name):
components = module_name.split(".")
parent = model
for comp in components[:-1]:
parent = getattr(parent, comp)
return parent, components[-1]
def halve_channels(model):
for name, module in model.named_modules():
if hasattr(module, "pruned"):
continue
if isinstance(module, nn.Conv2d):
in_channels = int(module.in_channels * 0.75)
out_channels = int(module.out_channels * 0.75)
new_conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None)
with torch.no_grad():
new_conv.weight.copy_(module.weight[:out_channels, :in_channels])
if module.bias is not None:
new_conv.bias.copy_(module.bias[:out_channels])
parent, last_name = find_parent(model, name)
setattr(parent, last_name, new_conv)
new_conv.pruned = True
elif isinstance(module, nn.Linear):
in_features = int(module.in_features * 0.75)
out_features = int(module.out_features * 0.75)
new_linear = nn.Linear(in_features=in_features,
out_features=out_features,
bias=module.bias is not None)
with torch.no_grad():
new_linear.weight.copy_(module.weight[:out_features, :in_features])
if module.bias is not None:
new_linear.bias.copy_(module.bias[:out_features])
parent, last_name = find_parent(model, name)
setattr(parent, last_name, new_linear)
new_linear.pruned = True
elif isinstance(module, nn.GroupNorm):
num_channels = int(module.num_channels * 0.75)
for num_groups in [32, 24, 16, 12, 8, 6, 4, 2, 1]:
if num_channels % num_groups == 0:
break
new_gn = nn.GroupNorm(num_groups=num_groups,
num_channels=num_channels,
eps=module.eps,
affine=module.affine)
with torch.no_grad():
new_gn.weight.copy_(module.weight[:num_channels])
new_gn.bias.copy_(module.bias[:num_channels])
parent, last_name = find_parent(model, name)
setattr(parent, last_name, new_gn)
new_gn.pruned = True
elif isinstance(module, nn.LayerNorm):
normalized_shape = int(module.normalized_shape[0] * 0.75)
new_ln = nn.LayerNorm(normalized_shape,
eps=module.eps,
elementwise_affine=module.elementwise_affine)
with torch.no_grad():
new_ln.weight.copy_(module.weight[:normalized_shape])
new_ln.bias.copy_(module.bias[:normalized_shape])
parent, last_name = find_parent(model, name)
setattr(parent, last_name, new_ln)
new_ln.pruned = True
elif isinstance(module, Downsample2D) or isinstance(module, Upsample2D):
module.channels = int(module.channels * 0.75)
class Net(nn.Module):
def __init__(self, unet, decoder):
super().__init__()
del unet.time_embedding
new_conv_in = nn.Conv2d(16, 320, 3, padding=1)
new_conv_in.weight.data = unet.conv_in.weight.data.repeat(1, 4, 1, 1)
new_conv_in.bias.data = unet.conv_in.bias.data
unet.conv_in = new_conv_in
new_conv_out = nn.Conv2d(320, 342, 3, padding=1)
new_conv_out.weight.data = unet.conv_out.weight.data.repeat(86, 1, 1, 1)[:342]
new_conv_out.bias.data = unet.conv_out.bias.data.repeat(86,)[:342]
unet.conv_out = new_conv_out
def ResnetBlock2D_remove_time_emb_proj(module):
if isinstance(module, ResnetBlock2D):
del module.time_emb_proj
unet.apply(ResnetBlock2D_remove_time_emb_proj)
def BasicTransformerBlock_remove_cross_attn(module):
if isinstance(module, BasicTransformerBlock):
del module.attn2, module.norm2
unet.apply(BasicTransformerBlock_remove_cross_attn)
def set_inplace_to_true(module):
if isinstance(module, nn.Dropout) or isinstance(module, nn.SiLU):
module.inplace = True
unet.apply(set_inplace_to_true)
def replace_forward_methods(module):
if isinstance(module, CrossAttnDownBlock2D):
module.forward = types.MethodType(MyCrossAttnDownBlock2D_SD_forward, module)
elif isinstance(module, DownBlock2D):
module.forward = types.MethodType(MyDownBlock2D_SD_forward, module)
elif isinstance(module, UNetMidBlock2DCrossAttn):
module.forward = types.MethodType(MyUNetMidBlock2DCrossAttn_SD_forward, module)
elif isinstance(module, UpBlock2D):
module.forward = types.MethodType(MyUpBlock2D_SD_forward, module)
elif isinstance(module, CrossAttnUpBlock2D):
module.forward = types.MethodType(MyCrossAttnUpBlock2D_SD_forward, module)
elif isinstance(module, ResnetBlock2D):
module.forward = types.MethodType(MyResnetBlock2D_SD_forward, module)
elif isinstance(module, Transformer2DModel):
module.forward = types.MethodType(MyTransformer2DModel_SD_forward, module)
unet.apply(replace_forward_methods)
unet.forward = types.MethodType(MyUNet2DConditionModel_SD_forward, unet)
halve_channels(unet)
unet.body = nn.Sequential(
*unet.down_blocks,
unet.mid_block,
*unet.up_blocks,
unet.conv_norm_out,
unet.conv_act,
unet.conv_out,
)
del decoder.conv_in, decoder.up_blocks, decoder.conv_norm_out, decoder.conv_act, decoder.conv_out
self.body = nn.Sequential(
nn.PixelUnshuffle(2),
unet,
decoder.mid_block,
)
def forward(self, x):
return self.body(x)
================================================
FILE: ram/configs/condition_config.json
================================================
{
"nf": 64
}
================================================
FILE: ram/configs/med_config.json
================================================
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30524,
"encoder_width": 768,
"add_cross_attention": true
}
================================================
FILE: ram/configs/q2l_config.json
================================================
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 4,
"num_hidden_layers": 2,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30522,
"encoder_width": 768,
"add_cross_attention": true,
"add_tag_cross_attention": false
}
================================================
FILE: ram/configs/swin/config_swinB_384.json
================================================
{
"ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
"vision_width": 1024,
"image_res": 384,
"window_size": 12,
"embed_dim": 128,
"depths": [ 2, 2, 18, 2 ],
"num_heads": [ 4, 8, 16, 32 ]
}
================================================
FILE: ram/configs/swin/config_swinL_384.json
================================================
{
"ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth",
"vision_width": 1536,
"image_res": 384,
"window_size": 12,
"embed_dim": 192,
"depths": [ 2, 2, 18, 2 ],
"num_heads": [ 6, 12, 24, 48 ]
}
================================================
FILE: ram/configs/swin/config_swinL_444.json
================================================
{
"ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth",
"vision_width": 1536,
"image_res": 444,
"window_size": 12,
"embed_dim": 192,
"depths": [ 2, 2, 18, 2 ],
"num_heads": [ 6, 12, 24, 48 ]
}
================================================
FILE: ram/data/ram_tag_list.txt
================================================
3D CG rendering
3D glasses
abacus
abalone
monastery
belly
academy
accessory
accident
accordion
acorn
acrylic paint
act
action
action film
activity
actor
adaptation
add
adhesive tape
adjust
adult
adventure
advertisement
antenna
aerobics
spray can
afro
agriculture
aid
air conditioner
air conditioning
air sock
aircraft cabin
aircraft model
air field
air line
airliner
airman
plane
airplane window
airport
airport runway
airport terminal
airship
airshow
aisle
alarm
alarm clock
mollymawk
album
album cover
alcohol
alcove
algae
alley
almond
aloe vera
alp
alpaca
alphabet
german shepherd
altar
amber
ambulance
bald eagle
American shorthair
amethyst
amphitheater
amplifier
amusement park
amusement ride
anchor
ancient
anemone
angel
angle
animal
animal sculpture
animal shelter
animation
animation film
animator
anime
ankle
anklet
anniversary
trench coat
ant
antelope
antique
antler
anvil
apartment
ape
app
app icon
appear
appearance
appetizer
applause
apple
apple juice
apple pie
apple tree
applesauce
appliance
appointment
approach
apricot
apron
aqua
aquarium
aquarium fish
aqueduct
arcade
arcade machine
arch
arch bridge
archaelogical excavation
archery
archipelago
architect
architecture
archive
archway
area
arena
argument
arm
armadillo
armband
armchair
armoire
armor
army
army base
army tank
array
arrest
arrow
art
art exhibition
art gallery
art print
art school
art studio
art vector illustration
artichoke
article
artifact
artist
artists loft
ash
ashtray
asia temple
asparagus
asphalt road
assemble
assembly
assembly line
association
astronaut
astronomer
athlete
athletic
atlas
atm
atmosphere
atrium
attach
fighter jet
attend
attraction
atv
eggplant
auction
audi
audio
auditorium
aurora
author
auto factory
auto mechanic
auto part
auto show
auto showroom
car battery
automobile make
automobile model
motor vehicle
autumn
autumn forest
autumn leave
autumn park
autumn tree
avatar
avenue
aviator sunglasses
avocado
award
award ceremony
award winner
shed
ax
azalea
baboon
baby
baby bottle
baby carriage
baby clothe
baby elephant
baby food
baby seat
baby shower
back
backdrop
backlight
backpack
backyard
bacon
badge
badger
badlands
badminton
badminton racket
bag
bagel
bagpipe
baguette
bait
baked goods
baker
bakery
baking
baking sheet
balance
balance car
balcony
ball
ball pit
ballerina
ballet
ballet dancer
ballet skirt
balloon
balloon arch
baseball player
ballroom
bamboo
bamboo forest
banana
banana bread
banana leaf
banana tree
band
band aid
bandage
headscarf
bandeau
bangs
bracelet
balustrade
banjo
bank
bank card
bank vault
banknote
banner
banquet
banquet hall
banyan tree
baozi
baptism
bar
bar code
bar stool
barbecue
barbecue grill
barbell
barber
barber shop
barbie
barge
barista
bark
barley
barn
barn owl
barn door
barrel
barricade
barrier
handcart
bartender
baseball
baseball base
baseball bat
baseball hat
baseball stadium
baseball game
baseball glove
baseball pitcher
baseball team
baseball uniform
basement
basil
basin
basket
basket container
basketball
basketball backboard
basketball coach
basketball court
basketball game
basketball hoop
basketball player
basketball stadium
basketball team
bass
bass guitar
bass horn
bassist
bat
bath
bath heater
bath mat
bath towel
swimwear
bathrobe
bathroom
bathroom accessory
bathroom cabinet
bathroom door
bathroom mirror
bathroom sink
toilet paper
bathroom window
batman
wand
batter
battery
battle
battle rope
battleship
bay
bay bridge
bay window
bayberry
bazaar
beach
beach ball
beach chair
beach house
beach hut
beach towel
beach volleyball
lighthouse
bead
beagle
beak
beaker
beam
bean
bean bag chair
beanbag
bear
bear cub
beard
beast
beat
beautiful
beauty
beauty salon
beaver
bed
bedcover
bed frame
bedroom
bedding
bedpan
bedroom window
bedside lamp
bee
beech tree
beef
beekeeper
beeper
beer
beer bottle
beer can
beer garden
beer glass
beer hall
beet
beetle
beige
clock
bell pepper
bell tower
belt
belt buckle
bench
bend
bengal tiger
bento
beret
berry
berth
beverage
bib
bibimbap
bible
bichon
bicycle
bicycle helmet
bicycle wheel
biker
bidet
big ben
bike lane
bike path
bike racing
bike ride
bikini
bikini top
bill
billard
billboard
billiard table
bin
binder
binocular
biology laboratory
biplane
birch
birch tree
bird
bird bath
bird feeder
bird house
bird nest
birdbath
bird cage
birth
birthday
birthday cake
birthday candle
birthday card
birthday party
biscuit
bishop
bison
bit
bite
black
black sheep
blackberry
blackbird
blackboard
blacksmith
blade
blanket
sports coat
bleacher
blender
blessing
blind
eye mask
flasher
snowstorm
block
blog
blood
bloom
blossom
blouse
blow
hair drier
blowfish
blue
blue artist
blue jay
blue sky
blueberry
bluebird
pig
board
board eraser
board game
boardwalk
boat
boat deck
boat house
paddle
boat ride
bobfloat
bobcat
body
bodyboard
bodybuilder
boiled egg
boiler
bolo tie
bolt
bomb
bomber
bonasa umbellu
bone
bonfire
bonnet
bonsai
book
book cover
bookcase
folder
bookmark
bookshelf
bookstore
boom microphone
boost
boot
border
Border collie
botanical garden
bottle
bottle cap
bottle opener
bottle screw
bougainvillea
boulder
bouquet
boutique
boutique hotel
bow
bow tie
bow window
bowl
bowling
bowling alley
bowling ball
bowling equipment
box
box girder bridge
box turtle
boxer
underdrawers
boxing
boxing glove
boxing ring
boy
brace
bracket
braid
brain
brake
brake light
branch
brand
brandy
brass
brass plaque
bread
breadbox
break
breakfast
seawall
chest
brewery
brick
brick building
wall
brickwork
wedding dress
bride
groom
bridesmaid
bridge
bridle
briefcase
bright
brim
broach
broadcasting
broccoli
bronze
bronze medal
bronze sculpture
bronze statue
brooch
creek
broom
broth
brown
brown bear
brownie
brunch
brunette
brush
coyote
brussels sprout
bubble
bubble gum
bubble tea
bucket cabinet
shield
bud
buddha
buffalo
buffet
bug
build
builder
building
building block
building facade
building material
lamp
bull
bulldog
bullet
bullet train
bulletin board
bulletproof vest
bullfighting
megaphone
bullring
bumblebee
bumper
roll
bundle
bungee
bunk bed
bunker
bunny
buoy
bureau
burial chamber
burn
burrito
bus
bus driver
bus interior
bus station
bus stop
bus window
bush
business
business card
business executive
business suit
business team
business woman
businessman
bust
butcher
butchers shop
butte
butter
cream
butterfly
butterfly house
button
buttonwood
buy
taxi
cabana
cabbage
cabin
cabin car
cabinet
cabinetry
cable
cable car
cactus
cafe
canteen
cage
cake
cake stand
calculator
caldron
calendar
calf
call
phone box
calligraphy
calm
camcorder
camel
camera
camera lens
camouflage
camp
camper
campfire
camping
campsite
campus
can
can opener
canal
canary
cancer
candle
candle holder
candy
candy bar
candy cane
candy store
cane
jar
cannon
canopy
canopy bed
cantaloupe
cantilever bridge
canvas
canyon
cap
cape
cape cod
cappuccino
capsule
captain
capture
car
car dealership
car door
car interior
car logo
car mirror
parking lot
car seat
car show
car wash
car window
caramel
card
card game
cardboard
cardboard box
cardigan
cardinal
cargo
cargo aircraft
cargo ship
caribbean
carnation
carnival
carnivore
carousel
carp
carpenter
carpet
slipper
house finch
coach
dalmatian
aircraft carrier
carrot
carrot cake
carry
cart
carton
cartoon
cartoon character
cartoon illustration
cartoon style
carve
case
cash
cashew
casino
casserole
cassette
cassette deck
plaster bandage
casting
castle
cat
cat bed
cat food
cat furniture
cat tree
catacomb
catamaran
catamount
catch
catcher
caterpillar
catfish
cathedral
cattle
catwalk
catwalk show
cauliflower
cave
caviar
CD
CD player
cedar
ceiling
ceiling fan
celebrate
celebration
celebrity
celery
cello
smartphone
cement
graveyard
centerpiece
centipede
ceramic
ceramic tile
cereal
ceremony
certificate
chain
chain saw
chair
chairlift
daybed
chalet
chalice
chalk
chamber
chameleon
champagne
champagne flute
champion
championship
chandelier
changing table
channel
chap
chapel
character sculpture
charcoal
charge
charger
chariot
charity
charity event
charm
graph
chase
chassis
check
checkbook
chessboard
checklist
cheer
cheerlead
cheese
cheeseburger
cheesecake
cheetah
chef
chemical compound
chemist
chemistry
chemistry lab
cheongsam
cherry
cherry blossom
cherry tomato
cherry tree
chess
chestnut
chicken
chicken breast
chicken coop
chicken salad
chicken wing
garbanzo
chiffonier
chihuahua
child
child actor
childs room
chile
chili dog
chimney
chimpanzee
chinaware
chinese cabbage
chinese garden
chinese knot
chinese rose
chinese tower
chip
chipmunk
chisel
chocolate
chocolate bar
chocolate cake
chocolate chip
chocolate chip cookie
chocolate milk
chocolate mousse
truffle
choir
kitchen knife
cutting board
chopstick
christmas
christmas ball
christmas card
christmas decoration
christmas dinner
christmas eve
christmas hat
christmas light
christmas market
christmas ornament
christmas tree
chrysanthemum
church
church tower
cider
cigar
cigar box
cigarette
cigarette case
waistband
cinema
photographer
cinnamon
circle
circuit
circuit board
circus
water tank
citrus fruit
city
city bus
city hall
city nightview
city park
city skyline
city square
city street
city wall
city view
clam
clarinet
clasp
class
classic
classroom
clavicle
claw
clay
pottery
clean
clean room
cleaner
cleaning product
clear
cleat
clementine
client
cliff
climb
climb mountain
climber
clinic
clip
clip art
clipboard
clipper
clivia
cloak
clogs
close-up
closet
cloth
clothe
clothing
clothespin
clothesline
clothing store
cloud
cloud forest
cloudy
clover
joker
clown fish
club
clutch
clutch bag
coal
coast
coat
coatrack
cob
cock
cockatoo
cocker
cockpit
roach
cocktail
cocktail dress
cocktail shaker
cocktail table
cocoa
coconut
coconut tree
coffee
coffee bean
coffee cup
coffee machine
coffee shop
coffeepot
coffin
cognac
spiral
coin
coke
colander
cold
slaw
collaboration
collage
collection
college student
sheepdog
crash
color
coloring book
coloring material
pony
pillar
comb
combination lock
comic
comedy
comedy film
comet
comfort
comfort food
comic book
comic book character
comic strip
commander
commentator
community
commuter
company
compass
compete
contest
competitor
composer
composition
compost
computer
computer box
computer chair
computer desk
keyboard
computer monitor
computer room
computer screen
computer tower
concept car
concert
concert hall
conch
concrete
condiment
condom
condominium
conductor
cone
meeting
conference center
conference hall
meeting room
confetti
conflict
confluence
connect
connector
conservatory
constellation
construction site
construction worker
contain
container
container ship
continent
profile
contract
control
control tower
convenience store
convention
conversation
converter
convertible
transporter
cook
cooking
cooking spray
cooker
cool
cooler
copper
copy
coral
coral reef
rope
corded phone
liquor
corgi
cork
corkboard
cormorant
corn
corn field
cornbread
corner
trumpet
cornice
cornmeal
corral
corridor
corset
cosmetic
cosmetics brush
cosmetics mirror
cosplay
costume
costumer film designer
infant bed
cottage
cotton
cotton candy
couch
countdown
counter
counter top
country artist
country house
country lane
country pop artist
countryside
coupe
couple
couple photo
courgette
course
court
courthouse
courtyard
cousin
coverall
cow
cowbell
cowboy
cowboy boot
cowboy hat
crab
crabmeat
crack
cradle
craft
craftsman
cranberry
crane
crape
crapper
crate
crater lake
lobster
crayon
cream cheese
cream pitcher
create
creature
credit card
crescent
croissant
crest
crew
cricket
cricket ball
cricket team
cricketer
crochet
crock pot
crocodile
crop
crop top
cross
crossbar
crossroad
crosstalk
crosswalk
crouton
crow
crowbar
crowd
crowded
crown
crt screen
crucifix
cruise
cruise ship
cruiser
crumb
crush
crutch
crystal
cub
cube
cucumber
cue
cuff
cufflink
cuisine
farmland
cup
cupcake
cupid
curb
curl
hair roller
currant
currency
curry
curtain
curve
pad
customer
cut
cutlery
cycle
cycling
cyclone
cylinder
cymbal
cypress
cypress tree
dachshund
daffodil
dagger
dahlia
daikon
dairy
daisy
dam
damage
damp
dance
dance floor
dance room
dancer
dandelion
dark
darkness
dart
dartboard
dashboard
date
daughter
dawn
day bed
daylight
deadbolt
death
debate
debris
decanter
deck
decker bus
decor
decorate
decorative picture
deer
defender
deity
delicatessen
deliver
demolition
monster
demonstration
den
denim jacket
dentist
department store
depression
derby
dermopathy
desert
desert road
design
designer
table
table lamp
desktop
desktop computer
dessert
destruction
detective
detergent
dew
dial
diamond
diaper
diaper bag
journal
die
diet
excavator
number
digital clock
dill
dinner
rowboat
dining room
dinner party
dinning table
dinosaur
dip
diploma
direct
director
dirt
dirt bike
dirt field
dirt road
dirt track
disaster
disciple
disco
disco ball
discotheque
disease
plate
dish antenna
dish washer
dishrag
dishes
dishsoap
Disneyland
dispenser
display
display window
trench
dive
diver
diving board
paper cup
dj
doberman
dock
doctor
document
documentary
dog
dog bed
dog breed
dog collar
dog food
dog house
doll
dollar
dollhouse
dolly
dolphin
dome
domicile
domino
donkey
donut
doodle
door
door handle
doormat
doorplate
doorway
dormitory
dough
downtown
dozer
drag
dragon
dragonfly
drain
drama
drama film
draw
drawer
drawing
drawing pin
pigtail
dress
dress hat
dress shirt
dress shoe
dress suit
dresser
dressing room
dribble
drift
driftwood
drill
drink
drinking water
drive
driver
driveway
drone
drop
droplight
dropper
drought
medicine
pharmacy
drum
drummer
drumstick
dry
duchess
duck
duckbill
duckling
duct tape
dude
duet
duffel
canoe
dumbbell
dumpling
dune
dunk
durian
dusk
dust
garbage truck
dustpan
duvet
DVD
dye
eagle
ear
earmuff
earphone
earplug
earring
earthquake
easel
easter
easter bunny
easter egg
eat
restaurant
eclair
eclipse
ecosystem
edit
education
educator
eel
egg
egg roll
egg tart
eggbeater
egret
Eiffel tower
elastic band
senior
electric chair
electric drill
electrician
electricity
electron
electronic
elephant
elevation map
elevator
elevator car
elevator door
elevator lobby
elevator shaft
embankment
embassy
embellishment
ember
emblem
embroidery
emerald
emergency
emergency service
emergency vehicle
emotion
Empire State Building
enamel
enclosure
side table
energy
engagement
engagement ring
engine
engine room
engineer
engineering
english shorthair
ensemble
enter
entertainer
entertainment
entertainment center
entrance
entrance hall
envelope
equestrian
equipment
eraser
erhu
erosion
escalator
escargot
espresso
estate
estuary
eucalyptus tree
evening
evening dress
evening light
evening sky
evening sun
event
evergreen
ewe
excavation
exercise
exhaust hood
exhibition
exit
explorer
explosion
extension cord
extinguisher
extractor
extrude
eye
eye shadow
eyebrow
eyeliner
fabric
fabric store
facade
face
face close-up
face powder
face towel
facial tissue holder
facility
factory
factory workshop
fair
fairground
fairy
falcon
fall
family
family car
family photo
family room
fan
fang
farm
farmer
farmer market
farmhouse
fashion
fashion accessory
fashion designer
fashion girl
fashion illustration
fashion look
fashion model
fashion show
fast food
fastfood restaurant
father
faucet
fault
fauna
fawn
fax
feast
feather
fedora
feed
feedbag
feeding
feeding chair
feline
mountain lion
fence
fender
fern
ferret
ferris wheel
ferry
fertilizer
festival
fiber
fiction
fiction book
field
field road
fig
fight
figure skater
figurine
file
file photo
file cabinet
fill
film camera
film director
film format
film premiere
film producer
filming
filter
fin
hand
finish line
fir
fir tree
fire
fire alarm
fire department
fire truck
fire escape
fire hose
fire pit
fire station
firecracker
fireman
fireplace
firework
firework display
first-aid kit
fish
fish boat
fish market
fish pond
fishbowl
fisherman
fishing
fishing boat
fishing net
fishing pole
fishing village
fitness
fitness course
five
fixture
fjord
flag
flag pole
flake
flame
flamingo
flannel
flap
flare
flash
flask
flat
flatfish
flavor
flea
flea market
fleet
flight
flight attendant
flip
flip-flop
flipchart
float
flock
flood
floor
floor fan
floor mat
floor plan
floor window
floral arrangement
florist
floss
flour
flow
flower
flower basket
flower bed
flower box
flower field
flower girl
flower market
fluid
flush
flute
fly
fly fishing
flyer
horse
foam
fog
foggy
foie gra
foil
folding chair
leaf
folk artist
folk dance
folk rock artist
fondant
hotpot
font
food
food coloring
food court
food processor
food stand
food truck
foosball
foot
foot bridge
football
football coach
football college game
football match
football field
football game
football helmet
football player
football stadium
football team
path
footprint
footrest
footstall
footwear
forbidden city
ford
forehead
forest
forest fire
forest floor
forest path
forest road
forge
fork
forklift
form
formal garden
formation
formula 1
fort
fortification
forward
fossil
foundation
fountain
fountain pen
fox
frame
freckle
highway
lorry
French
French bulldog
French fries
French toast
freshener
fridge
fried chicken
fried egg
fried rice
friendship
frisbee
frog
frost
frosting
frosty
frozen
fruit
fruit cake
fruit dish
fruit market
fruit salad
fruit stand
fruit tree
fruits shop
fry
frying pan
fudge
fuel
fume hood
fun
funeral
fungi
funnel
fur
fur coat
furniture
futon
gadget
muzzle
galaxy
gallery
game
game board
game controller
ham
gang
garage
garage door
garage kit
garbage
garden
garden asparagus
garden hose
garden spider
gardener
gardening
garfield
gargoyle
wreath
garlic
garment
gas
gas station
gas stove
gasmask
collect
gathering
gauge
gazebo
gear
gecko
geisha
gel
general store
generator
geranium
ghost
gift
gift bag
gift basket
gift box
gift card
gift shop
gift wrap
gig
gin
ginger
gingerbread
gingerbread house
ginkgo tree
giraffe
girl
give
glacier
gladiator
glass bead
glass bottle
glass bowl
glass box
glass building
glass door
glass floor
glass house
glass jar
glass plate
glass table
glass vase
glass wall
glass window
glasses
glaze
glider
earth
glove
glow
glue pudding
go
go for
goal
goalkeeper
goat
goat cheese
gobi
goggles
gold
gold medal
Golden Gate Bridge
golden retriever
goldfish
golf
golf cap
golf cart
golf club
golf course
golfer
goose
gorilla
gothic
gourd
government
government agency
gown
graduate
graduation
grain
grampus
grand prix
grandfather
grandmother
grandparent
granite
granola
grape
grapefruit
wine
grass
grasshopper
grassland
grassy
grater
grave
gravel
gravestone
gravy
gravy boat
gray
graze
grazing
green
greenery
greet
greeting
greeting card
greyhound
grid
griddle
grill
grille
grilled eel
grind
grinder
grits
grocery bag
grotto
ground squirrel
group
group photo
grove
grow
guacamole
guard
guard dog
guest house
guest room
guide
guinea pig
guitar
guitarist
gulf
gull
gun
gundam
gurdwara
guzheng
gym
gymnast
habitat
hacker
hail
hair
hair color
hair spray
hairbrush
haircut
hairgrip
hairnet
hairpin
hairstyle
half
hall
halloween
halloween costume
halloween pumpkin
halter top
hamburg
hamburger
hami melon
hammer
hammock
hamper
hamster
hand dryer
hand glass
hand towel
handbag
handball
handcuff
handgun
handkerchief
handle
handsaw
handshake
handstand
handwriting
hanfu
hang
hangar
hanger
happiness
harbor
harbor seal
hard rock artist
hardback book
safety helmet
hardware
hardware store
hardwood
hardwood floor
mouth organ
pipe organ
harpsichord
harvest
harvester
hassock
hat
hatbox
hautboy
hawthorn
hay
hayfield
hazelnut
head
head coach
headlight
headboard
headdress
headland
headquarter
hearing
heart
heart shape
heat
heater
heather
hedge
hedgehog
heel
helicopter
heliport
helmet
help
hen
henna
herb
herd
hermit crab
hero
heron
hibiscus
hibiscus flower
hide
high bar
high heel
highland
highlight
hike
hiker
hiking boot
hiking equipment
hill
hill country
hill station
hillside
hindu temple
hinge
hip
hip hop artist
hippo
historian
historic
history
hockey
hockey arena
hockey game
hockey player
hockey stick
hoe
hole
vacation
holly
holothurian
home
home appliance
home base
home decor
home interior
home office
home theater
homework
hummus
honey
beehive
honeymoon
hood
hoodie
hook
jump
horizon
hornbill
horned cow
hornet
horror
horror film
horse blanket
horse cart
horse farm
horse ride
horseback
horseshoe
hose
hospital
hospital bed
hospital room
host
inn
hot
hot air balloon
hot dog
hot sauce
hot spring
hotel
hotel lobby
hotel room
hotplate
hourglass
house
house exterior
houseplant
hoverboard
howler
huddle
hug
hula hoop
person
humidifier
hummingbird
humpback whale
hunt
hunting lodge
hurdle
hurricane
husky
hut
hyaena
hybrid
hydrangea
hydrant
seaplane
ice
ice bag
polar bear
ice cave
icecream
ice cream cone
ice cream parlor
ice cube
ice floe
ice hockey player
ice hockey team
lollipop
ice maker
rink
ice sculpture
ice shelf
skate
ice skating
iceberg
icicle
icing
icon
id photo
identity card
igloo
light
iguana
illuminate
illustration
image
impala
incense
independence day
individual
indoor
indoor rower
induction cooker
industrial area
industry
infantry
inflatable boat
information desk
infrastructure
ingredient
inhalator
injection
injury
ink
inking pad
inlet
inscription
insect
install
instrument
insulated cup
interaction
interior design
website
intersection
interview
invertebrate
invitation
ipad
iphone
ipod
iris
iron
ironing board
irrigation system
island
islet
isopod
ivory
ivy
izakaya
jack
jackcrab
jacket
jacuzzi
jade
jaguar
jail cell
jam
japanese garden
jasmine
jaw
jay
jazz
jazz artist
jazz fusion artist
jeans
jeep
jelly
jelly bean
jellyfish
jet
motorboat
jewel
jewellery
jewelry shop
jigsaw puzzle
rickshaw
jockey
jockey cap
jog
joint
journalist
joystick
judge
jug
juggle
juice
juicer
jujube
jump rope
jumpsuit
jungle
junkyard
kale
kaleidoscope
kangaroo
karaoke
karate
karting
kasbah
kayak
kebab
key
keycard
khaki
kick
kilt
kimono
kindergarden classroom
kindergarten
king
king crab
kiss
kit
kitchen
kitchen cabinet
kitchen counter
kitchen floor
kitchen hood
kitchen island
kitchen sink
kitchen table
kitchen utensil
kitchen window
kitchenware
kite
kiwi
knee pad
kneel
knife
rider
knit
knitting needle
knob
knocker
knot
koala
koi
ktv
laboratory
lab coat
label
labrador
maze
lace
lace dress
ladder
ladle
ladybird
lagoon
lake
lake district
lake house
lakeshore
lamb
lamb chop
lamp post
lamp shade
spear
land
land vehicle
landfill
landing
landing deck
landmark
landscape
landslide
lanyard
lantern
lap
laptop
laptop keyboard
larva
lasagne
laser
lash
lasso
latch
latex
latte
laugh
launch
launch event
launch party
laundromat
laundry
laundry basket
laundry room
lava
lavender
lawn
lawn wedding
lawyer
lay
lead
lead singer
lead to
leader
leak
lean
learn
leash
leather
leather jacket
leather shoe
speech
lecture hall
lecture room
ledge
leftover
leg
legend
legging
legislative chamber
lego
legume
lemon
lemon juice
lemonade
lemur
lens
lens flare
lentil
leopard
leotard
tights
leprechaun
lesson
letter
mailbox
letter logo
lettering
lettuce
level
library
license
license plate
lichen
lick
lid
lie
life belt
life jacket
lifeboat
lifeguard
lift
light fixture
light show
light switch
lighting
lightning
lightning rod
lilac
lily
limb
lime
limestone
limo
line
line art
line up
linen
liner
lion
lip balm
lipstick
liquid
liquor store
list
litchi
live
livestock
living room
living space
lizard
load
loading dock
loafer
hallway
locate
lock
lock chamber
locker
loft
log
log cabin
logo
loki
long hair
longboard
loom
loop
lose
lottery
lotus
love
loveseat
luggage
lumber
lumberjack
lunch
lunch box
lush
luxury
luxury yacht
mac
macadamia
macaque
macaroni
macaw
machete
machine
machine gun
magazine
magic
magician
magnet
magnifying glass
magnolia
magpie
mahjong
mahout
maid
chain mail
mail slot
make
makeover
makeup artist
makeup tool
mallard
mallard duck
mallet
mammal
mammoth
man
management
manager
manatee
mandala
mandarin orange
mandarine
mane
manga
manger
mango
mangosteen
mangrove
manhattan
manhole
manhole cover
manicure
mannequin
manor house
mansion
mantid
mantle
manufactured home
manufacturing
manuscript
map
maple
maple leaf
maple syrup
maraca
marathon
marble
march
marching band
mare
marigold
marine
marine invertebrate
marine mammal
puppet
mark
market
market square
market stall
marriage
martial
martial artist
martial arts gym
martini
martini glass
mascara
mascot
mashed potato
masher
mask
massage
mast
mat
matador
match
matchbox
material
mattress
mausoleum
maxi dress
meal
measuring cup
measuring tape
meat
meatball
mechanic
mechanical fan
medal
media
medical equipment
medical image
medical staff
medicine cabinet
medieval
medina
meditation
meerkat
meet
melon
monument
menu
mermaid
net
mess
messenger bag
metal
metal artist
metal detector
meter
mezzanine
microphone
microscope
microwave
midnight
milestone
military uniform
milk
milk can
milk tea
milkshake
mill
mine
miner
mineral
mineral water
miniskirt
miniature
minibus
minister
minivan
mint
mint candy
mirror
miss
missile
mission
mistletoe
mix
mixer
mixing bowl
mixture
moat
mobility scooter
model
model car
modern
modern tower
moisture
mold
molding
mole
monarch
money
monitor
monk
monkey
monkey wrench
monochrome
monocycle
monster truck
moon
moon cake
moonlight
moor
moose
swab
moped
morning
morning fog
morning light
morning sun
mortar
mosaic
mosque
mosquito
moss
motel
moth
mother
motherboard
motif
sport
motor
motorbike
motorcycle
motorcycle helmet
motorcycle racer
motorcyclist
motorsport
mound
mountain
mountain bike
mountain biker
mountain biking
mountain gorilla
mountain lake
mountain landscape
mountain pass
mountain path
mountain range
mountain river
mountain snowy
mountain stream
mountain view
mountain village
mountaineer
mountaineering bag
mouse
mousepad
mousetrap
mouth
mouthwash
move
movie poster
movie ticket
mower
mp3 player
mr
mud
muffin
mug
mulberry
mulch
mule
municipality
mural
muscle
muscle car
museum
mushroom
music
music festival
music stool
music
gitextract_c9hslnpw/ ├── LICENSE ├── README.md ├── bsr/ │ ├── degradations.py │ ├── transforms.py │ └── utils/ │ ├── __init__.py │ ├── color_util.py │ ├── diffjpeg.py │ ├── dist_util.py │ ├── download_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_process_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ ├── plot_util.py │ └── registry.py ├── config.yml ├── dataset.py ├── evaluate.py ├── evaluate_debug.sh ├── forward.py ├── model.py ├── ram/ │ ├── configs/ │ │ ├── condition_config.json │ │ ├── med_config.json │ │ ├── q2l_config.json │ │ └── swin/ │ │ ├── config_swinB_384.json │ │ ├── config_swinL_384.json │ │ └── config_swinL_444.json │ ├── data/ │ │ ├── ram_tag_list.txt │ │ ├── ram_tag_list_chinese.txt │ │ ├── ram_tag_list_threshold.txt │ │ └── tag_list.txt │ └── models/ │ ├── __init__.py │ ├── bert.py │ ├── bert_lora.py │ ├── ram.py │ ├── ram_lora.py │ ├── swin_transformer.py │ ├── swin_transformer_lora.py │ ├── tag2text.py │ ├── tag2text_lora.py │ ├── utils.py │ └── vit.py ├── requirements.txt ├── test.py ├── test_debug.sh ├── train.py ├── train.sh ├── train_debug.sh └── utils.py
SYMBOL INDEX (473 symbols across 31 files)
FILE: bsr/degradations.py
function sigma_matrix2 (line 16) | def sigma_matrix2(sig_x, sig_y, theta):
function mesh_grid (line 32) | def mesh_grid(kernel_size):
function pdf2 (line 50) | def pdf2(sigma_matrix, grid):
function cdf2 (line 66) | def cdf2(d_matrix, grid):
function bivariate_Gaussian (line 84) | def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isot...
function bivariate_generalized_Gaussian (line 112) | def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, bet...
function bivariate_plateau (line 143) | def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None,...
function random_bivariate_Gaussian (line 176) | def random_bivariate_Gaussian(kernel_size,
function random_bivariate_generalized_Gaussian (line 220) | def random_bivariate_generalized_Gaussian(kernel_size,
function random_bivariate_plateau (line 272) | def random_bivariate_plateau(kernel_size,
function random_mixed_kernels (line 324) | def random_mixed_kernels(kernel_list,
function circular_lowpass_kernel (line 389) | def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
function generate_gaussian_noise (line 419) | def generate_gaussian_noise(img, sigma=10, gray_noise=False):
function add_gaussian_noise (line 438) | def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_nois...
function generate_gaussian_noise_pt (line 460) | def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
function add_gaussian_noise_pt (line 492) | def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds...
function random_generate_gaussian_noise (line 515) | def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
function random_add_gaussian_noise (line 524) | def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, cl...
function random_generate_gaussian_noise_pt (line 536) | def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_pro...
function random_add_gaussian_noise_pt (line 544) | def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0,...
function generate_poisson_noise (line 559) | def generate_poisson_noise(img, scale=1.0, gray_noise=False):
function add_poisson_noise (line 586) | def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_nois...
function generate_poisson_noise_pt (line 609) | def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
function add_poisson_noise_pt (line 657) | def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_n...
function random_generate_poisson_noise (line 685) | def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
function random_add_poisson_noise (line 694) | def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, cli...
function random_generate_poisson_noise_pt (line 706) | def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_pro...
function random_add_poisson_noise_pt (line 714) | def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, ...
function add_jpg_compression (line 731) | def add_jpg_compression(img, quality=90):
function random_add_jpg_compression (line 750) | def random_add_jpg_compression(img, quality_range=(90, 100)):
FILE: bsr/transforms.py
function mod_crop (line 6) | def mod_crop(img, scale):
function paired_random_crop (line 26) | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=N...
function augment (line 94) | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=F...
function img_rotate (line 161) | def img_rotate(img, angle, center=None, scale=1.0):
FILE: bsr/utils/color_util.py
function rgb2ycbcr (line 5) | def rgb2ycbcr(img, y_only=False):
function bgr2ycbcr (line 38) | def bgr2ycbcr(img, y_only=False):
function ycbcr2rgb (line 71) | def ycbcr2rgb(img):
function ycbcr2bgr (line 100) | def ycbcr2bgr(img):
function _convert_input_type_range (line 129) | def _convert_input_type_range(img):
function _convert_output_type_range (line 156) | def _convert_output_type_range(img, dst_type):
function rgb2ycbcr_pt (line 186) | def rgb2ycbcr_pt(img, y_only=False):
FILE: bsr/utils/diffjpeg.py
function diff_round (line 26) | def diff_round(x):
function quality_to_factor (line 32) | def quality_to_factor(quality):
class RGB2YCbCrJpeg (line 49) | class RGB2YCbCrJpeg(nn.Module):
method __init__ (line 53) | def __init__(self):
method forward (line 60) | def forward(self, image):
class ChromaSubsampling (line 73) | class ChromaSubsampling(nn.Module):
method __init__ (line 77) | def __init__(self):
method forward (line 80) | def forward(self, image):
class BlockSplitting (line 98) | class BlockSplitting(nn.Module):
method __init__ (line 102) | def __init__(self):
method forward (line 106) | def forward(self, image):
class DCT8x8 (line 121) | class DCT8x8(nn.Module):
method __init__ (line 125) | def __init__(self):
method forward (line 134) | def forward(self, image):
class YQuantize (line 148) | class YQuantize(nn.Module):
method __init__ (line 155) | def __init__(self, rounding):
method forward (line 160) | def forward(self, image, factor=1):
class CQuantize (line 178) | class CQuantize(nn.Module):
method __init__ (line 185) | def __init__(self, rounding):
method forward (line 190) | def forward(self, image, factor=1):
class CompressJpeg (line 208) | class CompressJpeg(nn.Module):
method __init__ (line 215) | def __init__(self, rounding=torch.round):
method forward (line 222) | def forward(self, image, factor=1):
class YDequantize (line 247) | class YDequantize(nn.Module):
method __init__ (line 251) | def __init__(self):
method forward (line 255) | def forward(self, image, factor=1):
class CDequantize (line 272) | class CDequantize(nn.Module):
method __init__ (line 276) | def __init__(self):
method forward (line 280) | def forward(self, image, factor=1):
class iDCT8x8 (line 297) | class iDCT8x8(nn.Module):
method __init__ (line 301) | def __init__(self):
method forward (line 310) | def forward(self, image):
class BlockMerging (line 324) | class BlockMerging(nn.Module):
method __init__ (line 328) | def __init__(self):
method forward (line 331) | def forward(self, patches, height, width):
class ChromaUpsampling (line 348) | class ChromaUpsampling(nn.Module):
method __init__ (line 352) | def __init__(self):
method forward (line 355) | def forward(self, y, cb, cr):
class YCbCr2RGBJpeg (line 378) | class YCbCr2RGBJpeg(nn.Module):
method __init__ (line 382) | def __init__(self):
method forward (line 389) | def forward(self, image):
class DeCompressJpeg (line 401) | class DeCompressJpeg(nn.Module):
method __init__ (line 408) | def __init__(self, rounding=torch.round):
method forward (line 417) | def forward(self, y, cb, cr, imgh, imgw, factor=1):
class DiffJPEG (line 449) | class DiffJPEG(nn.Module):
method __init__ (line 457) | def __init__(self, differentiable=True):
method forward (line 467) | def forward(self, x, quality):
FILE: bsr/utils/dist_util.py
function init_dist (line 10) | def init_dist(launcher, backend='nccl', **kwargs):
function _init_dist_pytorch (line 21) | def _init_dist_pytorch(backend, **kwargs):
function _init_dist_slurm (line 28) | def _init_dist_slurm(backend, port=None):
function get_dist_info (line 60) | def get_dist_info():
function master_only (line 74) | def master_only(func):
FILE: bsr/utils/download_util.py
function download_file_from_google_drive (line 11) | def download_file_from_google_drive(file_id, save_path):
function get_confirm_token (line 41) | def get_confirm_token(response):
function save_response_content (line 48) | def save_response_content(response, destination, file_size=None, chunk_s...
function load_file_from_url (line 69) | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
FILE: bsr/utils/file_client.py
class BaseStorageBackend (line 5) | class BaseStorageBackend(metaclass=ABCMeta):
method get (line 14) | def get(self, filepath):
method get_text (line 18) | def get_text(self, filepath):
class MemcachedBackend (line 22) | class MemcachedBackend(BaseStorageBackend):
method __init__ (line 32) | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
method get (line 47) | def get(self, filepath):
method get_text (line 54) | def get_text(self, filepath):
class HardDiskBackend (line 58) | class HardDiskBackend(BaseStorageBackend):
method get (line 61) | def get(self, filepath):
method get_text (line 67) | def get_text(self, filepath):
class LmdbBackend (line 74) | class LmdbBackend(BaseStorageBackend):
method __init__ (line 94) | def __init__(self, db_paths, client_keys='default', readonly=True, loc...
method get (line 114) | def get(self, filepath, client_key):
method get_text (line 128) | def get_text(self, filepath):
class FileClient (line 132) | class FileClient(object):
method __init__ (line 151) | def __init__(self, backend='disk', **kwargs):
method get (line 158) | def get(self, filepath, client_key='default'):
method get_text (line 166) | def get_text(self, filepath):
FILE: bsr/utils/flow_util.py
function flowread (line 7) | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
function flowwrite (line 45) | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kw...
function quantize_flow (line 76) | def quantize_flow(flow, max_val=0.02, norm=True):
function dequantize_flow (line 102) | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
function quantize (line 126) | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
function dequantize (line 150) | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
FILE: bsr/utils/img_process_util.py
function filter2D (line 7) | def filter2D(img, kernel):
function usm_sharp (line 34) | def usm_sharp(img, weight=0.5, radius=50, threshold=10):
class USMSharp (line 63) | class USMSharp(torch.nn.Module):
method __init__ (line 65) | def __init__(self, radius=50, sigma=0):
method forward (line 74) | def forward(self, img, weight=0.5, threshold=10):
FILE: bsr/utils/img_util.py
function img2tensor (line 9) | def img2tensor(imgs, bgr2rgb=True, float32=True):
function tensor2img (line 38) | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
function tensor2img_fast (line 97) | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
function imfrombytes (line 114) | def imfrombytes(content, flag='color', float32=False):
function imwrite (line 135) | def imwrite(img, file_path, params=None, auto_mkdir=True):
function crop_border (line 156) | def crop_border(imgs, crop_border):
FILE: bsr/utils/lmdb_util.py
function make_lmdb_from_imgs (line 9) | def make_lmdb_from_imgs(data_path,
function read_img_worker (line 135) | def read_img_worker(path, key, compress_level):
class LmdbMaker (line 159) | class LmdbMaker():
method __init__ (line 170) | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_l...
method put (line 185) | def put(self, img_byte, key, img_shape):
method close (line 196) | def close(self):
FILE: bsr/utils/logger.py
class AvgTimer (line 10) | class AvgTimer():
method __init__ (line 12) | def __init__(self, window=200):
method start (line 20) | def start(self):
method record (line 23) | def record(self):
method get_current_time (line 38) | def get_current_time(self):
method get_avg_time (line 41) | def get_avg_time(self):
class MessageLogger (line 45) | class MessageLogger():
method __init__ (line 58) | def __init__(self, opt, start_iter=1, tb_logger=None):
method reset_start_time (line 68) | def reset_start_time(self):
method __call__ (line 72) | def __call__(self, log_vars):
function init_tb_logger (line 119) | def init_tb_logger(log_dir):
function init_wandb_logger (line 126) | def init_wandb_logger(opt):
function get_root_logger (line 146) | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_f...
function get_env_info (line 188) | def get_env_info():
FILE: bsr/utils/matlab_functions.py
function cubic (line 6) | def cubic(x):
function calculate_weights_indices (line 16) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
function imresize (line 86) | def imresize(img, scale, antialiasing=True):
FILE: bsr/utils/misc.py
function set_random_seed (line 11) | def set_random_seed(seed):
function get_time_str (line 20) | def get_time_str():
function mkdir_and_rename (line 24) | def mkdir_and_rename(path):
function make_exp_dirs (line 38) | def make_exp_dirs(opt):
function scandir (line 52) | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
function check_resume (line 94) | def check_resume(opt, resume_iter):
function sizeof_fmt (line 127) | def sizeof_fmt(size, suffix='B'):
FILE: bsr/utils/options.py
function ordered_yaml (line 13) | def ordered_yaml():
function yaml_load (line 38) | def yaml_load(f):
function dict2str (line 54) | def dict2str(opt, indent_level=1):
function _postprocess_yml_value (line 75) | def _postprocess_yml_value(value):
function parse_options (line 99) | def parse_options(root_path, is_train=True):
function copy_opt_file (line 205) | def copy_opt_file(opt_file, experiments_root):
FILE: bsr/utils/plot_util.py
function read_data_from_tensorboard (line 4) | def read_data_from_tensorboard(log_path, tag):
function read_data_from_txt_2v (line 23) | def read_data_from_txt_2v(path, pattern, step_one=False):
function read_data_from_txt_1v (line 48) | def read_data_from_txt_1v(path, pattern):
function smooth_data (line 68) | def smooth_data(values, smooth_weight):
FILE: bsr/utils/registry.py
class Registry (line 4) | class Registry():
method __init__ (line 30) | def __init__(self, name):
method _do_register (line 38) | def _do_register(self, name, obj, suffix=None):
method register (line 46) | def register(self, obj=None, suffix=None):
method get (line 65) | def get(self, name, suffix='basicsr'):
method __contains__ (line 74) | def __contains__(self, name):
method __iter__ (line 77) | def __iter__(self):
method keys (line 80) | def keys(self):
FILE: dataset.py
class RealESRGANDataset (line 9) | class RealESRGANDataset(torch.utils.data.Dataset):
method __init__ (line 10) | def __init__(self, opt, bsz):
method __getitem__ (line 44) | def __getitem__(self, index):
method __len__ (line 124) | def __len__(self):
class RealESRGANDegrader (line 127) | class RealESRGANDegrader:
method __init__ (line 128) | def __init__(self, opt, device):
method _dequeue_and_enqueue (line 135) | def _dequeue_and_enqueue(self):
method degrade (line 172) | def degrade(self, data):
FILE: forward.py
function MyUNet2DConditionModel_SD_forward (line 3) | def MyUNet2DConditionModel_SD_forward(self, x):
function MyCrossAttnDownBlock2D_SD_forward (line 10) | def MyCrossAttnDownBlock2D_SD_forward(self, x):
function MyCrossAttnUpBlock2D_SD_forward (line 20) | def MyCrossAttnUpBlock2D_SD_forward(self, x):
function MyDownBlock2D_SD_forward (line 28) | def MyDownBlock2D_SD_forward(self, x):
function MyUNetMidBlock2DCrossAttn_SD_forward (line 34) | def MyUNetMidBlock2DCrossAttn_SD_forward(self, x):
function MyUpBlock2D_SD_forward (line 40) | def MyUpBlock2D_SD_forward(self, x):
function MyResnetBlock2D_SD_forward (line 46) | def MyResnetBlock2D_SD_forward(self, x_in):
function MyTransformer2DModel_SD_forward (line 57) | def MyTransformer2DModel_SD_forward(self, x_in):
FILE: model.py
function find_parent (line 23) | def find_parent(model, module_name):
function halve_channels (line 30) | def halve_channels(model):
class Net (line 94) | class Net(nn.Module):
method __init__ (line 95) | def __init__(self, unet, decoder):
method forward (line 151) | def forward(self, x):
FILE: ram/models/bert.py
class BertEmbeddings_nopos (line 52) | class BertEmbeddings_nopos(nn.Module):
method __init__ (line 55) | def __init__(self, config):
method forward (line 71) | def forward(
class BertEmbeddings (line 100) | class BertEmbeddings(nn.Module):
method __init__ (line 103) | def __init__(self, config):
method forward (line 119) | def forward(
class BertSelfAttention (line 146) | class BertSelfAttention(nn.Module):
method __init__ (line 147) | def __init__(self, config, is_cross_attention):
method save_attn_gradients (line 175) | def save_attn_gradients(self, attn_gradients):
method get_attn_gradients (line 178) | def get_attn_gradients(self):
method save_attention_map (line 181) | def save_attention_map(self, attention_map):
method get_attention_map (line 184) | def get_attention_map(self):
method transpose_for_scores (line 187) | def transpose_for_scores(self, x):
method forward (line 192) | def forward(
class BertSelfOutput (line 284) | class BertSelfOutput(nn.Module):
method __init__ (line 285) | def __init__(self, config):
method forward (line 291) | def forward(self, hidden_states, input_tensor):
class BertAttention (line 298) | class BertAttention(nn.Module):
method __init__ (line 299) | def __init__(self, config, is_cross_attention=False):
method prune_heads (line 305) | def prune_heads(self, heads):
method forward (line 323) | def forward(
class BertIntermediate (line 347) | class BertIntermediate(nn.Module):
method __init__ (line 348) | def __init__(self, config):
method forward (line 356) | def forward(self, hidden_states):
class BertOutput (line 362) | class BertOutput(nn.Module):
method __init__ (line 363) | def __init__(self, config):
method forward (line 369) | def forward(self, hidden_states, input_tensor):
class BertLayer (line 376) | class BertLayer(nn.Module):
method __init__ (line 377) | def __init__(self, config, layer_num):
method forward (line 389) | def forward(
method feed_forward_chunk (line 455) | def feed_forward_chunk(self, attention_output):
class BertEncoder (line 461) | class BertEncoder(nn.Module):
method __init__ (line 462) | def __init__(self, config):
method forward (line 468) | def forward(
class BertPooler (line 561) | class BertPooler(nn.Module):
method __init__ (line 562) | def __init__(self, config):
method forward (line 567) | def forward(self, hidden_states):
class BertPredictionHeadTransform (line 576) | class BertPredictionHeadTransform(nn.Module):
method __init__ (line 577) | def __init__(self, config):
method forward (line 586) | def forward(self, hidden_states):
class BertLMPredictionHead (line 593) | class BertLMPredictionHead(nn.Module):
method __init__ (line 594) | def __init__(self, config):
method forward (line 607) | def forward(self, hidden_states):
class BertOnlyMLMHead (line 613) | class BertOnlyMLMHead(nn.Module):
method __init__ (line 614) | def __init__(self, config):
method forward (line 618) | def forward(self, sequence_output):
class BertPreTrainedModel (line 623) | class BertPreTrainedModel(PreTrainedModel):
method _init_weights (line 633) | def _init_weights(self, module):
class BertModel (line 646) | class BertModel(BertPreTrainedModel):
method __init__ (line 656) | def __init__(self, config, add_pooling_layer=True):
method get_input_embeddings (line 669) | def get_input_embeddings(self):
method set_input_embeddings (line 672) | def set_input_embeddings(self, value):
method _prune_heads (line 675) | def _prune_heads(self, heads_to_prune):
method get_extended_attention_mask (line 684) | def get_extended_attention_mask(self, attention_mask: Tensor, input_sh...
method forward (line 745) | def forward(
class BertLMHeadModel (line 885) | class BertLMHeadModel(BertPreTrainedModel):
method __init__ (line 890) | def __init__(self, config):
method get_output_embeddings (line 898) | def get_output_embeddings(self):
method set_output_embeddings (line 901) | def set_output_embeddings(self, new_embeddings):
method forward (line 904) | def forward(
method prepare_inputs_for_generation (line 1010) | def prepare_inputs_for_generation(self, input_ids, past=None, attentio...
method _reorder_cache (line 1029) | def _reorder_cache(self, past, beam_idx):
FILE: ram/models/bert_lora.py
class BertEmbeddings_nopos (line 54) | class BertEmbeddings_nopos(nn.Module):
method __init__ (line 57) | def __init__(self, config):
method forward (line 73) | def forward(
class BertEmbeddings (line 102) | class BertEmbeddings(nn.Module):
method __init__ (line 105) | def __init__(self, config):
method forward (line 121) | def forward(
class BertSelfAttention (line 148) | class BertSelfAttention(nn.Module):
method __init__ (line 149) | def __init__(self, config, is_cross_attention):
method save_attn_gradients (line 180) | def save_attn_gradients(self, attn_gradients):
method get_attn_gradients (line 183) | def get_attn_gradients(self):
method save_attention_map (line 186) | def save_attention_map(self, attention_map):
method get_attention_map (line 189) | def get_attention_map(self):
method transpose_for_scores (line 192) | def transpose_for_scores(self, x):
method forward (line 197) | def forward(
class BertSelfOutput (line 289) | class BertSelfOutput(nn.Module):
method __init__ (line 290) | def __init__(self, config):
method forward (line 296) | def forward(self, hidden_states, input_tensor):
class BertAttention (line 303) | class BertAttention(nn.Module):
method __init__ (line 304) | def __init__(self, config, is_cross_attention=False):
method prune_heads (line 310) | def prune_heads(self, heads):
method forward (line 328) | def forward(
class BertIntermediate (line 352) | class BertIntermediate(nn.Module):
method __init__ (line 353) | def __init__(self, config):
method forward (line 361) | def forward(self, hidden_states):
class BertOutput (line 367) | class BertOutput(nn.Module):
method __init__ (line 368) | def __init__(self, config):
method forward (line 374) | def forward(self, hidden_states, input_tensor):
class BertLayer (line 381) | class BertLayer(nn.Module):
method __init__ (line 382) | def __init__(self, config, layer_num):
method forward (line 394) | def forward(
method feed_forward_chunk (line 460) | def feed_forward_chunk(self, attention_output):
class BertEncoder (line 466) | class BertEncoder(nn.Module):
method __init__ (line 467) | def __init__(self, config):
method forward (line 473) | def forward(
class BertPooler (line 566) | class BertPooler(nn.Module):
method __init__ (line 567) | def __init__(self, config):
method forward (line 572) | def forward(self, hidden_states):
class BertPredictionHeadTransform (line 581) | class BertPredictionHeadTransform(nn.Module):
method __init__ (line 582) | def __init__(self, config):
method forward (line 591) | def forward(self, hidden_states):
class BertLMPredictionHead (line 598) | class BertLMPredictionHead(nn.Module):
method __init__ (line 599) | def __init__(self, config):
method forward (line 612) | def forward(self, hidden_states):
class BertOnlyMLMHead (line 618) | class BertOnlyMLMHead(nn.Module):
method __init__ (line 619) | def __init__(self, config):
method forward (line 623) | def forward(self, sequence_output):
class BertPreTrainedModel (line 628) | class BertPreTrainedModel(PreTrainedModel):
method _init_weights (line 638) | def _init_weights(self, module):
class BertModel (line 651) | class BertModel(BertPreTrainedModel):
method __init__ (line 661) | def __init__(self, config, add_pooling_layer=True):
method get_input_embeddings (line 674) | def get_input_embeddings(self):
method set_input_embeddings (line 677) | def set_input_embeddings(self, value):
method _prune_heads (line 680) | def _prune_heads(self, heads_to_prune):
method get_extended_attention_mask (line 689) | def get_extended_attention_mask(self, attention_mask: Tensor, input_sh...
method forward (line 750) | def forward(
class BertLMHeadModel (line 890) | class BertLMHeadModel(BertPreTrainedModel):
method __init__ (line 895) | def __init__(self, config):
method get_output_embeddings (line 903) | def get_output_embeddings(self):
method set_output_embeddings (line 906) | def set_output_embeddings(self, new_embeddings):
method forward (line 909) | def forward(
method prepare_inputs_for_generation (line 1015) | def prepare_inputs_for_generation(self, input_ids, past=None, attentio...
method _reorder_cache (line 1034) | def _reorder_cache(self, past, beam_idx):
FILE: ram/models/ram.py
class RAM (line 20) | class RAM(nn.Module):
method __init__ (line 21) | def __init__(self,
method load_tag_list (line 160) | def load_tag_list(self, tag_list_file):
method del_selfattention (line 167) | def del_selfattention(self):
method condition_forward (line 172) | def condition_forward(self,
method generate_tag (line 212) | def generate_tag(self,
method generate_tag_openset (line 261) | def generate_tag_openset(self,
function ram (line 306) | def ram(pretrained='', **kwargs):
FILE: ram/models/ram_lora.py
class RAMLora (line 21) | class RAMLora(nn.Module):
method __init__ (line 22) | def __init__(self,
method load_tag_list (line 171) | def load_tag_list(self, tag_list_file):
method del_selfattention (line 178) | def del_selfattention(self):
method generate_image_embeds (line 183) | def generate_image_embeds(self,
method generate_tag (line 192) | def generate_tag(self,
method condition_forward (line 243) | def condition_forward(self,
method generate_tag_openset (line 283) | def generate_tag_openset(self,
function ram (line 328) | def ram(pretrained='', pretrained_condition='', **kwargs):
FILE: ram/models/swin_transformer.py
class Mlp (line 17) | class Mlp(nn.Module):
method __init__ (line 18) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 27) | def forward(self, x):
function window_partition (line 36) | def window_partition(x, window_size):
function window_reverse (line 51) | def window_reverse(windows, window_size, H, W):
class WindowAttention (line 68) | class WindowAttention(nn.Module):
method __init__ (line 82) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
method forward (line 116) | def forward(self, x, mask=None):
method extra_repr (line 149) | def extra_repr(self) -> str:
method flops (line 152) | def flops(self, N):
class SwinTransformerBlock (line 166) | class SwinTransformerBlock(nn.Module):
method __init__ (line 185) | def __init__(self, dim, input_resolution, num_heads, window_size=7, sh...
method forward (line 247) | def forward(self, x, condition=None):
method extra_repr (line 312) | def extra_repr(self) -> str:
method flops (line 316) | def flops(self):
class PatchMerging (line 331) | class PatchMerging(nn.Module):
method __init__ (line 340) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
method forward (line 347) | def forward(self, x):
method extra_repr (line 370) | def extra_repr(self) -> str:
method flops (line 373) | def flops(self):
class BasicLayer (line 380) | class BasicLayer(nn.Module):
method __init__ (line 400) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
method forward (line 428) | def forward(self, x, condition=None):
method extra_repr (line 438) | def extra_repr(self) -> str:
method flops (line 441) | def flops(self):
class PatchEmbed (line 450) | class PatchEmbed(nn.Module):
method __init__ (line 461) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
method forward (line 480) | def forward(self, x):
method flops (line 490) | def flops(self):
class SwinTransformer (line 498) | class SwinTransformer(nn.Module):
method __init__ (line 524) | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes...
method _init_weights (line 582) | def _init_weights(self, m):
method no_weight_decay (line 592) | def no_weight_decay(self):
method no_weight_decay_keywords (line 596) | def no_weight_decay_keywords(self):
method forward (line 599) | def forward(self, x, idx_to_group_img=None, image_atts=None, condition...
method flops (line 623) | def flops(self):
function interpolate_relative_pos_embed (line 633) | def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name...
function zero_module (line 693) | def zero_module(module):
FILE: ram/models/swin_transformer_lora.py
class Mlp (line 19) | class Mlp(nn.Module):
method __init__ (line 20) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 31) | def forward(self, x):
function window_partition (line 40) | def window_partition(x, window_size):
function window_reverse (line 55) | def window_reverse(windows, window_size, H, W):
class WindowAttention (line 72) | class WindowAttention(nn.Module):
method __init__ (line 86) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
method forward (line 122) | def forward(self, x, mask=None):
method extra_repr (line 155) | def extra_repr(self) -> str:
method flops (line 158) | def flops(self, N):
class SwinTransformerBlock (line 172) | class SwinTransformerBlock(nn.Module):
method __init__ (line 191) | def __init__(self, dim, input_resolution, num_heads, window_size=7, sh...
method forward (line 242) | def forward(self, x):
method extra_repr (line 281) | def extra_repr(self) -> str:
method flops (line 285) | def flops(self):
class PatchMerging (line 300) | class PatchMerging(nn.Module):
method __init__ (line 309) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
method forward (line 316) | def forward(self, x):
method extra_repr (line 339) | def extra_repr(self) -> str:
method flops (line 342) | def flops(self):
class BasicLayer (line 349) | class BasicLayer(nn.Module):
method __init__ (line 369) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
method forward (line 397) | def forward(self, x):
method extra_repr (line 407) | def extra_repr(self) -> str:
method flops (line 410) | def flops(self):
class PatchEmbed (line 419) | class PatchEmbed(nn.Module):
method __init__ (line 430) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
method forward (line 449) | def forward(self, x):
method flops (line 459) | def flops(self):
class SwinTransformer (line 467) | class SwinTransformer(nn.Module):
method __init__ (line 493) | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes...
method _init_weights (line 551) | def _init_weights(self, m):
method no_weight_decay (line 561) | def no_weight_decay(self):
method no_weight_decay_keywords (line 565) | def no_weight_decay_keywords(self):
method forward (line 568) | def forward(self, x, idx_to_group_img=None, image_atts=None, **kwargs):
method flops (line 592) | def flops(self):
function interpolate_relative_pos_embed (line 602) | def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name...
FILE: ram/models/tag2text.py
class Tag2Text (line 19) | class Tag2Text(nn.Module):
method __init__ (line 21) | def __init__(self,
method load_tag_list (line 128) | def load_tag_list(self, tag_list_file):
method del_selfattention (line 135) | def del_selfattention(self):
method forward (line 141) | def forward(self, image, caption, tag):
method generate_image_embeds (line 230) | def generate_image_embeds(self,
method condition_forward (line 239) | def condition_forward(self,
method generate (line 280) | def generate(self,
function tag2text (line 409) | def tag2text(pretrained='', **kwargs):
FILE: ram/models/tag2text_lora.py
class Tag2Text (line 19) | class Tag2Text(nn.Module):
method __init__ (line 21) | def __init__(self,
method load_tag_list (line 128) | def load_tag_list(self, tag_list_file):
method del_selfattention (line 135) | def del_selfattention(self):
method forward (line 141) | def forward(self, image, caption, tag):
method generate_image_embeds (line 230) | def generate_image_embeds(self,
method condition_forward (line 239) | def condition_forward(self,
method generate (line 280) | def generate(self,
function tag2text (line 409) | def tag2text(pretrained='', **kwargs):
FILE: ram/models/utils.py
function read_json (line 16) | def read_json(rpath):
function tie_encoder_decoder_weights (line 21) | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module,
class GroupWiseLinear (line 99) | class GroupWiseLinear(nn.Module):
method __init__ (line 103) | def __init__(self, num_class, hidden_dim, bias=True):
method reset_parameters (line 114) | def reset_parameters(self):
method forward (line 122) | def forward(self, x):
function init_tokenizer (line 130) | def init_tokenizer():
function create_vit (line 138) | def create_vit(vit,
function is_url (line 170) | def is_url(url_or_filename):
function load_checkpoint (line 175) | def load_checkpoint(model, url_or_filename):
function load_checkpoint_swinlarge_condition (line 203) | def load_checkpoint_swinlarge_condition(model, url_or_filename, kwargs):
function load_checkpoint_swinbase (line 241) | def load_checkpoint_swinbase(model, url_or_filename, kwargs):
function load_checkpoint_swinlarge (line 279) | def load_checkpoint_swinlarge(model, url_or_filename, kwargs):
class AsymmetricLoss (line 319) | class AsymmetricLoss(nn.Module):
method __init__ (line 320) | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disa...
method forward (line 329) | def forward(self, x, y):
FILE: ram/models/vit.py
class Mlp (line 23) | class Mlp(nn.Module):
method __init__ (line 26) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 35) | def forward(self, x):
class Attention (line 44) | class Attention(nn.Module):
method __init__ (line 45) | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, at...
method save_attn_gradients (line 58) | def save_attn_gradients(self, attn_gradients):
method get_attn_gradients (line 61) | def get_attn_gradients(self):
method save_attention_map (line 64) | def save_attention_map(self, attention_map):
method get_attention_map (line 67) | def get_attention_map(self):
method forward (line 70) | def forward(self, x, register_hook=False):
class Block (line 89) | class Block(nn.Module):
method __init__ (line 91) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
method forward (line 107) | def forward(self, x, register_hook=False):
class VisionTransformer (line 113) | class VisionTransformer(nn.Module):
method __init__ (line 118) | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classe...
method _init_weights (line 167) | def _init_weights(self, m):
method no_weight_decay (line 177) | def no_weight_decay(self):
method forward (line 180) | def forward(self, x, register_blk=-1):
method load_pretrained (line 197) | def load_pretrained(self, checkpoint_path, prefix=''):
function _load_weights (line 202) | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix...
function interpolate_pos_embed (line 281) | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
FILE: utils.py
function add_lora_to_unet (line 4) | def add_lora_to_unet(unet, rank=4):
Condensed preview — 54 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (588K chars).
[
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 6913,
"preview": "<p align=\"center\">\n <img src=\"assets/icon.png\" alt=\"icon\" width=\"200px\"/>\n</p>\n\n# (CVPR 2025) Adversarial Diffusion Co"
},
{
"path": "bsr/degradations.py",
"chars": 28202,
"preview": "import cv2\nimport math\nimport numpy as np\nimport random\nimport torch\nfrom scipy import special\nfrom scipy.stats import m"
},
{
"path": "bsr/transforms.py",
"chars": 6225,
"preview": "import cv2\nimport random\nimport torch\n\n\ndef mod_crop(img, scale):\n \"\"\"Mod crop images, used during testing.\n\n Args"
},
{
"path": "bsr/utils/__init__.py",
"chars": 1220,
"preview": "from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb\nfrom .diffjpeg import DiffJPEG\nfrom .fi"
},
{
"path": "bsr/utils/color_util.py",
"chars": 7981,
"preview": "import numpy as np\nimport torch\n\n\ndef rgb2ycbcr(img, y_only=False):\n \"\"\"Convert a RGB image to YCbCr image.\n\n This"
},
{
"path": "bsr/utils/diffjpeg.py",
"chars": 15662,
"preview": "\"\"\"\nModified from https://github.com/mlomnitz/DiffJPEG\n\nFor images not divisible by 8\nhttps://dsp.stackexchange.com/ques"
},
{
"path": "bsr/utils/dist_util.py",
"chars": 2608,
"preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501\nimport functools\n"
},
{
"path": "bsr/utils/download_util.py",
"chars": 3341,
"preview": "import math\nimport os\nimport requests\nfrom torch.hub import download_url_to_file, get_dir\nfrom tqdm import tqdm\nfrom url"
},
{
"path": "bsr/utils/file_client.py",
"chars": 6014,
"preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501\nfrom abc import "
},
{
"path": "bsr/utils/flow_util.py",
"chars": 6159,
"preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501\nimport cv2\nimport num"
},
{
"path": "bsr/utils/img_process_util.py",
"chars": 2563,
"preview": "import cv2\nimport numpy as np\nimport torch\nfrom torch.nn import functional as F\n\n\ndef filter2D(img, kernel):\n \"\"\"PyTo"
},
{
"path": "bsr/utils/img_util.py",
"chars": 6195,
"preview": "import cv2\nimport math\nimport numpy as np\nimport os\nimport torch\nfrom torchvision.utils import make_grid\n\n\ndef img2tenso"
},
{
"path": "bsr/utils/lmdb_util.py",
"chars": 7130,
"preview": "import cv2\nimport lmdb\nimport sys\nfrom multiprocessing import Pool\nfrom os import path as osp\nfrom tqdm import tqdm\n\n\nde"
},
{
"path": "bsr/utils/logger.py",
"chars": 7148,
"preview": "import datetime\nimport logging\nimport time\n\nfrom .dist_util import get_dist_info, master_only\n\ninitialized_logger = {}\n\n"
},
{
"path": "bsr/utils/matlab_functions.py",
"chars": 6962,
"preview": "import math\nimport numpy as np\nimport torch\n\n\ndef cubic(x):\n \"\"\"cubic function used for calculate_weights_indices.\"\"\""
},
{
"path": "bsr/utils/misc.py",
"chars": 4655,
"preview": "import numpy as np\nimport os\nimport random\nimport time\nimport torch\nfrom os import path as osp\n\nfrom .dist_util import m"
},
{
"path": "bsr/utils/options.py",
"chars": 6990,
"preview": "import argparse\nimport os\nimport random\nimport torch\nimport yaml\nfrom collections import OrderedDict\nfrom os import path"
},
{
"path": "bsr/utils/plot_util.py",
"chars": 2525,
"preview": "import re\n\n\ndef read_data_from_tensorboard(log_path, tag):\n \"\"\"Get raw data (steps and values) from tensorboard event"
},
{
"path": "bsr/utils/registry.py",
"chars": 2477,
"preview": "# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501\n\n\nclass "
},
{
"path": "config.yml",
"chars": 1139,
"preview": "dataroot_gt: path_to_HR_images_of_LSDIR\n\nscale: 4\n\n# the first degradation process\nresize_prob: [0.2, 0.7, 0.1] # up, d"
},
{
"path": "dataset.py",
"chars": 13204,
"preview": "import torch, random, cv2, os, math, glob\nimport torch.nn.functional as F\nimport numpy as np\nfrom bsr.degradations impor"
},
{
"path": "evaluate.py",
"chars": 2153,
"preview": "import torch, os, glob, pyiqa\nfrom argparse import ArgumentParser\nimport numpy as np\nfrom PIL import Image\nfrom tqdm imp"
},
{
"path": "evaluate_debug.sh",
"chars": 137,
"preview": "HF_ENDPOINT=https://hf-mirror.com \\\nCUDA_VISIBLE_DEVICES=0 \\\npython -u evaluate.py \\\n--HR_dir=testset/RealSR/HR \\\n--SR_d"
},
{
"path": "forward.py",
"chars": 1845,
"preview": "import torch\n\ndef MyUNet2DConditionModel_SD_forward(self, x):\n global skip\n x = self.conv_in(x)\n skip = [x]\n "
},
{
"path": "model.py",
"chars": 7813,
"preview": "import torch, types, copy\r\nfrom torch import nn\r\nimport torch.nn.functional as F\r\nfrom diffusers.models.unets.unet_2d_bl"
},
{
"path": "ram/configs/condition_config.json",
"chars": 18,
"preview": "{\n \"nf\": 64\n }"
},
{
"path": "ram/configs/med_config.json",
"chars": 524,
"preview": "{\n \"architectures\": [\n \"BertModel\"\n ],\n \"attention_probs_dropout_prob\": 0.1,\n \"hidden_act\": \"gelu\",\n "
},
{
"path": "ram/configs/q2l_config.json",
"chars": 557,
"preview": "{\n \"architectures\": [\n \"BertModel\"\n ],\n \"attention_probs_dropout_prob\": 0.1,\n \"hidden_act\": \"gelu\",\n "
},
{
"path": "ram/configs/swin/config_swinB_384.json",
"chars": 230,
"preview": "{\n \"ckpt\": \"pretrain_model/swin_base_patch4_window7_224_22k.pth\",\n \"vision_width\": 1024,\n \"image_res\": 384,\n "
},
{
"path": "ram/configs/swin/config_swinL_384.json",
"chars": 233,
"preview": "{\n \"ckpt\": \"pretrain_model/swin_large_patch4_window12_384_22k.pth\",\n \"vision_width\": 1536,\n \"image_res\": 384,\n "
},
{
"path": "ram/configs/swin/config_swinL_444.json",
"chars": 233,
"preview": "{\n \"ckpt\": \"pretrain_model/swin_large_patch4_window12_384_22k.pth\",\n \"vision_width\": 1536,\n \"image_res\": 444,\n "
},
{
"path": "ram/data/ram_tag_list.txt",
"chars": 41904,
"preview": "3D CG rendering\n3D glasses\nabacus\nabalone\nmonastery\nbelly\nacademy\naccessory\naccident\naccordion\nacorn\nacrylic paint\nact\na"
},
{
"path": "ram/data/ram_tag_list_chinese.txt",
"chars": 19597,
"preview": "三维CG渲染 \n3d眼镜\n算盘 \n鲍鱼 \n修道院 \n肚子 \n学院 \n附件 \n事故 \n手风琴 \n橡子 \n丙烯颜料\n表演\n行动 \n动作电影 \n活动 \n演员 \n改编本\n添加 \n胶带 \n调整 \n成人 \n冒险 \n广告 \n天线 \n有氧运动 \n喷雾罐\n爆"
},
{
"path": "ram/data/ram_tag_list_threshold.txt",
"chars": 22016,
"preview": "0.65\n0.65\n0.65\n0.65\n0.65\n0.65\n0.65\n0.8\n0.71\n0.75\n0.65\n0.65\n0.65\n0.8\n0.65\n0.8\n0.8\n0.65\n0.65\n0.65\n0.65\n0.8\n0.65\n0.8\n0.8\n0."
},
{
"path": "ram/data/tag_list.txt",
"chars": 29062,
"preview": "tennis\nbear cub\nobservatory\nbicycle\nhillside\njudge\nwatercolor illustration\ngranite\nlobster\nlivery\nstone\nceramic\nranch\ncl"
},
{
"path": "ram/models/__init__.py",
"chars": 52,
"preview": "from .ram import ram\nfrom .tag2text import tag2text\n"
},
{
"path": "ram/models/bert.py",
"chars": 45130,
"preview": "'''\n * Copyright (c) 2022, salesforce.com, inc.\n * All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n * For "
},
{
"path": "ram/models/bert_lora.py",
"chars": 45400,
"preview": "'''\n * Copyright (c) 2022, salesforce.com, inc.\n * All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n * For "
},
{
"path": "ram/models/ram.py",
"chars": 12606,
"preview": "'''\n * The Recognize Anything Model (RAM)\n * Written by Xinyu Huang\n'''\nimport json\nimport warnings\n\nimport numpy as np\n"
},
{
"path": "ram/models/ram_lora.py",
"chars": 13396,
"preview": "'''\n * The Recognize Anything Model (RAM)\n * Written by Xinyu Huang\n'''\nimport json\nimport warnings\n\nimport numpy as np\n"
},
{
"path": "ram/models/swin_transformer.py",
"chars": 28482,
"preview": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed "
},
{
"path": "ram/models/swin_transformer_lora.py",
"chars": 27102,
"preview": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed "
},
{
"path": "ram/models/tag2text.py",
"chars": 16366,
"preview": "'''\n * The Tag2Text Model\n * Written by Xinyu Huang\n'''\nimport numpy as np\nimport json\nimport torch\nimport warnings\n\nfro"
},
{
"path": "ram/models/tag2text_lora.py",
"chars": 16376,
"preview": "'''\n * The Tag2Text Model\n * Written by Xinyu Huang\n'''\nimport numpy as np\nimport json\nimport torch\nimport warnings\n\nfro"
},
{
"path": "ram/models/utils.py",
"chars": 14716,
"preview": "import os\nimport json\nimport torch\nimport math\n\nfrom torch import nn\nfrom typing import List\nfrom transformers import Be"
},
{
"path": "ram/models/vit.py",
"chars": 14240,
"preview": "'''\n * Copyright (c) 2022, salesforce.com, inc.\n * All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n * For "
},
{
"path": "requirements.txt",
"chars": 305,
"preview": "pillow==9.1.1\r\nopencv-python-headless==4.11.0.86\r\ntqdm==4.65.2\r\nomegaconf==2.3.0\r\ntorch==2.4.1\r\ntorchvision==0.19.1\r\ntor"
},
{
"path": "test.py",
"chars": 2661,
"preview": "import torch, os, glob, copy\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\nfrom argparse impo"
},
{
"path": "test_debug.sh",
"chars": 146,
"preview": "HF_ENDPOINT=https://hf-mirror.com \\\nCUDA_VISIBLE_DEVICES=0 \\\npython -u test.py \\\n--epoch=200 \\\n--LR_dir=testset/RealSR/L"
},
{
"path": "train.py",
"chars": 8781,
"preview": "import torch, os, glob, random, copy\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader\nimport torc"
},
{
"path": "train.sh",
"chars": 176,
"preview": "HF_ENDPOINT=https://hf-mirror.com \\\nCUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \\\npython -m torch.distributed.run \\\n--nproc_per"
},
{
"path": "train_debug.sh",
"chars": 162,
"preview": "HF_ENDPOINT=https://hf-mirror.com \\\nCUDA_VISIBLE_DEVICES=0 \\\nnohup torchrun \\\n--nproc_per_node=1 \\\n--master_port=23333 \\"
},
{
"path": "utils.py",
"chars": 1361,
"preview": "import torch\nfrom peft import LoraConfig\n\ndef add_lora_to_unet(unet, rank=4):\n l_target_modules_encoder, l_target_mod"
}
]
About this extraction
This page contains the full source code of the Guaishou74851/AdcSR GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 54 files (518.0 KB), approximately 160.1k tokens, and a symbol index with 473 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.