Showing preview only (4,135K chars total). Download the full file or copy to clipboard to get everything.
Repository: XPixelGroup/BasicSR
Branch: master
Commit: 8d56e3a045f9
Files: 293
Total size: 3.9 MB
Directory structure:
gitextract_6e6skl2h/
├── .github/
│ └── workflows/
│ ├── publish-pip.yml
│ ├── pylint.yml
│ └── release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── .vscode/
│ └── settings.json
├── CITATION.cff
├── LICENSE/
│ ├── LICENSE-NVIDIA
│ ├── LICENSE-stylegan2-pytorch
│ ├── LICENSE_SwinIR
│ ├── LICENSE_pytorch-image-models
│ └── README.md
├── LICENSE.txt
├── MANIFEST.in
├── README.md
├── README_CN.md
├── VERSION
├── basicsr/
│ ├── __init__.py
│ ├── archs/
│ │ ├── __init__.py
│ │ ├── arch_util.py
│ │ ├── basicvsr_arch.py
│ │ ├── basicvsrpp_arch.py
│ │ ├── dfdnet_arch.py
│ │ ├── dfdnet_util.py
│ │ ├── discriminator_arch.py
│ │ ├── duf_arch.py
│ │ ├── ecbsr_arch.py
│ │ ├── edsr_arch.py
│ │ ├── edvr_arch.py
│ │ ├── hifacegan_arch.py
│ │ ├── hifacegan_util.py
│ │ ├── inception.py
│ │ ├── rcan_arch.py
│ │ ├── ridnet_arch.py
│ │ ├── rrdbnet_arch.py
│ │ ├── spynet_arch.py
│ │ ├── srresnet_arch.py
│ │ ├── srvgg_arch.py
│ │ ├── stylegan2_arch.py
│ │ ├── stylegan2_bilinear_arch.py
│ │ ├── swinir_arch.py
│ │ ├── tof_arch.py
│ │ └── vgg_arch.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── data_sampler.py
│ │ ├── data_util.py
│ │ ├── degradations.py
│ │ ├── ffhq_dataset.py
│ │ ├── meta_info/
│ │ │ ├── meta_info_DIV2K800sub_GT.txt
│ │ │ ├── meta_info_REDS4_test_GT.txt
│ │ │ ├── meta_info_REDS_GT.txt
│ │ │ ├── meta_info_REDSofficial4_test_GT.txt
│ │ │ ├── meta_info_REDSval_official_test_GT.txt
│ │ │ ├── meta_info_Vimeo90K_test_GT.txt
│ │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt
│ │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt
│ │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt
│ │ │ └── meta_info_Vimeo90K_train_GT.txt
│ │ ├── paired_image_dataset.py
│ │ ├── prefetch_dataloader.py
│ │ ├── realesrgan_dataset.py
│ │ ├── realesrgan_paired_dataset.py
│ │ ├── reds_dataset.py
│ │ ├── single_image_dataset.py
│ │ ├── transforms.py
│ │ ├── video_test_dataset.py
│ │ └── vimeo90k_dataset.py
│ ├── losses/
│ │ ├── __init__.py
│ │ ├── basic_loss.py
│ │ ├── gan_loss.py
│ │ └── loss_util.py
│ ├── metrics/
│ │ ├── README.md
│ │ ├── README_CN.md
│ │ ├── __init__.py
│ │ ├── fid.py
│ │ ├── metric_util.py
│ │ ├── niqe.py
│ │ ├── niqe_pris_params.npz
│ │ ├── psnr_ssim.py
│ │ └── test_metrics/
│ │ └── test_psnr_ssim.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── edvr_model.py
│ │ ├── esrgan_model.py
│ │ ├── hifacegan_model.py
│ │ ├── lr_scheduler.py
│ │ ├── realesrgan_model.py
│ │ ├── realesrnet_model.py
│ │ ├── sr_model.py
│ │ ├── srgan_model.py
│ │ ├── stylegan2_model.py
│ │ ├── swinir_model.py
│ │ ├── video_base_model.py
│ │ ├── video_gan_model.py
│ │ ├── video_recurrent_gan_model.py
│ │ └── video_recurrent_model.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── dcn/
│ │ │ ├── __init__.py
│ │ │ ├── deform_conv.py
│ │ │ └── src/
│ │ │ ├── deform_conv_cuda.cpp
│ │ │ ├── deform_conv_cuda_kernel.cu
│ │ │ └── deform_conv_ext.cpp
│ │ ├── fused_act/
│ │ │ ├── __init__.py
│ │ │ ├── fused_act.py
│ │ │ └── src/
│ │ │ ├── fused_bias_act.cpp
│ │ │ └── fused_bias_act_kernel.cu
│ │ └── upfirdn2d/
│ │ ├── __init__.py
│ │ ├── src/
│ │ │ ├── upfirdn2d.cpp
│ │ │ └── upfirdn2d_kernel.cu
│ │ └── upfirdn2d.py
│ ├── test.py
│ ├── train.py
│ └── utils/
│ ├── __init__.py
│ ├── color_util.py
│ ├── diffjpeg.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_process_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ ├── options.py
│ ├── plot_util.py
│ └── registry.py
├── colab/
│ └── README.md
├── docs/
│ ├── BasicSR_docs_CN.md
│ ├── Config.md
│ ├── DatasetPreparation.md
│ ├── DatasetPreparation_CN.md
│ ├── DesignConvention.md
│ ├── FAQ.md
│ ├── HOWTOs.md
│ ├── HOWTOs_CN.md
│ ├── INSTALL.md
│ ├── Logging.md
│ ├── Logging_CN.md
│ ├── Makefile
│ ├── Metrics.md
│ ├── Metrics_CN.md
│ ├── ModelZoo.md
│ ├── ModelZoo_CN.md
│ ├── Models.md
│ ├── README.md
│ ├── TrainTest.md
│ ├── TrainTest_CN.md
│ ├── auto_generate_api.py
│ ├── conf.py
│ ├── history_updates.md
│ ├── index.rst
│ ├── introduction.md
│ ├── make.bat
│ └── requirements.txt
├── inference/
│ ├── inference_basicvsr.py
│ ├── inference_basicvsrpp.py
│ ├── inference_dfdnet.py
│ ├── inference_esrgan.py
│ ├── inference_ridnet.py
│ ├── inference_stylegan2.py
│ └── inference_swinir.py
├── options/
│ ├── test/
│ │ ├── BasicVSR/
│ │ │ ├── test_BasicVSR_REDS.yml
│ │ │ ├── test_BasicVSR_Vimeo90K_BDx4.yml
│ │ │ ├── test_BasicVSR_Vimeo90K_BIx4.yml
│ │ │ ├── test_IconVSR_REDS.yml
│ │ │ ├── test_IconVSR_Vimeo90K_BDx4.yml
│ │ │ └── test_IconVSR_Vimeo90K_BIx4.yml
│ │ ├── DUF/
│ │ │ └── test_DUF_official.yml
│ │ ├── EDSR/
│ │ │ ├── test_EDSR_Lx2.yml
│ │ │ ├── test_EDSR_Lx3.yml
│ │ │ ├── test_EDSR_Lx4.yml
│ │ │ ├── test_EDSR_Mx2.yml
│ │ │ ├── test_EDSR_Mx3.yml
│ │ │ └── test_EDSR_Mx4.yml
│ │ ├── EDVR/
│ │ │ ├── test_EDVR_L_deblur_REDS.yml
│ │ │ ├── test_EDVR_L_deblurcomp_REDS.yml
│ │ │ ├── test_EDVR_L_x4_SR_REDS.yml
│ │ │ ├── test_EDVR_L_x4_SR_Vid4.yml
│ │ │ ├── test_EDVR_L_x4_SR_Vimeo90K.yml
│ │ │ ├── test_EDVR_L_x4_SRblur_REDS.yml
│ │ │ └── test_EDVR_M_x4_SR_REDS.yml
│ │ ├── ESRGAN/
│ │ │ ├── test_ESRGAN_x4.yml
│ │ │ ├── test_ESRGAN_x4_woGT.yml
│ │ │ └── test_RRDBNet_PSNR_x4.yml
│ │ ├── HiFaceGAN/
│ │ │ ├── test_hifacegan.yml
│ │ │ └── test_hifacegan_woGT.yml
│ │ ├── RCAN/
│ │ │ └── test_RCAN.yml
│ │ ├── SRResNet_SRGAN/
│ │ │ ├── test_MSRGAN_x4.yml
│ │ │ ├── test_MSRResNet_x2.yml
│ │ │ ├── test_MSRResNet_x3.yml
│ │ │ ├── test_MSRResNet_x4.yml
│ │ │ └── test_MSRResNet_x4_woGT.yml
│ │ └── TOF/
│ │ └── test_TOF_official.yml
│ └── train/
│ ├── BasicVSR/
│ │ ├── train_BasicVSR_REDS.yml
│ │ ├── train_BasicVSR_Vimeo90K_BDx4.yml
│ │ ├── train_BasicVSR_Vimeo90K_BIx4.yml
│ │ ├── train_IconVSR_REDS.yml
│ │ ├── train_IconVSR_Vimeo90K_BDx4.yml
│ │ └── train_IconVSR_Vimeo90K_BIx4.yml
│ ├── BasicVSRPP/
│ │ └── train_BasicVSRPP_REDS.yml
│ ├── ECBSR/
│ │ ├── train_ECBSR_x2_m4c16_prelu.yml
│ │ ├── train_ECBSR_x4_m4c16_prelu.yml
│ │ └── train_ECBSR_x4_m4c16_prelu_RGB.yml
│ ├── EDSR/
│ │ ├── train_EDSR_Lx2.yml
│ │ ├── train_EDSR_Lx3.yml
│ │ ├── train_EDSR_Lx4.yml
│ │ ├── train_EDSR_Mx2.yml
│ │ ├── train_EDSR_Mx3.yml
│ │ └── train_EDSR_Mx4.yml
│ ├── EDVR/
│ │ ├── train_EDVRM_woTSA_GAN_TODO.yml
│ │ ├── train_EDVR_L_x4_SR_REDS.yml
│ │ ├── train_EDVR_L_x4_SR_REDS_woTSA.yml
│ │ ├── train_EDVR_M_x4_SR_REDS.yml
│ │ └── train_EDVR_M_x4_SR_REDS_woTSA.yml
│ ├── ESRGAN/
│ │ ├── train_ESRGAN_x4.yml
│ │ └── train_RRDBNet_PSNR_x4.yml
│ ├── HiFaceGAN/
│ │ └── train_hifacegan.yml
│ ├── LDL/
│ │ └── train_LDL_Real_x4.yml
│ ├── RCAN/
│ │ └── train_RCAN_x2.yml
│ ├── RealESRGAN/
│ │ ├── train_realesrgan_x2plus.yml
│ │ ├── train_realesrgan_x4plus.yml
│ │ ├── train_realesrnet_x2plus.yml
│ │ └── train_realesrnet_x4plus.yml
│ ├── SRResNet_SRGAN/
│ │ ├── README.md
│ │ ├── train_MSRGAN_x4.yml
│ │ ├── train_MSRResNet_x2.yml
│ │ ├── train_MSRResNet_x3.yml
│ │ └── train_MSRResNet_x4.yml
│ ├── StyleGAN/
│ │ └── train_StyleGAN2_256_Cmul2_FFHQ.yml
│ ├── SwinIR/
│ │ ├── train_SwinIR_SRx2_scratch.yml
│ │ └── train_SwinIR_SRx4_scratch.yml
│ └── VideoRecurrentGAN/
│ └── train_VideoRecurrentGANModel_REDS.yml
├── requirements.txt
├── scripts/
│ ├── data_preparation/
│ │ ├── create_lmdb.py
│ │ ├── download_datasets.py
│ │ ├── extract_images_from_tfrecords.py
│ │ ├── extract_subimages.py
│ │ ├── generate_meta_info.py
│ │ ├── prepare_hifacegan_dataset.py
│ │ └── regroup_reds_dataset.py
│ ├── dist_test.sh
│ ├── dist_train.sh
│ ├── download_gdrive.py
│ ├── download_pretrained_models.py
│ ├── matlab_scripts/
│ │ ├── back_projection/
│ │ │ ├── backprojection.m
│ │ │ ├── main_bp.m
│ │ │ └── main_reverse_filter.m
│ │ ├── generate_LR_Vimeo90K.m
│ │ └── generate_bicubic_img.m
│ ├── metrics/
│ │ ├── calculate_fid_folder.py
│ │ ├── calculate_fid_stats_from_datasets.py
│ │ ├── calculate_lpips.py
│ │ ├── calculate_niqe.py
│ │ ├── calculate_psnr_ssim.py
│ │ └── calculate_stylegan2_fid.py
│ ├── model_conversion/
│ │ ├── convert_dfdnet.py
│ │ ├── convert_models.py
│ │ ├── convert_ridnet.py
│ │ └── convert_stylegan.py
│ ├── plot/
│ │ ├── README.md
│ │ └── model_complexity_cmp_bsrn.py
│ └── publish_models.py
├── setup.cfg
├── setup.py
├── test_scripts/
│ ├── test_discriminator_backward.py
│ ├── test_ffhq_dataset.py
│ ├── test_lr_scheduler.py
│ ├── test_niqe.py
│ ├── test_paired_image_dataset.py
│ ├── test_reds_dataset.py
│ └── test_vimeo90k_dataset.py
└── tests/
├── README.md
├── data/
│ ├── gt.lmdb/
│ │ ├── data.mdb
│ │ ├── lock.mdb
│ │ └── meta_info.txt
│ ├── lq.lmdb/
│ │ ├── data.mdb
│ │ ├── lock.mdb
│ │ └── meta_info.txt
│ ├── meta_info_gt.txt
│ └── meta_info_pair.txt
├── test_archs/
│ ├── test_basicvsr_arch.py
│ ├── test_discriminator_arch.py
│ ├── test_duf_arch.py
│ ├── test_ecbsr_arch.py
│ └── test_srresnet_arch.py
├── test_data/
│ ├── test_paired_image_dataset.py
│ └── test_single_image_dataset.py
├── test_losses/
│ └── test_losses.py
├── test_metrics/
│ └── test_psnr_ssim.py
└── test_models/
└── test_sr_model.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/publish-pip.yml
================================================
name: PyPI Publish
on: push
jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch (cpu)
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install dependencies
run: pip install -r requirements.txt
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Build for distribution
# remove bdist_wheel for pip installation with compiling cuda extensions
run: python setup.py sdist
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
password: ${{ secrets.PYPI_API_TOKEN }}
================================================
FILE: .github/workflows/pylint.yml
================================================
name: PyLint
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install codespell flake8 isort yapf
- name: Lint
run: |
codespell
flake8 .
isort --check-only --diff basicsr/ options/ scripts/ tests/ inference/ setup.py
yapf -r -d basicsr/ options/ scripts/ tests/ inference/ setup.py
================================================
FILE: .github/workflows/release.yml
================================================
name: release
on:
push:
tags:
- '*'
jobs:
build:
permissions: write-all
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ github.ref }}
release_name: BasicSR ${{ github.ref }} Release Note
body: |
🚀 See you again 😸
🚀Have a nice day 😸 and happy everyday 😃
🚀 Long time no see ☄️
✨ **Highlights**
✅ [Features] Support ...
🐛 **Bug Fixes**
🌴 **Improvements**
📢📢📢
<p align="center">
<img src="https://raw.githubusercontent.com/XPixelGroup/BasicSR/master/assets/basicsr_xpixel_logo.png" height=150>
</p>
draft: true
prerelease: false
================================================
FILE: .gitignore
================================================
# ignored folders
datasets/*
experiments/*
results/*
tb_logger/*
wandb/*
tmp/*
docs/api
scripts/__init__.py
*.DS_Store
.idea
# ignored files
version.py
# ignored files with suffix
*.html
*.png
*.jpeg
*.jpg
*.gif
*.pth
*.zip
# template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
# flake8
- repo: https://github.com/PyCQA/flake8
rev: 3.8.3
hooks:
- id: flake8
args: ["--config=setup.cfg", "--ignore=W504, W503"]
# modify known_third_party
- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config
# isort
- repo: https://github.com/timothycrosley/isort
rev: 5.2.2
hooks:
- id: isort
# yapf
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.30.0
hooks:
- id: yapf
# codespell
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
- id: codespell
# pre-commit-hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace # Trim trailing whitespace
- id: check-yaml # Attempt to load all yaml files to verify syntax
- id: check-merge-conflict # Check for files that contain merge conflict strings
- id: double-quote-string-fixer # Replace double quoted strings with single quoted strings
- id: end-of-file-fixer # Make sure files end in a newline and only a newline
- id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0
- id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*-
args: ["--remove"]
- id: mixed-line-ending # Replace or check mixed line ending
args: ["--fix=lf"]
================================================
FILE: .readthedocs.yaml
================================================
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Set the version of Python and other tools you might need
build:
os: ubuntu-20.04
tools:
python: "3.8"
# You can also specify other tool versions:
# nodejs: "16"
# rust: "1.55"
# golang: "1.17"
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
# If using Sphinx, optionally build your docs in additional formats such as PDF
# formats:
# - pdf
# Optionally declare the Python requirements required to build your docs
python:
install:
- requirements: docs/requirements.txt
================================================
FILE: .vscode/settings.json
================================================
{
"files.trimTrailingWhitespace": true,
"editor.wordWrap": "on",
"editor.rulers": [
80,
120
],
"editor.renderWhitespace": "all",
"editor.renderControlCharacters": true,
"python.formatting.provider": "yapf",
"python.formatting.yapfArgs": [
"--style",
"{BASED_ON_STYLE = pep8, BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true, SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true, COLUMN_LIMIT = 120}"
],
"python.linting.flake8Enabled": true,
"python.linting.flake8Args": [
"max-line-length=120"
],
}
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "If you use this project, please cite it as below."
title: "BasicSR: Open Source Image and Video Restoration Toolbox"
version: 1.3.5
date-released: 2022-02-16
url: "https://github.com/XPixelGroup/BasicSR"
license: Apache-2.0
authors:
- family-names: Wang
given-names: Xintao
- family-names: Xie
given-names: Liangbin
- family-names: Yu
given-names: Ke
- family-names: Chan
given-names: Kelvin C.K.
- family-names: Loy
given-names: Chen Change
- family-names: Dong
given-names: Chao
================================================
FILE: LICENSE/LICENSE-NVIDIA
================================================
Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
Nvidia Source Code License-NC
=======================================================================
1. Definitions
"Licensor" means any person or entity that distributes its Work.
"Software" means the original work of authorship made available under
this License.
"Work" means the Software and any additions to or derivative works of
the Software that are made available under this License.
"Nvidia Processors" means any central processing unit (CPU), graphics
processing unit (GPU), field-programmable gate array (FPGA),
application-specific integrated circuit (ASIC) or any combination
thereof designed, made, sold, or provided by Nvidia or its affiliates.
The terms "reproduce," "reproduction," "derivative works," and
"distribution" have the meaning as provided under U.S. copyright law;
provided, however, that for the purposes of this License, derivative
works shall not include works that remain separable from, or merely
link (or bind by name) to the interfaces of, the Work.
Works, including the Software, are "made available" under this License
by including in or with the Work either (a) a copyright notice
referencing the applicability of this License to the Work, or (b) a
copy of this License.
2. License Grants
2.1 Copyright Grant. Subject to the terms and conditions of this
License, each Licensor grants to you a perpetual, worldwide,
non-exclusive, royalty-free, copyright license to reproduce,
prepare derivative works of, publicly display, publicly perform,
sublicense and distribute its Work and any resulting derivative
works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only
if (a) you do so under this License, (b) you include a complete
copy of this License with your distribution, and (c) you retain
without modification any copyright, patent, trademark, or
attribution notices that are present in the Work.
3.2 Derivative Works. You may specify that additional or different
terms apply to the use, reproduction, and distribution of your
derivative works of the Work ("Your Terms") only if (a) Your Terms
provide that the use limitation in Section 3.3 applies to your
derivative works, and (b) you identify the specific derivative
works that are subject to Your Terms. Notwithstanding Your Terms,
this License (including the redistribution requirements in Section
3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only
may be used or intended for use non-commercially. The Work or
derivative works thereof may be used or intended for use by Nvidia
or its affiliates commercially or non-commercially. As used herein,
"non-commercially" means for research or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim
against any Licensor (including any claim, cross-claim or
counterclaim in a lawsuit) to enforce any patents that you allege
are infringed by any Work, then your rights under this License from
such Licensor (including the grants in Sections 2.1 and 2.2) will
terminate immediately.
3.5 Trademarks. This License does not grant any rights to use any
Licensor's or its affiliates' names, logos, or trademarks, except
as necessary to reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your
rights under this License (including the grants in Sections 2.1 and
2.2) will terminate immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
THE POSSIBILITY OF SUCH DAMAGES.
=======================================================================
================================================
FILE: LICENSE/LICENSE-stylegan2-pytorch
================================================
MIT License
Copyright (c) 2019 Kim Seonghyeon
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: LICENSE/LICENSE_SwinIR
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2021] [SwinIR Authors]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: LICENSE/LICENSE_pytorch-image-models
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2019 Ross Wightman
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: LICENSE/README.md
================================================
# License and Acknowledgement
This BasicSR project is released under the Apache 2.0 license.
- StyleGAN2
- The codes are modified from the repository [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). Many thanks to the author - [Kim Seonghyeon](https://rosinality.github.io/) :blush: for translating from the official TensorFlow codes to PyTorch ones. Here is the [license](LICENSE-stylegan2-pytorch) of stylegan2-pytorch.
- The official repository is <https://github.com/NVlabs/stylegan2>, and here is the [NVIDIA license](./LICENSE-NVIDIA).
- DFDNet
- The codes are largely modified from the repository [DFDNet](https://github.com/csxmli2016/DFDNet). Their license is [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by-nc-sa/4.0/).
- DiffJPEG
- Modified from <https://github.com/mlomnitz/DiffJPEG>.
- [pytorch-image-models](https://github.com/rwightman/pytorch-image-models/)
- We use the implementation of `DropPath` and `trunc_normal_` from [pytorch-image-models](https://github.com/rwightman/pytorch-image-models/). The LICENSE is included as [LICENSE_pytorch-image-models](LICENSE/LICENSE_pytorch-image-models).
- [SwinIR](https://github.com/JingyunLiang/SwinIR)
- The arch implementation of SwinIR is from [SwinIR](https://github.com/JingyunLiang/SwinIR). The LICENSE is included as [LICENSE_SwinIR](LICENSE/LICENSE_SwinIR).
- [ECBSR](https://github.com/xindongzhang/ECBSR)
- The arch implementation of ECBSR is from [ECBSR](https://github.com/xindongzhang/ECBSR). The LICENSE of ECBSR is [Apache License 2.0](https://github.com/xindongzhang/ECBSR/blob/main/LICENSE)
## References
1. NIQE metric: the codes are translated from the [official MATLAB codes](http://live.ece.utexas.edu/research/quality/niqe_release.zip)
> A. Mittal, R. Soundararajan and A. C. Bovik, "Making a Completely Blind Image Quality Analyzer", IEEE Signal Processing Letters, 2012.
1. FID metric: the codes are modified from [pytorch-fid](https://github.com/mseitzer/pytorch-fid) and [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch).
================================================
FILE: LICENSE.txt
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2018-2022 BasicSR Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MANIFEST.in
================================================
include basicsr/ops/dcn/src/*.cu basicsr/ops/dcn/src/*.cpp
include basicsr/ops/fused_act/src/*.cu basicsr/ops/fused_act/src/*.cpp
include basicsr/ops/upfirdn2d/src/*.cu basicsr/ops/upfirdn2d/src/*.cpp
include basicsr/metrics/niqe_pris_params.npz
include VERSION
include requirements.txt
================================================
FILE: README.md
================================================
<p align="center">
<img src="assets/basicsr_xpixel_logo.png" height=120>
</p>
## <div align="center"><b><a href="README.md">English</a> | <a href="README_CN.md">简体中文</a></b></div>
<div align="center">
[](https://github.com/xinntao/BasicSR/blob/master/LICENSE.txt)
[](https://pypi.org/project/basicsr/)
[](https://lgtm.com/projects/g/xinntao/BasicSR/context:python)
[](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/pylint.yml)
[](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/publish-pip.yml)
[](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/gitee-mirror.yml)
</div>
<div align="center">
⚡[**HowTo**](#-HOWTOs) **|** 🔧[**Installation**](docs/INSTALL.md) **|** 💻[**Training Commands**](docs/TrainTest.md) **|** 🐢[**DatasetPrepare**](docs/DatasetPreparation.md) **|** 🏰[**Model Zoo**](docs/ModelZoo.md)
📕[**中文解读文档**](https://github.com/XPixelGroup/BasicSR-docs) **|** 📊 [**Plot scripts**](scripts/plot) **|** 📝[Introduction](docs/introduction.md) **|** <a href="https://github.com/XPixelGroup/BasicSR/tree/master/colab"><img src="https://colab.research.google.com/assets/colab-badge.svg" height="18" alt="google colab logo"></a> **|** ⏳[TODO List](https://github.com/xinntao/BasicSR/projects) **|** ❓[FAQ](docs/FAQ.md)
</div>
🚀 We add [BasicSR-Examples](https://github.com/xinntao/BasicSR-examples), which provides guidance and templates of using BasicSR as a python package. 🚀 <br>
📢 **技术交流QQ群**:**320960100**   入群答案:**互帮互助共同进步** <br>
🧭 [入群二维码](#-contact) (QQ、微信)    [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u) <br>
---
BasicSR (**Basic** **S**uper **R**estoration) is an open-source **image and video restoration** toolbox based on PyTorch, such as super-resolution, denoise, deblurring, JPEG artifacts removal, *etc*.<br>
BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 图像视频复原工具箱, 比如 超分辨率, 去噪, 去模糊, 去 JPEG 压缩噪声等.
🚩 **New Features/Updates**
- ✅ July 26, 2022. Add plot scripts 📊[Plot](scripts/plot).
- ✅ May 9, 2022. BasicSR joins [XPixel](http://xpixel.group/).
- ✅ Oct 5, 2021. Add **ECBSR training and testing** codes: [ECBSR](https://github.com/xindongzhang/ECBSR).
> ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
- ✅ Sep 2, 2021. Add **SwinIR training and testing** codes: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang). More details are in [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr)
- ✅ Aug 5, 2021. Add NIQE, which produces the same results as MATLAB (both are 5.7296 for tests/data/baboon.png).
- ✅ July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181).
> CVPR21: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond
- **[More](docs/history_updates.md)**
---
If BasicSR helps your research or work, please help to ⭐ this repo or recommend it to your friends. Thanks😊 <br>
Other recommended projects:<br>
▶️ [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN): A practical algorithm for general image restoration<br>
▶️ [GFPGAN](https://github.com/TencentARC/GFPGAN): A practical algorithm for real-world face restoration <br>
▶️ [facexlib](https://github.com/xinntao/facexlib): A collection that provides useful face-relation functions.<br>
▶️ [HandyView](https://github.com/xinntao/HandyView): A PyQt5-based image viewer that is handy for view and comparison. <br>
▶️ [HandyFigure](https://github.com/xinntao/HandyFigure): Open source of paper figures <br>
<sub>([ESRGAN](https://github.com/xinntao/ESRGAN), [EDVR](https://github.com/xinntao/EDVR), [DNI](https://github.com/xinntao/DNI), [SFTGAN](https://github.com/xinntao/SFTGAN))</sub>
<sub>([HandyCrawler](https://github.com/xinntao/HandyCrawler), [HandyWriting](https://github.com/xinntao/HandyWriting))</sub>
---
## ⚡ HOWTOs
We provide simple pipelines to train/test/inference models for a quick start.
These pipelines/commands cannot cover all the cases and more details are in the following sections.
| GAN | | | | | |
| :------------------- | :--------------------------------------------: | :----------------------------------------------------: | :------- | :--------------------------------------------: | :----------------------------------------------------: |
| StyleGAN2 | [Train](docs/HOWTOs.md#How-to-train-StyleGAN2) | [Inference](docs/HOWTOs.md#How-to-inference-StyleGAN2) | | | |
| **Face Restoration** | | | | | |
| DFDNet | - | [Inference](docs/HOWTOs.md#How-to-inference-DFDNet) | | | |
| **Super Resolution** | | | | | |
| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO* |
| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO* |
| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr) |
| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
| **Deblurring** | | | | | |
| DeblurGANv2 | - | *TODO* | | | |
| **Denoise** | | | | | |
| RIDNet | - | *TODO* | CBDNet | - | *TODO* |
## ✨ **Projects that use BasicSR**
- [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): A practical algorithm for general image restoration
- [**GFPGAN**](https://github.com/TencentARC/GFPGAN): A practical algorithm for real-world face restoration
If you use `BasicSR` in your open-source projects, welcome to contact me (by [email](#-contact) or opening an issue/pull request). I will add your projects to the above list 😊
## 📜 License and Acknowledgement
This project is released under the [Apache 2.0 license](LICENSE.txt).<br>
More details about **license** and **acknowledgement** are in [LICENSE](LICENSE/README.md).
## 🌏 Citations
If BasicSR helps your research or work, please cite BasicSR.<br>
The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX package.
``` latex
@misc{basicsr,
author = {Xintao Wang and Liangbin Xie and Ke Yu and Kelvin C.K. Chan and Chen Change Loy and Chao Dong},
title = {{BasicSR}: Open Source Image and Video Restoration Toolbox},
howpublished = {\url{https://github.com/XPixelGroup/BasicSR}},
year = {2022}
}
```
> Xintao Wang, Liangbin Xie, Ke Yu, Kelvin C.K. Chan, Chen Change Loy and Chao Dong. BasicSR: Open Source Image and Video Restoration Toolbox. <https://github.com/xinntao/BasicSR>, 2022.
## 📧 Contact
If you have any questions, please email `xintao.alpha@gmail.com`, `xintao.wang@outlook.com`.
<br>
- **QQ群**: 扫描左边二维码 或者 搜索QQ群号: 320960100 入群答案:互帮互助共同进步
- **微信群**: 我们的一群已经满500人啦,二群也超过200人了;进群可以添加 Liangbin 的个人微信 (右边二维码),他会在空闲的时候拉大家入群~
<p align="center">
<img src="https://user-images.githubusercontent.com/17445847/134879983-6f2d663b-16e7-49f2-97e1-7c53c8a5f71a.jpg" height="300">  
<img src="https://user-images.githubusercontent.com/17445847/139572512-8e192aac-00fa-432b-ac8e-a33026b019df.png" height="300">
</p>
 (start from 2022-11-06)
================================================
FILE: README_CN.md
================================================
<p align="center">
<img src="assets/basicsr_xpixel_logo.png" height=120>
</p>
## <div align="center"><b><a href="README.md">English</a> | <a href="README_CN.md">简体中文</a></b></div>
[](https://github.com/xinntao/BasicSR/blob/master/LICENSE.txt)
[](https://pypi.org/project/basicsr/)
[](https://lgtm.com/projects/g/xinntao/BasicSR/context:python)
[](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/pylint.yml)
[](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/publish-pip.yml)
[](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/gitee-mirror.yml)
<!-- [English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/BasicSR) **|** [Gitee码云](https://gitee.com/xinntao/BasicSR) -->
:rocket: 我们添加了 [BasicSR-Examples](https://github.com/xinntao/BasicSR-examples), 它提供了使用BasicSR的指南以及模板 (以python package的形式) :rocket:
:loudspeaker: **技术交流QQ群**:**320960100**   入群答案:**互帮互助共同进步**
:compass: [入群二维码](#e-mail-%E8%81%94%E7%B3%BB) (QQ、微信)    [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u)
---
<a href="https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" height="18" alt="google colab logo"></a> Google Colab: [GitHub Link](colab) **|** [Google Drive Link](https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing) <br>
:m: [模型库](docs/ModelZoo_CN.md): :arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ)
:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) <br>
:file_folder: [数据](docs/DatasetPreparation_CN.md): :arrow_double_down: [百度网盘](https://pan.baidu.com/s/1AZDcEAFwwc1OC3KCd7EDnQ) (提取码:basr) :arrow_double_down: [Google Drive](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing) <br>
:chart_with_upwards_trend: [wandb的训练曲线](https://app.wandb.ai/xintao/basicsr) <br>
:computer: [训练和测试的命令](docs/TrainTest_CN.md) <br>
:zap: [HOWTOs](#zap-howtos)
---
BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源图像视频复原工具箱, 比如 超分辨率, 去噪, 去模糊, 去 JPEG 压缩噪声等.
:triangular_flag_on_post: **新的特性/更新**
- :white_check_mark: Oct 5, 2021. 添加 **ECBSR 训练和测试** 代码: [ECBSR](https://github.com/xindongzhang/ECBSR).
> ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
- :white_check_mark: Sep 2, 2021. 添加 **SwinIR 训练和测试** 代码: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang). 更多内容参见 [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr)
- :white_check_mark: Aug 5, 2021. 添加了NIQE, 它输出和MATLAB一样的结果 (both are 5.7296 for tests/data/baboon.png).
- :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181).
> CVPR21: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond
- **[更多](docs/history_updates.md)**
:sparkles: **使用 BasicSR 的项目**
- [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): 通用图像复原的实用算法
- [**GFPGAN**](https://github.com/TencentARC/GFPGAN): 真实场景人脸复原的实用算法
如果你的开源项目中使用了`BasicSR`, 欢迎联系我 ([邮件](#e-mail-%E8%81%94%E7%B3%BB)或者开一个issue/pull request)。我会将你的开源项目添加到上面的列表中 :blush:
---
如果 BasicSR 对你有所帮助,欢迎 :star: 这个仓库或推荐给你的朋友。Thanks:blush: <br>
其他推荐的项目:<br>
:arrow_forward: [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN): 通用图像复原的实用算法<br>
:arrow_forward: [GFPGAN](https://github.com/TencentARC/GFPGAN): 真实场景人脸复原的实用算法<br>
:arrow_forward: [facexlib](https://github.com/xinntao/facexlib): 提供实用的人脸相关功能的集合<br>
:arrow_forward: [HandyView](https://github.com/xinntao/HandyView): 基于PyQt5的 方便的看图比图工具<br>
<sub>([ESRGAN](https://github.com/xinntao/ESRGAN), [EDVR](https://github.com/xinntao/EDVR), [DNI](https://github.com/xinntao/DNI), [SFTGAN](https://github.com/xinntao/SFTGAN))</sub>
<sub>([HandyView](https://gitee.com/xinntao/HandyView), [HandyFigure](https://gitee.com/xinntao/HandyFigure), [HandyCrawler](https://gitee.com/xinntao/HandyCrawler), [HandyWriting](https://gitee.com/xinntao/HandyWriting))</sub>
---
## :zap: HOWTOs
我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分.
| GAN | | | | | |
| :------------------- | :------------------------------------------: | :------------------------------------------: | :------- | :--------------------------------------------: | :----------------------------------------------------: |
| StyleGAN2 | [训练](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [测试](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | | | |
| **Face Restoration** | | | | | |
| DFDNet | - | [测试](docs/HOWTOs_CN.md#如何测试-DFDNet) | | | |
| **Super Resolution** | | | | | |
| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO* |
| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO* |
| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr) |
| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
| **Deblurring** | | | | | |
| DeblurGANv2 | - | *TODO* | | | |
| **Denoise** | | | | | |
| RIDNet | - | *TODO* | CBDNet | - | *TODO* |
## :wrench: 依赖和安装
For detailed instructions refer to [docs/INSTALL.md](docs/INSTALL.md).
## :hourglass_flowing_sand: TODO 清单
参见 [project boards](https://github.com/xinntao/BasicSR/projects).
## :turtle: 数据准备
- 数据准备步骤, 参见 **[DatasetPreparation_CN.md](docs/DatasetPreparation_CN.md)**.
- 目前支持的数据集 (`torch.utils.data.Dataset`类), 参见 [Datasets_CN.md](docs/Datasets_CN.md).
## :computer: 训练和测试
- **训练和测试的命令**, 参见 **[TrainTest_CN.md](docs/TrainTest_CN.md)**.
- **Options/Configs**配置文件的说明, 参见 [Config_CN.md](docs/Config_CN.md).
- **Logging**日志系统的说明, 参见 [Logging_CN.md](docs/Logging_CN.md).
## :european_castle: 模型库和基准
- 目前支持的模型描述, 参见 [Models_CN.md](docs/Models_CN.md).
- **预训练模型和log样例**, 参见 **[ModelZoo_CN.md](docs/ModelZoo_CN.md)**.
- 我们也在 [wandb](https://app.wandb.ai/xintao/basicsr) 上提供了**训练曲线**等:
<p align="center">
<a href="https://app.wandb.ai/xintao/basicsr" target="_blank">
<img src="./assets/wandb.jpg" height="280">
</a></p>
## :memo: 代码库的设计和约定
参见 [DesignConvention_CN.md](docs/DesignConvention_CN.md).<br>
下图概括了整体的框架. 每个模块更多的描述参见: <br>
**[Datasets_CN.md](docs/Datasets_CN.md)** | **[Models_CN.md](docs/Models_CN.md)** | **[Config_CN.md](docs/Config_CN.md)** | **[Logging_CN.md](docs/Logging_CN.md)**

## :scroll: 许可
本项目使用 Apache 2.0 license.<br>
更多关于**许可**和**致谢**, 请参见 [LICENSE](LICENSE/README.md).
## :earth_asia: 引用
如果 BasicSR 对你有帮助, 请引用BasicSR. <br>
下面是一个 BibTex 引用条目, 它需要 `url` LaTeX package.
``` latex
@misc{basicsr,
author = {Xintao Wang and Liangbin Xie and Ke Yu and Kelvin C.K. Chan and Chen Change Loy and Chao Dong},
title = {{BasicSR}: Open Source Image and Video Restoration Toolbox},
howpublished = {\url{https://github.com/XPixelGroup/BasicSR}},
year = {2022}
}
```
> Xintao Wang, Liangbin Xie, Ke Yu, Kelvin C.K. Chan, Chen Change Loy and Chao Dong. BasicSR: Open Source Image and Video Restoration Toolbox. <https://github.com/xinntao/BasicSR>, 2022.
## :e-mail: 联系
若有任何问题, 请电邮 `xintao.alpha@gmail.com`, `xintao.wang@outlook.com`.
<br>
- **QQ群**: 扫描左边二维码 或者 搜索QQ群号: 320960100 入群答案:互帮互助共同进步
- **微信群**: 我们的群一已经满500人啦,进群二可以扫描中间的二维码;如果进群遇到问题,也可以添加 Liangbin 的个人微信 (右边二维码),他会在空闲的时候拉大家入群~
<p align="center">
<img src="https://user-images.githubusercontent.com/17445847/134879983-6f2d663b-16e7-49f2-97e1-7c53c8a5f71a.jpg" height="300">  
<img src="https://user-images.githubusercontent.com/52127135/172553058-6cf32e10-2959-42dd-b26a-f802f09343b0.png" height="300">  
<img src="https://user-images.githubusercontent.com/17445847/139572512-8e192aac-00fa-432b-ac8e-a33026b019df.png" height="300">
</p>
================================================
FILE: VERSION
================================================
1.4.2
================================================
FILE: basicsr/__init__.py
================================================
# https://github.com/xinntao/BasicSR
# flake8: noqa
from .archs import *
from .data import *
from .losses import *
from .metrics import *
from .models import *
from .ops import *
from .test import *
from .train import *
from .utils import *
from .version import __gitsha__, __version__
================================================
FILE: basicsr/archs/__init__.py
================================================
import importlib
from copy import deepcopy
from os import path as osp
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import ARCH_REGISTRY
__all__ = ['build_network']
# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
def build_network(opt):
opt = deepcopy(opt)
network_type = opt.pop('type')
net = ARCH_REGISTRY.get(network_type)(**opt)
logger = get_root_logger()
logger.info(f'Network [{net.__class__.__name__}] is created.')
return net
================================================
FILE: basicsr/archs/arch_util.py
================================================
import collections.abc
import math
import torch
import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
from basicsr.utils import get_root_logger
@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
super(ResidualBlockNoBN, self).__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.relu = nn.ReLU(inplace=True)
if not pytorch_init:
default_init_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
"""Warp an image or feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
padding_mode (str): 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Before pytorch 1.3, the default value is
align_corners=True. After pytorch 1.3, the default value is
align_corners=False. Here, we use the True as default.
Returns:
Tensor: Warped image or feature map.
"""
assert x.size()[-2:] == flow.size()[1:3]
_, _, h, w = x.size()
# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
# TODO, what if align_corners=False
return output
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
"""Resize a flow according to ratio or shape.
Args:
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
size_type (str): 'ratio' or 'shape'.
sizes (list[int | float]): the ratio for resizing or the final output
shape.
1) The order of ratio should be [ratio_h, ratio_w]. For
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
ratio > 1.0).
2) The order of output_size should be [out_h, out_w].
interp_mode (str): The mode of interpolation for resizing.
Default: 'bilinear'.
align_corners (bool): Whether align corners. Default: False.
Returns:
Tensor: Resized flow.
"""
_, _, flow_h, flow_w = flow.size()
if size_type == 'ratio':
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
elif size_type == 'shape':
output_h, output_w = sizes[0], sizes[1]
else:
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
input_flow = flow.clone()
ratio_h = output_h / flow_h
ratio_w = output_w / flow_w
input_flow[:, 0, :, :] *= ratio_w
input_flow[:, 1, :, :] *= ratio_h
resized_flow = F.interpolate(
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
return resized_flow
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
class DCNv2Pack(ModulatedDeformConvPack):
"""Modulated deformable conv for deformable alignment.
Different from the official DCNv2Pack, which generates offsets and masks
from the preceding features, this DCNv2Pack takes another different
features to generate offsets and masks.
``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
"""
def forward(self, x, feat):
out = self.conv_offset(feat)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
offset_absmean = torch.mean(torch.abs(offset))
if offset_absmean > 50:
logger = get_root_logger()
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
self.dilation, mask)
else:
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
low = norm_cdf((a - mean) / std)
up = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [low, up], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * low - 1, 2 * up - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution.
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# From PyTorch
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
================================================
FILE: basicsr/archs/basicvsr_arch.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
from .edvr_arch import PCDAlignment, TSAFusion
from .spynet_arch import SpyNet
@ARCH_REGISTRY.register()
class BasicVSR(nn.Module):
"""A recurrent network for video SR. Now only x4 is supported.
Args:
num_feat (int): Number of channels. Default: 64.
num_block (int): Number of residual blocks for each branch. Default: 15
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
"""
def __init__(self, num_feat=64, num_block=15, spynet_path=None):
super().__init__()
self.num_feat = num_feat
# alignment
self.spynet = SpyNet(spynet_path)
# propagation
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
# reconstruction
self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
# activation functions
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def get_flow(self, x):
b, n, c, h, w = x.size()
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
return flows_forward, flows_backward
def forward(self, x):
"""Forward function of BasicVSR.
Args:
x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
"""
flows_forward, flows_backward = self.get_flow(x)
b, n, _, h, w = x.size()
# backward branch
out_l = []
feat_prop = x.new_zeros(b, self.num_feat, h, w)
for i in range(n - 1, -1, -1):
x_i = x[:, i, :, :, :]
if i < n - 1:
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
feat_prop = torch.cat([x_i, feat_prop], dim=1)
feat_prop = self.backward_trunk(feat_prop)
out_l.insert(0, feat_prop)
# forward branch
feat_prop = torch.zeros_like(feat_prop)
for i in range(0, n):
x_i = x[:, i, :, :, :]
if i > 0:
flow = flows_forward[:, i - 1, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
feat_prop = torch.cat([x_i, feat_prop], dim=1)
feat_prop = self.forward_trunk(feat_prop)
# upsample
out = torch.cat([out_l[i], feat_prop], dim=1)
out = self.lrelu(self.fusion(out))
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
out += base
out_l[i] = out
return torch.stack(out_l, dim=1)
class ConvResidualBlocks(nn.Module):
"""Conv and residual block used in BasicVSR.
Args:
num_in_ch (int): Number of input channels. Default: 3.
num_out_ch (int): Number of output channels. Default: 64.
num_block (int): Number of residual blocks. Default: 15.
"""
def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
def forward(self, fea):
return self.main(fea)
@ARCH_REGISTRY.register()
class IconVSR(nn.Module):
"""IconVSR, proposed also in the BasicVSR paper.
Args:
num_feat (int): Number of channels. Default: 64.
num_block (int): Number of residual blocks for each branch. Default: 15.
keyframe_stride (int): Keyframe stride. Default: 5.
temporal_padding (int): Temporal padding. Default: 2.
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
edvr_path (str): Path to the pretrained EDVR model. Default: None.
"""
def __init__(self,
num_feat=64,
num_block=15,
keyframe_stride=5,
temporal_padding=2,
spynet_path=None,
edvr_path=None):
super().__init__()
self.num_feat = num_feat
self.temporal_padding = temporal_padding
self.keyframe_stride = keyframe_stride
# keyframe_branch
self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
# alignment
self.spynet = SpyNet(spynet_path)
# propagation
self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
# reconstruction
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
# activation functions
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def pad_spatial(self, x):
"""Apply padding spatially.
Since the PCD module in EDVR requires that the resolution is a multiple
of 4, we apply padding to the input LR images if their resolution is
not divisible by 4.
Args:
x (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
"""
n, t, c, h, w = x.size()
pad_h = (4 - h % 4) % 4
pad_w = (4 - w % 4) % 4
# padding
x = x.view(-1, c, h, w)
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
return x.view(n, t, c, h + pad_h, w + pad_w)
def get_flow(self, x):
b, n, c, h, w = x.size()
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
return flows_forward, flows_backward
def get_keyframe_feature(self, x, keyframe_idx):
if self.temporal_padding == 2:
x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
elif self.temporal_padding == 3:
x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
x = torch.cat(x, dim=1)
num_frames = 2 * self.temporal_padding + 1
feats_keyframe = {}
for i in keyframe_idx:
feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
return feats_keyframe
def forward(self, x):
b, n, _, h_input, w_input = x.size()
x = self.pad_spatial(x)
h, w = x.shape[3:]
keyframe_idx = list(range(0, n, self.keyframe_stride))
if keyframe_idx[-1] != n - 1:
keyframe_idx.append(n - 1) # last frame is a keyframe
# compute flow and keyframe features
flows_forward, flows_backward = self.get_flow(x)
feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
# backward branch
out_l = []
feat_prop = x.new_zeros(b, self.num_feat, h, w)
for i in range(n - 1, -1, -1):
x_i = x[:, i, :, :, :]
if i < n - 1:
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
if i in keyframe_idx:
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
feat_prop = self.backward_fusion(feat_prop)
feat_prop = torch.cat([x_i, feat_prop], dim=1)
feat_prop = self.backward_trunk(feat_prop)
out_l.insert(0, feat_prop)
# forward branch
feat_prop = torch.zeros_like(feat_prop)
for i in range(0, n):
x_i = x[:, i, :, :, :]
if i > 0:
flow = flows_forward[:, i - 1, :, :, :]
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
if i in keyframe_idx:
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
feat_prop = self.forward_fusion(feat_prop)
feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
feat_prop = self.forward_trunk(feat_prop)
# upsample
out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
out += base
out_l[i] = out
return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
class EDVRFeatureExtractor(nn.Module):
"""EDVR feature extractor used in IconVSR.
Args:
num_input_frame (int): Number of input frames.
num_feat (int): Number of feature channels
load_path (str): Path to the pretrained weights of EDVR. Default: None.
"""
def __init__(self, num_input_frame, num_feat, load_path):
super(EDVRFeatureExtractor, self).__init__()
self.center_frame_idx = num_input_frame // 2
# extract pyramid features
self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# pcd and tsa module
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
if load_path:
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
def forward(self, x):
b, n, c, h, w = x.size()
# extract features for each frame
# L1
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
feat_l1 = self.feature_extraction(feat_l1)
# L2
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
# L3
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
feat_l1 = feat_l1.view(b, n, -1, h, w)
feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
# PCD alignment
ref_feat_l = [ # reference feature list
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
feat_l3[:, self.center_frame_idx, :, :, :].clone()
]
aligned_feat = []
for i in range(n):
nbr_feat_l = [ # neighboring feature list
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
]
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
# TSA fusion
return self.fusion(aligned_feat)
================================================
FILE: basicsr/archs/basicvsrpp_arch.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import warnings
from basicsr.archs.arch_util import flow_warp
from basicsr.archs.basicvsr_arch import ConvResidualBlocks
from basicsr.archs.spynet_arch import SpyNet
from basicsr.ops.dcn import ModulatedDeformConvPack
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class BasicVSRPlusPlus(nn.Module):
"""BasicVSR++ network structure.
Support either x4 upsampling or same size output. Since DCN is used in this
model, it can only be used with CUDA enabled. If CUDA is not enabled,
feature alignment will be skipped. Besides, we adopt the official DCN
implementation and the version of torch need to be higher than 1.9.
``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment``
Args:
mid_channels (int, optional): Channel number of the intermediate
features. Default: 64.
num_blocks (int, optional): The number of residual blocks in each
propagation branch. Default: 7.
max_residue_magnitude (int): The maximum magnitude of the offset
residue (Eq. 6 in paper). Default: 10.
is_low_res_input (bool, optional): Whether the input is low-resolution
or not. If False, the output resolution is equal to the input
resolution. Default: True.
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
cpu_cache_length (int, optional): When the length of sequence is larger
than this value, the intermediate features are sent to CPU. This
saves GPU memory, but slows down the inference speed. You can
increase this number if you have a GPU with large memory.
Default: 100.
"""
def __init__(self,
mid_channels=64,
num_blocks=7,
max_residue_magnitude=10,
is_low_res_input=True,
spynet_path=None,
cpu_cache_length=100):
super().__init__()
self.mid_channels = mid_channels
self.is_low_res_input = is_low_res_input
self.cpu_cache_length = cpu_cache_length
# optical flow
self.spynet = SpyNet(spynet_path)
# feature extraction module
if is_low_res_input:
self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
else:
self.feat_extract = nn.Sequential(
nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
ConvResidualBlocks(mid_channels, mid_channels, 5))
# propagation branches
self.deform_align = nn.ModuleDict()
self.backbone = nn.ModuleDict()
modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
for i, module in enumerate(modules):
if torch.cuda.is_available():
self.deform_align[module] = SecondOrderDeformableAlignment(
2 * mid_channels,
mid_channels,
3,
padding=1,
deformable_groups=16,
max_residue_magnitude=max_residue_magnitude)
self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
# upsampling module
self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# check if the sequence is augmented by flipping
self.is_mirror_extended = False
if len(self.deform_align) > 0:
self.is_with_alignment = True
else:
self.is_with_alignment = False
warnings.warn('Deformable alignment module is not added. '
'Probably your CUDA is not configured correctly. DCN can only '
'be used with CUDA enabled. Alignment is skipped now.')
def check_if_mirror_extended(self, lqs):
"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame.
Args:
lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w).
"""
if lqs.size(1) % 2 == 0:
lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
self.is_mirror_extended = True
def compute_flow(self, lqs):
"""Compute optical flow using SPyNet for feature alignment.
Note that if the input is an mirror-extended sequence, 'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
Return:
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation \
(current to previous). 'flows_backward' corresponds to the flows used for backward-time \
propagation (current to next).
"""
n, t, c, h, w = lqs.size()
lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
flows_forward = flows_backward.flip(1)
else:
flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
if self.cpu_cache:
flows_backward = flows_backward.cpu()
flows_forward = flows_forward.cpu()
return flows_forward, flows_backward
def propagate(self, feats, flows, module_name):
"""Propagate the latent features throughout the sequence.
Args:
feats dict(list[tensor]): Features from previous branches. Each
component is a list of tensors with shape (n, c, h, w).
flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
module_name (str): The name of the propgation branches. Can either
be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
Return:
dict(list[tensor]): A dictionary containing all the propagated \
features. Each key in the dictionary corresponds to a \
propagation branch, which is represented by a list of tensors.
"""
n, t, _, h, w = flows.size()
frame_idx = range(0, t + 1)
flow_idx = range(-1, t)
mapping_idx = list(range(0, len(feats['spatial'])))
mapping_idx += mapping_idx[::-1]
if 'backward' in module_name:
frame_idx = frame_idx[::-1]
flow_idx = frame_idx
feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
for i, idx in enumerate(frame_idx):
feat_current = feats['spatial'][mapping_idx[idx]]
if self.cpu_cache:
feat_current = feat_current.cuda()
feat_prop = feat_prop.cuda()
# second-order deformable alignment
if i > 0 and self.is_with_alignment:
flow_n1 = flows[:, flow_idx[i], :, :, :]
if self.cpu_cache:
flow_n1 = flow_n1.cuda()
cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
# initialize second-order features
feat_n2 = torch.zeros_like(feat_prop)
flow_n2 = torch.zeros_like(flow_n1)
cond_n2 = torch.zeros_like(cond_n1)
if i > 1: # second-order features
feat_n2 = feats[module_name][-2]
if self.cpu_cache:
feat_n2 = feat_n2.cuda()
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
if self.cpu_cache:
flow_n2 = flow_n2.cuda()
flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
# flow-guided deformable convolution
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
# concatenate and residual blocks
feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
if self.cpu_cache:
feat = [f.cuda() for f in feat]
feat = torch.cat(feat, dim=1)
feat_prop = feat_prop + self.backbone[module_name](feat)
feats[module_name].append(feat_prop)
if self.cpu_cache:
feats[module_name][-1] = feats[module_name][-1].cpu()
torch.cuda.empty_cache()
if 'backward' in module_name:
feats[module_name] = feats[module_name][::-1]
return feats
def upsample(self, lqs, feats):
"""Compute the output image given the features.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
feats (dict): The features from the propagation branches.
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
outputs = []
num_outputs = len(feats['spatial'])
mapping_idx = list(range(0, num_outputs))
mapping_idx += mapping_idx[::-1]
for i in range(0, lqs.size(1)):
hr = [feats[k].pop(0) for k in feats if k != 'spatial']
hr.insert(0, feats['spatial'][mapping_idx[i]])
hr = torch.cat(hr, dim=1)
if self.cpu_cache:
hr = hr.cuda()
hr = self.reconstruction(hr)
hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
hr = self.lrelu(self.conv_hr(hr))
hr = self.conv_last(hr)
if self.is_low_res_input:
hr += self.img_upsample(lqs[:, i, :, :, :])
else:
hr += lqs[:, i, :, :, :]
if self.cpu_cache:
hr = hr.cpu()
torch.cuda.empty_cache()
outputs.append(hr)
return torch.stack(outputs, dim=1)
def forward(self, lqs):
"""Forward function for BasicVSR++.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
n, t, c, h, w = lqs.size()
# whether to cache the features in CPU
self.cpu_cache = True if t > self.cpu_cache_length else False
if self.is_low_res_input:
lqs_downsample = lqs.clone()
else:
lqs_downsample = F.interpolate(
lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
# check whether the input is an extended sequence
self.check_if_mirror_extended(lqs)
feats = {}
# compute spatial features
if self.cpu_cache:
feats['spatial'] = []
for i in range(0, t):
feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
feats['spatial'].append(feat)
torch.cuda.empty_cache()
else:
feats_ = self.feat_extract(lqs.view(-1, c, h, w))
h, w = feats_.shape[2:]
feats_ = feats_.view(n, t, -1, h, w)
feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
# compute optical flow using the low-res inputs
assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
'The height and width of low-res inputs must be at least 64, '
f'but got {h} and {w}.')
flows_forward, flows_backward = self.compute_flow(lqs_downsample)
# feature propgation
for iter_ in [1, 2]:
for direction in ['backward', 'forward']:
module = f'{direction}_{iter_}'
feats[module] = []
if direction == 'backward':
flows = flows_backward
elif flows_forward is not None:
flows = flows_forward
else:
flows = flows_backward.flip(1)
feats = self.propagate(feats, flows, module)
if self.cpu_cache:
del flows
torch.cuda.empty_cache()
return self.upsample(lqs, feats)
class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
"""Second-order deformable alignment module.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
max_residue_magnitude (int): The maximum magnitude of the offset
residue (Eq. 6 in paper). Default: 10.
"""
def __init__(self, *args, **kwargs):
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
self.conv_offset = nn.Sequential(
nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
)
self.init_offset()
def init_offset(self):
def _constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
_constant_init(self.conv_offset[-1], val=0, bias=0)
def forward(self, x, extra_feat, flow_1, flow_2):
extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
out = self.conv_offset(extra_feat)
o1, o2, mask = torch.chunk(out, 3, dim=1)
# offset
offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
offset = torch.cat([offset_1, offset_2], dim=1)
# mask
mask = torch.sigmoid(mask)
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
self.dilation, mask)
# if __name__ == '__main__':
# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
# input = torch.rand(1, 2, 3, 64, 64).cuda()
# output = model(input)
# print('===================')
# print(output.shape)
================================================
FILE: basicsr/archs/dfdnet_arch.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.spectral_norm import spectral_norm
from basicsr.utils.registry import ARCH_REGISTRY
from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
from .vgg_arch import VGGFeatureExtractor
class SFTUpBlock(nn.Module):
"""Spatial feature transform (SFT) with upsampling block.
Args:
in_channel (int): Number of input channels.
out_channel (int): Number of output channels.
kernel_size (int): Kernel size in convolutions. Default: 3.
padding (int): Padding in convolutions. Default: 1.
"""
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
super(SFTUpBlock, self).__init__()
self.conv1 = nn.Sequential(
Blur(in_channel),
spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
nn.LeakyReLU(0.04, True),
# The official codes use two LeakyReLU here, so 0.04 for equivalent
)
self.convup = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
nn.LeakyReLU(0.2, True),
)
# for SFT scale and shift
self.scale_block = nn.Sequential(
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
self.shift_block = nn.Sequential(
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
# The official codes use sigmoid for shift block, do not know why
def forward(self, x, updated_feat):
out = self.conv1(x)
# SFT
scale = self.scale_block(updated_feat)
shift = self.shift_block(updated_feat)
out = out * scale + shift
# upsample
out = self.convup(out)
return out
@ARCH_REGISTRY.register()
class DFDNet(nn.Module):
"""DFDNet: Deep Face Dictionary Network.
It only processes faces with 512x512 size.
Args:
num_feat (int): Number of feature channels.
dict_path (str): Path to the facial component dictionary.
"""
def __init__(self, num_feat, dict_path):
super().__init__()
self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
# part_sizes: [80, 80, 50, 110]
channel_sizes = [128, 256, 512, 512]
self.feature_sizes = np.array([256, 128, 64, 32])
self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
self.flag_dict_device = False
# dict
self.dict = torch.load(dict_path)
# vgg face extractor
self.vgg_extractor = VGGFeatureExtractor(
layer_name_list=self.vgg_layers,
vgg_type='vgg19',
use_input_norm=True,
range_norm=True,
requires_grad=False)
# attention block for fusing dictionary features and input features
self.attn_blocks = nn.ModuleDict()
for idx, feat_size in enumerate(self.feature_sizes):
for name in self.parts:
self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
# multi scale dilation block
self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
# upsampling and reconstruction
self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
self.upsample4 = nn.Sequential(
spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
"""swap the features from the dictionary."""
# get the original vgg features
part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
# resize original vgg features
part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
# use adaptive instance normalization to adjust color and illuminations
dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
# get similarity scores
similarity_score = F.conv2d(part_resize_feat, dict_feat)
similarity_score = F.softmax(similarity_score.view(-1), dim=0)
# select the most similar features in the dict (after norm)
select_idx = torch.argmax(similarity_score)
swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
# attention
attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
attn_feat = attn * swap_feat
# update features
updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
return updated_feat
def put_dict_to_device(self, x):
if self.flag_dict_device is False:
for k, v in self.dict.items():
for kk, vv in v.items():
self.dict[k][kk] = vv.to(x)
self.flag_dict_device = True
def forward(self, x, part_locations):
"""
Now only support testing with batch size = 0.
Args:
x (Tensor): Input faces with shape (b, c, 512, 512).
part_locations (list[Tensor]): Part locations.
"""
self.put_dict_to_device(x)
# extract vggface features
vgg_features = self.vgg_extractor(x)
# update vggface features using the dictionary for each part
updated_vgg_features = []
batch = 0 # only supports testing with batch size = 0
for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
dict_features = self.dict[f'{f_size}']
vgg_feat = vgg_features[vgg_layer]
updated_feat = vgg_feat.clone()
# swap features from dictionary
for part_idx, part_name in enumerate(self.parts):
location = (part_locations[part_idx][batch] // (512 / f_size)).int()
updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
f_size)
updated_vgg_features.append(updated_feat)
vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
# use updated vgg features to modulate the upsampled features with
# SFT (Spatial Feature Transform) scaling and shifting manner.
upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
out = self.upsample4(upsampled_feat)
return out
================================================
FILE: basicsr/archs/dfdnet_util.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn.utils.spectral_norm import spectral_norm
class BlurFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
return grad_input
@staticmethod
def backward(ctx, gradgrad_output):
kernel, _ = ctx.saved_tensors
grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
return grad_input, None, None
class BlurFunction(Function):
@staticmethod
def forward(ctx, x, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
return output
@staticmethod
def backward(ctx, grad_output):
kernel, kernel_flip = ctx.saved_tensors
grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
return grad_input, None, None
blur = BlurFunction.apply
class Blur(nn.Module):
def __init__(self, channel):
super().__init__()
kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
kernel = kernel.view(1, 1, 3, 3)
kernel = kernel / kernel.sum()
kernel_flip = torch.flip(kernel, [2, 3])
self.kernel = kernel.repeat(channel, 1, 1, 1)
self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
def forward(self, x):
return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
n, c = size[:2]
feat_var = feat.view(n, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(n, c, 1, 1)
feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def AttentionBlock(in_channel):
return nn.Sequential(
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
"""Conv block used in MSDilationBlock."""
return nn.Sequential(
spectral_norm(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=((kernel_size - 1) // 2) * dilation,
bias=bias)),
nn.LeakyReLU(0.2),
spectral_norm(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=((kernel_size - 1) // 2) * dilation,
bias=bias)),
)
class MSDilationBlock(nn.Module):
"""Multi-scale dilation block."""
def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
super(MSDilationBlock, self).__init__()
self.conv_blocks = nn.ModuleList()
for i in range(4):
self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
self.conv_fusion = spectral_norm(
nn.Conv2d(
in_channels * 4,
in_channels,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
bias=bias))
def forward(self, x):
out = []
for i in range(4):
out.append(self.conv_blocks[i](x))
out = torch.cat(out, 1)
out = self.conv_fusion(out) + x
return out
class UpResBlock(nn.Module):
def __init__(self, in_channel):
super(UpResBlock, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
)
def forward(self, x):
out = x + self.body(x)
return out
================================================
FILE: basicsr/archs/discriminator_arch.py
================================================
from torch import nn as nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class VGGStyleDiscriminator(nn.Module):
"""VGG style discriminator with input size 128 x 128 or 256 x 256.
It is used to train SRGAN, ESRGAN, and VideoGAN.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features.Default: 64.
"""
def __init__(self, num_in_ch, num_feat, input_size=128):
super(VGGStyleDiscriminator, self).__init__()
self.input_size = input_size
assert self.input_size == 128 or self.input_size == 256, (
f'input size must be 128 or 256, but received {input_size}')
self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
if self.input_size == 256:
self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
self.linear2 = nn.Linear(100, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
feat = self.lrelu(self.conv0_0(x))
feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
if self.input_size == 256:
feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
# spatial size: (4, 4)
feat = feat.view(feat.size(0), -1)
feat = self.lrelu(self.linear1(feat))
out = self.linear2(feat)
return out
@ARCH_REGISTRY.register(suffix='basicsr')
class UNetDiscriminatorSN(nn.Module):
"""Defines a U-Net discriminator with spectral normalization (SN)
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
Arg:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features. Default: 64.
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
"""
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
super(UNetDiscriminatorSN, self).__init__()
self.skip_connection = skip_connection
norm = spectral_norm
# the first convolution
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
# downsample
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
# upsample
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
# extra convolutions
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
def forward(self, x):
# downsample
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
# upsample
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
if self.skip_connection:
x4 = x4 + x2
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
if self.skip_connection:
x5 = x5 + x1
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
if self.skip_connection:
x6 = x6 + x0
# extra convolutions
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out)
return out
================================================
FILE: basicsr/archs/duf_arch.py
================================================
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
class DenseBlocksTemporalReduce(nn.Module):
"""A concatenation of 3 dense blocks with reduction in temporal dimension.
Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
Args:
num_feat (int): Number of channels in the blocks. Default: 64.
num_grow_ch (int): Growing factor of the dense blocks. Default: 32
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
Set to false if you want to train from scratch. Default: False.
"""
def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
super(DenseBlocksTemporalReduce, self).__init__()
if adapt_official_weights:
eps = 1e-3
momentum = 1e-3
else: # pytorch default values
eps = 1e-05
momentum = 0.1
self.temporal_reduce1 = nn.Sequential(
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
self.temporal_reduce2 = nn.Sequential(
nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + num_grow_ch,
num_feat + num_grow_ch, (1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
self.temporal_reduce3 = nn.Sequential(
nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + 2 * num_grow_ch,
num_feat + 2 * num_grow_ch, (1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
def forward(self, x):
"""
Args:
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
Returns:
Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
"""
x1 = self.temporal_reduce1(x)
x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
x2 = self.temporal_reduce2(x1)
x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
x3 = self.temporal_reduce3(x2)
x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
return x3
class DenseBlocks(nn.Module):
""" A concatenation of N dense blocks.
Args:
num_feat (int): Number of channels in the blocks. Default: 64.
num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
num_block (int): Number of dense blocks. The values are:
DUF-S (16 layers): 3
DUF-M (18 layers): 9
DUF-L (52 layers): 21
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
Set to false if you want to train from scratch. Default: False.
"""
def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
super(DenseBlocks, self).__init__()
if adapt_official_weights:
eps = 1e-3
momentum = 1e-3
else: # pytorch default values
eps = 1e-05
momentum = 0.1
self.dense_blocks = nn.ModuleList()
for i in range(0, num_block):
self.dense_blocks.append(
nn.Sequential(
nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + i * num_grow_ch,
num_feat + i * num_grow_ch, (1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + i * num_grow_ch,
num_grow_ch, (3, 3, 3),
stride=(1, 1, 1),
padding=(1, 1, 1),
bias=True)))
def forward(self, x):
"""
Args:
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
Returns:
Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
"""
for i in range(0, len(self.dense_blocks)):
y = self.dense_blocks[i](x)
x = torch.cat((x, y), 1)
return x
class DynamicUpsamplingFilter(nn.Module):
"""Dynamic upsampling filter used in DUF.
Reference: https://github.com/yhjo09/VSR-DUF
It only supports input with 3 channels. And it applies the same filters to 3 channels.
Args:
filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
"""
def __init__(self, filter_size=(5, 5)):
super(DynamicUpsamplingFilter, self).__init__()
if not isinstance(filter_size, tuple):
raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
if len(filter_size) != 2:
raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
# generate a local expansion filter, similar to im2col
self.filter_size = filter_size
filter_prod = np.prod(filter_size)
expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
def forward(self, x, filters):
"""Forward function for DynamicUpsamplingFilter.
Args:
x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w).
filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling.
e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
Returns:
Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
"""
n, filter_prod, upsampling_square, h, w = filters.size()
kh, kw = self.filter_size
expanded_input = F.conv2d(
x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
2) # (n, h, w, 3, filter_prod)
filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
@ARCH_REGISTRY.register()
class DUF(nn.Module):
"""Network architecture for DUF
``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation``
Reference: https://github.com/yhjo09/VSR-DUF
For all the models below, 'adapt_official_weights' is only necessary when
loading the weights converted from the official TensorFlow weights.
Please set it to False if you are training the model from scratch.
There are three models with different model size: DUF16Layers, DUF28Layers,
and DUF52Layers. This class is the base class for these models.
Args:
scale (int): The upsampling factor. Default: 4.
num_layer (int): The number of layers. Default: 52.
adapt_official_weights_weights (bool): Whether to adapt the weights
translated from the official implementation. Set to false if you
want to train from scratch. Default: False.
"""
def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
super(DUF, self).__init__()
self.scale = scale
if adapt_official_weights:
eps = 1e-3
momentum = 1e-3
else: # pytorch default values
eps = 1e-05
momentum = 0.1
self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
if num_layer == 16:
num_block = 3
num_grow_ch = 32
elif num_layer == 28:
num_block = 9
num_grow_ch = 16
elif num_layer == 52:
num_block = 21
num_grow_ch = 16
else:
raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
self.dense_block1 = DenseBlocks(
num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
adapt_official_weights=adapt_official_weights) # T = 7
self.dense_block2 = DenseBlocksTemporalReduce(
64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
self.conv3d_f2 = nn.Conv3d(
512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
def forward(self, x):
"""
Args:
x (Tensor): Input with shape (b, 7, c, h, w)
Returns:
Tensor: Output with shape (b, c, h * scale, w * scale)
"""
num_batches, num_imgs, _, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
x_center = x[:, :, num_imgs // 2, :, :]
x = self.conv3d1(x)
x = self.dense_block1(x)
x = self.dense_block2(x)
x = F.relu(self.bn3d2(x), inplace=True)
x = F.relu(self.conv3d2(x), inplace=True)
# residual image
res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
# filter
filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
# dynamic filter
out = self.dynamic_filter(x_center, filter_)
out += res.squeeze_(2)
out = F.pixel_shuffle(out, self.scale)
return out
================================================
FILE: basicsr/archs/ecbsr_arch.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils.registry import ARCH_REGISTRY
class SeqConv3x3(nn.Module):
"""The re-parameterizable block used in the ECBSR architecture.
``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices``
Reference: https://github.com/xindongzhang/ECBSR
Args:
seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
in_channels (int): Channel number of input.
out_channels (int): Channel number of output.
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
"""
def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
super(SeqConv3x3, self).__init__()
self.seq_type = seq_type
self.in_channels = in_channels
self.out_channels = out_channels
if self.seq_type == 'conv1x1-conv3x3':
self.mid_planes = int(out_channels * depth_multiplier)
conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
self.k1 = conv1.weight
self.b1 = conv1.bias
elif self.seq_type == 'conv1x1-sobelx':
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale and bias
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(scale)
bias = torch.randn(self.out_channels) * 1e-3
bias = torch.reshape(bias, (self.out_channels, ))
self.bias = nn.Parameter(bias)
# init mask
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_channels):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 1, 0] = 2.0
self.mask[i, 0, 2, 0] = 1.0
self.mask[i, 0, 0, 2] = -1.0
self.mask[i, 0, 1, 2] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
elif self.seq_type == 'conv1x1-sobely':
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale and bias
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
bias = torch.randn(self.out_channels) * 1e-3
bias = torch.reshape(bias, (self.out_channels, ))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_channels):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 0, 1] = 2.0
self.mask[i, 0, 0, 2] = 1.0
self.mask[i, 0, 2, 0] = -1.0
self.mask[i, 0, 2, 1] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
elif self.seq_type == 'conv1x1-laplacian':
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias
# init scale and bias
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
bias = torch.randn(self.out_channels) * 1e-3
bias = torch.reshape(bias, (self.out_channels, ))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_channels):
self.mask[i, 0, 0, 1] = 1.0
self.mask[i, 0, 1, 0] = 1.0
self.mask[i, 0, 1, 2] = 1.0
self.mask[i, 0, 2, 1] = 1.0
self.mask[i, 0, 1, 1] = -4.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
else:
raise ValueError('The type of seqconv is not supported!')
def forward(self, x):
if self.seq_type == 'conv1x1-conv3x3':
# conv-1x1
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1)
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
else:
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1)
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3
y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
return y1
def rep_params(self):
device = self.k0.get_device()
if device < 0:
device = None
if self.seq_type == 'conv1x1-conv3x3':
# re-param conv kernel
rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
else:
tmp = self.scale * self.mask
k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
for i in range(self.out_channels):
k1[i, i, :, :] = tmp[i, 0, :, :]
b1 = self.bias
# re-param conv kernel
rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
return rep_weight, rep_bias
class ECB(nn.Module):
"""The ECB block used in the ECBSR architecture.
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
Ref git repo: https://github.com/xindongzhang/ECBSR
Args:
in_channels (int): Channel number of input.
out_channels (int): Channel number of output.
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
with_idt (bool): Whether to use identity connection. Default: False.
"""
def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
super(ECB, self).__init__()
self.depth_multiplier = depth_multiplier
self.in_channels = in_channels
self.out_channels = out_channels
self.act_type = act_type
if with_idt and (self.in_channels == self.out_channels):
self.with_idt = True
else:
self.with_idt = False
self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
if self.act_type == 'prelu':
self.act = nn.PReLU(num_parameters=self.out_channels)
elif self.act_type == 'relu':
self.act = nn.ReLU(inplace=True)
elif self.act_type == 'rrelu':
self.act = nn.RReLU(lower=-0.05, upper=0.05)
elif self.act_type == 'softplus':
self.act = nn.Softplus()
elif self.act_type == 'linear':
pass
else:
raise ValueError('The type of activation if not support!')
def forward(self, x):
if self.training:
y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
if self.with_idt:
y += x
else:
rep_weight, rep_bias = self.rep_params()
y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
if self.act_type != 'linear':
y = self.act(y)
return y
def rep_params(self):
weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
weight1, bias1 = self.conv1x1_3x3.rep_params()
weight2, bias2 = self.conv1x1_sbx.rep_params()
weight3, bias3 = self.conv1x1_sby.rep_params()
weight4, bias4 = self.conv1x1_lpl.rep_params()
rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
bias0 + bias1 + bias2 + bias3 + bias4)
if self.with_idt:
device = rep_weight.get_device()
if device < 0:
device = None
weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
for i in range(self.out_channels):
weight_idt[i, i, 1, 1] = 1.0
bias_idt = 0.0
rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
return rep_weight, rep_bias
@ARCH_REGISTRY.register()
class ECBSR(nn.Module):
"""ECBSR architecture.
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
Ref git repo: https://github.com/xindongzhang/ECBSR
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_block (int): Block number in the trunk network.
num_channel (int): Channel number.
with_idt (bool): Whether use identity in convolution layers.
act_type (str): Activation type.
scale (int): Upsampling factor.
"""
def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
super(ECBSR, self).__init__()
self.num_in_ch = num_in_ch
self.scale = scale
backbone = []
backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
for _ in range(num_block):
backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
backbone += [
ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
]
self.backbone = nn.Sequential(*backbone)
self.upsampler = nn.PixelShuffle(scale)
def forward(self, x):
if self.num_in_ch > 1:
shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
else:
shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
y = self.backbone(x) + shortcut
y = self.upsampler(y)
return y
================================================
FILE: basicsr/archs/edsr_arch.py
================================================
import torch
from torch import nn as nn
from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class EDSR(nn.Module):
"""EDSR network structure.
Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64.
num_block (int): Block number in the trunk network. Default: 16.
upscale (int): Upsampling factor. Support 2^n and 3.
Default: 4.
res_scale (float): Used to scale the residual in residual block.
Default: 1.
img_range (float): Image range. Default: 255.
rgb_mean (tuple[float]): Image mean in RGB orders.
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
"""
def __init__(self,
num_in_ch,
num_out_ch,
num_feat=64,
num_block=16,
upscale=4,
res_scale=1,
img_range=255.,
rgb_mean=(0.4488, 0.4371, 0.4040)):
super(EDSR, self).__init__()
self.img_range = img_range
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
x = self.conv_first(x)
res = self.conv_after_body(self.body(x))
res += x
x = self.conv_last(self.upsample(res))
x = x / self.img_range + self.mean
return x
================================================
FILE: basicsr/archs/edvr_arch.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
class PCDAlignment(nn.Module):
"""Alignment module using Pyramid, Cascading and Deformable convolution
(PCD). It is used in EDVR.
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
Args:
num_feat (int): Channel number of middle features. Default: 64.
deformable_groups (int): Deformable groups. Defaults: 8.
"""
def __init__(self, num_feat=64, deformable_groups=8):
super(PCDAlignment, self).__init__()
# Pyramid has three levels:
# L3: level 3, 1/4 spatial size
# L2: level 2, 1/2 spatial size
# L1: level 1, original spatial size
self.offset_conv1 = nn.ModuleDict()
self.offset_conv2 = nn.ModuleDict()
self.offset_conv3 = nn.ModuleDict()
self.dcn_pack = nn.ModuleDict()
self.feat_conv = nn.ModuleDict()
# Pyramids
for i in range(3, 0, -1):
level = f'l{i}'
self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
if i == 3:
self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
else:
self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
if i < 3:
self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
# Cascading dcn
self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, nbr_feat_l, ref_feat_l):
"""Align neighboring frame features to the reference frame features.
Args:
nbr_feat_l (list[Tensor]): Neighboring feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
ref_feat_l (list[Tensor]): Reference feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
Returns:
Tensor: Aligned features.
"""
# Pyramids
upsampled_offset, upsampled_feat = None, None
for i in range(3, 0, -1):
level = f'l{i}'
offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
offset = self.lrelu(self.offset_conv1[level](offset))
if i == 3:
offset = self.lrelu(self.offset_conv2[level](offset))
else:
offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
offset = self.lrelu(self.offset_conv3[level](offset))
feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
if i < 3:
feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
if i > 1:
feat = self.lrelu(feat)
if i > 1: # upsample offset and features
# x2: when we upsample the offset, we should also enlarge
# the magnitude.
upsampled_offset = self.upsample(offset) * 2
upsampled_feat = self.upsample(feat)
# Cascading
offset = torch.cat([feat, ref_feat_l[0]], dim=1)
offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
feat = self.lrelu(self.cas_dcnpack(feat, offset))
return feat
class TSAFusion(nn.Module):
"""Temporal Spatial Attention (TSA) fusion module.
Temporal: Calculate the correlation between center frame and
neighboring frames;
Spatial: It has 3 pyramid levels, the attention is similar to SFT.
(SFT: Recovering realistic texture in image super-resolution by deep
spatial feature transform.)
Args:
num_feat (int): Channel number of middle features. Default: 64.
num_frame (int): Number of frames. Default: 5.
center_frame_idx (int): The index of center frame. Default: 2.
"""
def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
super(TSAFusion, self).__init__()
self.center_frame_idx = center_frame_idx
# temporal attention (before fusion conv)
self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
# spatial attention (after fusion conv)
self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
def forward(self, aligned_feat):
"""
Args:
aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
Returns:
Tensor: Features after TSA with the shape (b, c, h, w).
"""
b, t, c, h, w = aligned_feat.size()
# temporal attention
embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
corr_l = [] # correlation list
for i in range(t):
emb_neighbor = embedding[:, i, :, :, :]
corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
# fusion
feat = self.lrelu(self.feat_fusion(aligned_feat))
# spatial attention
attn = self.lrelu(self.spatial_attn1(aligned_feat))
attn_max = self.max_pool(attn)
attn_avg = self.avg_pool(attn)
attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
# pyramid levels
attn_level = self.lrelu(self.spatial_attn_l1(attn))
attn_max = self.max_pool(attn_level)
attn_avg = self.avg_pool(attn_level)
attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
attn_level = self.upsample(attn_level)
attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
attn = self.lrelu(self.spatial_attn4(attn))
attn = self.upsample(attn)
attn = self.spatial_attn5(attn)
attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
attn = torch.sigmoid(attn)
# after initialization, * 2 makes (attn * 2) to be close to 1.
feat = feat * attn * 2 + attn_add
return feat
class PredeblurModule(nn.Module):
"""Pre-dublur module.
Args:
num_in_ch (int): Channel number of input image. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
hr_in (bool): Whether the input has high resolution. Default: False.
"""
def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
super(PredeblurModule, self).__init__()
self.hr_in = hr_in
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
if self.hr_in:
# downsample x4 by stride conv
self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
# generate feature pyramid
self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
feat_l1 = self.lrelu(self.conv_first(x))
if self.hr_in:
feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
# generate feature pyramid
feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
feat_l3 = self.upsample(self.resblock_l3(feat_l3))
feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
for i in range(2):
feat_l1 = self.resblock_l1[i](feat_l1)
feat_l1 = feat_l1 + feat_l2
for i in range(2, 5):
feat_l1 = self.resblock_l1[i](feat_l1)
return feat_l1
@ARCH_REGISTRY.register()
class EDVR(nn.Module):
"""EDVR network structure for video super-resolution.
Now only support X4 upsampling factor.
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
Args:
num_in_ch (int): Channel number of input image. Default: 3.
num_out_ch (int): Channel number of output image. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_frame (int): Number of input frames. Default: 5.
deformable_groups (int): Deformable groups. Defaults: 8.
num_extract_block (int): Number of blocks for feature extraction.
Default: 5.
num_reconstruct_block (int): Number of blocks for reconstruction.
Default: 10.
center_frame_idx (int): The index of center frame. Frame counting from
0. Default: Middle of input frames.
hr_in (bool): Whether the input has high resolution. Default: False.
with_predeblur (bool): Whether has predeblur module.
Default: False.
with_tsa (bool): Whether has TSA module. Default: True.
"""
def __init__(self,
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_frame=5,
deformable_groups=8,
num_extract_block=5,
num_reconstruct_block=10,
center_frame_idx=None,
hr_in=False,
with_predeblur=False,
with_tsa=True):
super(EDVR, self).__init__()
if center_frame_idx is None:
self.center_frame_idx = num_frame // 2
else:
self.center_frame_idx = center_frame_idx
self.hr_in = hr_in
self.with_predeblur = with_predeblur
self.with_tsa = with_tsa
# extract features for each frame
if self.with_predeblur:
self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
else:
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
# extract pyramid features
self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# pcd and tsa module
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
if self.with_tsa:
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
else:
self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
# reconstruction
self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
# upsample
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
b, t, c, h, w = x.size()
if self.hr_in:
assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
else:
assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
# extract features for each frame
# L1
if self.with_predeblur:
feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
if self.hr_in:
h, w = h // 4, w // 4
else:
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
feat_l1 = self.feature_extraction(feat_l1)
# L2
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
# L3
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
feat_l1 = feat_l1.view(b, t, -1, h, w)
feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
# PCD alignment
ref_feat_l = [ # reference feature list
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
feat_l3[:, self.center_frame_idx, :, :, :].clone()
]
aligned_feat = []
for i in range(t):
nbr_feat_l = [ # neighboring feature list
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
]
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
if not self.with_tsa:
aligned_feat = aligned_feat.view(b, -1, h, w)
feat = self.fusion(aligned_feat)
out = self.reconstruction(feat)
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
if self.hr_in:
base = x_center
else:
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
out += base
return out
================================================
FILE: basicsr/archs/hifacegan_arch.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
class SPADEGenerator(BaseNetwork):
"""Generator with SPADEResBlock"""
def __init__(self,
num_in_ch=3,
num_feat=64,
use_vae=False,
z_dim=256,
crop_size=512,
norm_g='spectralspadesyncbatch3x3',
is_train=True,
init_train_phase=3): # progressive training disabled
super().__init__()
self.nf = num_feat
self.input_nc = num_in_ch
self.is_train = is_train
self.train_phase = init_train_phase
self.scale_ratio = 5 # hardcoded now
self.sw = crop_size // (2**self.scale_ratio)
self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
if use_vae:
# In case of VAE, we will sample from random z vector
self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
else:
# Otherwise, we make the network deterministic by starting with
# downsampled segmentation map instead of random z
self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
self.ups = nn.ModuleList([
SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
])
self.to_rgbs = nn.ModuleList([
nn.Conv2d(8 * self.nf, 3, 3, padding=1),
nn.Conv2d(4 * self.nf, 3, 3, padding=1),
nn.Conv2d(2 * self.nf, 3, 3, padding=1),
nn.Conv2d(1 * self.nf, 3, 3, padding=1)
])
self.up = nn.Upsample(scale_factor=2)
def encode(self, input_tensor):
"""
Encode input_tensor into feature maps, can be overridden in derived classes
Default: nearest downsampling of 2**5 = 32 times
"""
h, w = input_tensor.size()[-2:]
sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
x = F.interpolate(input_tensor, size=(sh, sw))
return self.fc(x)
def forward(self, x):
# In oroginal SPADE, seg means a segmentation map, but here we use x instead.
seg = x
x = self.encode(x)
x = self.head_0(x, seg)
x = self.up(x)
x = self.g_middle_0(x, seg)
x = self.g_middle_1(x, seg)
if self.is_train:
phase = self.train_phase + 1
else:
phase = len(self.to_rgbs)
for i in range(phase):
x = self.up(x)
x = self.ups[i](x, seg)
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
x = torch.tanh(x)
return x
def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
"""
A helper class for subspace visualization. Input and seg are different images.
For the first n levels (including encoder) we use input, for the rest we use seg.
If mode = 'progressive', the output's like: AAABBB
If mode = 'one_plug', the output's like: AAABAA
If mode = 'one_ablate', the output's like: BBBABB
"""
if seg is None:
return self.forward(input_x)
if self.is_train:
phase = self.train_phase + 1
else:
phase = len(self.to_rgbs)
if mode == 'progressive':
n = max(min(n, 4 + phase), 0)
guide_list = [input_x] * n + [seg] * (4 + phase - n)
elif mode == 'one_plug':
n = max(min(n, 4 + phase - 1), 0)
guide_list = [seg] * (4 + phase)
guide_list[n] = input_x
elif mode == 'one_ablate':
if n > 3 + phase:
return self.forward(input_x)
guide_list = [input_x] * (4 + phase)
guide_list[n] = seg
x = self.encode(guide_list[0])
x = self.head_0(x, guide_list[1])
x = self.up(x)
x = self.g_middle_0(x, guide_list[2])
x = self.g_middle_1(x, guide_list[3])
for i in range(phase):
x = self.up(x)
x = self.ups[i](x, guide_list[4 + i])
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
x = torch.tanh(x)
return x
@ARCH_REGISTRY.register()
class HiFaceGAN(SPADEGenerator):
"""
HiFaceGAN: SPADEGenerator with a learnable feature encoder
Current encoder design: LIPEncoder
"""
def __init__(self,
num_in_ch=3,
num_feat=64,
use_vae=False,
z_dim=256,
crop_size=512,
norm_g='spectralspadesyncbatch3x3',
is_train=True,
init_train_phase=3):
super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
def encode(self, input_tensor):
return self.lip_encoder(input_tensor)
@ARCH_REGISTRY.register()
class HiFaceGANDiscriminator(BaseNetwork):
"""
Inspired by pix2pixHD multiscale discriminator.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
conditional_d (bool): Whether use conditional discriminator.
Default: True.
num_d (int): Number of Multiscale discriminators. Default: 3.
n_layers_d (int): Number of downsample layers in each D. Default: 4.
num_feat (int): Channel number of base intermediate features.
Default: 64.
norm_d (str): String to determine normalization layers in D.
Choices: [spectral][instance/batch/syncbatch]
Default: 'spectralinstance'.
keep_features (bool): Keep intermediate features for matching loss, etc.
Default: True.
"""
def __init__(self,
num_in_ch=3,
num_out_ch=3,
conditional_d=True,
num_d=2,
n_layers_d=4,
num_feat=64,
norm_d='spectralinstance',
keep_features=True):
super().__init__()
self.num_d = num_d
input_nc = num_in_ch
if conditional_d:
input_nc += num_out_ch
for i in range(num_d):
subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
self.add_module(f'discriminator_{i}', subnet_d)
def downsample(self, x):
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
# Returns list of lists of discriminator outputs.
# The final result is of size opt.num_d x opt.n_layers_D
def forward(self, x):
result = []
for _, _net_d in self.named_children():
out = _net_d(x)
result.append(out)
x = self.downsample(x)
return result
class NLayerDiscriminator(BaseNetwork):
"""Defines the PatchGAN discriminator with the specified arguments."""
def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
super().__init__()
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
nf = num_feat
self.keep_features = keep_features
norm_layer = get_nonspade_norm_layer(norm_d)
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
for n in range(1, n_layers_d):
nf_prev = nf
nf = min(nf * 2, 512)
stride = 1 if n == n_layers_d - 1 else 2
sequence += [[
norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
nn.LeakyReLU(0.2, False)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
# We divide the layers into groups to extract intermediate layer outputs
for n in range(len(sequence)):
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
def forward(self, x):
results = [x]
for submodel in self.children():
intermediate_output = submodel(results[-1])
results.append(intermediate_output)
if self.keep_features:
return results[1:]
else:
return results[-1]
================================================
FILE: basicsr/archs/hifacegan_util.py
================================================
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
# Warning: spectral norm could be buggy
# under eval mode and multi-GPU inference
# A workaround is sticking to single-GPU inference and train mode
from torch.nn.utils import spectral_norm
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc):
super().__init__()
assert config_text.startswith('spade')
parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
if param_free_norm_type == 'instance':
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
elif param_free_norm_type == 'syncbatch':
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
elif param_free_norm_type == 'batch':
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
else:
raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128 if norm_nc > 128 else norm_nc
pw = ks // 2
self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
def forward(self, x, segmap):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * gamma + beta
return out
class SPADEResnetBlock(nn.Module):
"""
ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
it takes in the segmentation map as input, learns the skip connection if necessary,
and applies normalization first and then convolution.
This architecture seemed like a standard architecture for unconditional or
class-conditional GAN architecture using residual block.
The code was inspired from https://github.com/LMescheder/GAN_stability.
"""
def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if 'spectral' in norm_g:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = norm_g.replace('spectral', '')
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg):
x_s = self.shortcut(x, seg)
dx = self.conv_0(self.act(self.norm_0(x, seg)))
dx = self.conv_1(self.act(self.norm_1(dx, seg)))
out = x_s + dx
return out
def shortcut(self, x, seg):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg))
else:
x_s = x
return x_s
def act(self, x):
return F.leaky_relu(x, 2e-1)
class BaseNetwork(nn.Module):
""" A basis for hifacegan archs with custom initialization """
def init_weights(self, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if classname.find('BatchNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init.normal_(m.weight.data, 1.0, gain)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
def forward(self, x):
pass
def lip2d(x, logit, kernel=3, stride=2, padding=1):
weight = logit.exp()
return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
class SoftGate(nn.Module):
COEFF = 12.0
def forward(self, x):
return torch.sigmoid(x).mul(self.COEFF)
class SimplifiedLIP(nn.Module):
def __init__(self, channels):
super(SimplifiedLIP, self).__init__()
self.logit = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
SoftGate())
def init_layer(self):
self.logit[0].weight.data.fill_(0.0)
def forward(self, x):
frac = lip2d(x, self.logit(x))
return frac
class LIPEncoder(BaseNetwork):
"""Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
super().__init__()
self.sw = sw
self.sh = sh
self.max_ratio = 16
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
kw = 3
pw = (kw - 1) // 2
model = [
nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
norm_layer(ngf),
nn.ReLU(),
]
cur_ratio = 1
for i in range(n_2xdown):
next_ratio = min(cur_ratio * 2, self.max_ratio)
model += [
SimplifiedLIP(ngf * cur_ratio),
nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
norm_layer(ngf * next_ratio),
]
cur_ratio = next_ratio
if i < n_2xdown - 1:
model += [nn.ReLU(inplace=True)]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
def get_nonspade_norm_layer(norm_type='instance'):
# helper function to get # output channels of the previous layer
def get_out_channel(layer):
if hasattr(layer, 'out_channels'):
return getattr(layer, 'out_channels')
return layer.weight.size(0)
# this function will be returned
def add_norm_layer(layer):
nonlocal norm_type
if norm_type.startswith('spectral'):
layer = spectral_norm(layer)
subnorm_type = norm_type[len('spectral'):]
if subnorm_type == 'none' or len(subnorm_type) == 0:
return layer
# remove bias in the previous layer, which is meaningless
# since it has no effect after normalization
if getattr(layer, 'bias', None) is not None:
delattr(layer, 'bias')
layer.register_parameter('bias', None)
if subnorm_type == 'batch':
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == 'sync_batch':
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
# norm_layer = SynchronizedBatchNorm2d(
# get_out_channel(layer), affine=True)
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
elif subnorm_type == 'instance':
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
else:
raise ValueError(f'normalization layer {subnorm_type} is not recognized')
return nn.Sequential(layer, norm_layer)
print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
return add_norm_layer
================================================
FILE: basicsr/archs/inception.py
================================================
# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
# For FID metric
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.model_zoo import load_url
from torchvision import models
# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
class InceptionV3(nn.Module):
"""Pretrained InceptionV3 network returning feature maps"""
# Index of default block of inception to return,
# corresponds to output of final average pooling
DEFAULT_BLOCK_INDEX = 3
# Maps feature dimensionality to their output blocks indices
BLOCK_INDEX_BY_DIM = {
64: 0, # First max pooling features
192: 1, # Second max pooling features
768: 2, # Pre-aux classifier features
2048: 3 # Final average pooling features
}
def __init__(self,
output_blocks=(DEFAULT_BLOCK_INDEX),
resize_input=True,
normalize_input=True,
requires_grad=False,
use_fid_inception=True):
"""Build pretrained InceptionV3.
Args:
output_blocks (list[int]): Indices of blocks to return features of.
Possible values are:
- 0: corresponds to output of first max pooling
- 1: corresponds to output of second max pooling
- 2: corresponds to output which is fed to aux classifier
- 3: corresponds to output of final average pooling
resize_input (bool): If true, bilinearly resizes input to width and
height 299 before feeding input to model. As the network
without fully connected layers is fully convolutional, it
should be able to handle inputs of arbitrary size, so resizing
might not be strictly needed. Default: True.
normalize_input (bool): If true, scales the input from range (0, 1)
to the range the pretrained Inception network expects,
namely (-1, 1). Default: True.
requires_grad (bool): If true, parameters of the model require
gradients. Possibly useful for finetuning the network.
Default: False.
use_fid_inception (bool): If true, uses the pretrained Inception
model used in Tensorflow's FID implementation.
gitextract_6e6skl2h/
├── .github/
│ └── workflows/
│ ├── publish-pip.yml
│ ├── pylint.yml
│ └── release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── .vscode/
│ └── settings.json
├── CITATION.cff
├── LICENSE/
│ ├── LICENSE-NVIDIA
│ ├── LICENSE-stylegan2-pytorch
│ ├── LICENSE_SwinIR
│ ├── LICENSE_pytorch-image-models
│ └── README.md
├── LICENSE.txt
├── MANIFEST.in
├── README.md
├── README_CN.md
├── VERSION
├── basicsr/
│ ├── __init__.py
│ ├── archs/
│ │ ├── __init__.py
│ │ ├── arch_util.py
│ │ ├── basicvsr_arch.py
│ │ ├── basicvsrpp_arch.py
│ │ ├── dfdnet_arch.py
│ │ ├── dfdnet_util.py
│ │ ├── discriminator_arch.py
│ │ ├── duf_arch.py
│ │ ├── ecbsr_arch.py
│ │ ├── edsr_arch.py
│ │ ├── edvr_arch.py
│ │ ├── hifacegan_arch.py
│ │ ├── hifacegan_util.py
│ │ ├── inception.py
│ │ ├── rcan_arch.py
│ │ ├── ridnet_arch.py
│ │ ├── rrdbnet_arch.py
│ │ ├── spynet_arch.py
│ │ ├── srresnet_arch.py
│ │ ├── srvgg_arch.py
│ │ ├── stylegan2_arch.py
│ │ ├── stylegan2_bilinear_arch.py
│ │ ├── swinir_arch.py
│ │ ├── tof_arch.py
│ │ └── vgg_arch.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── data_sampler.py
│ │ ├── data_util.py
│ │ ├── degradations.py
│ │ ├── ffhq_dataset.py
│ │ ├── meta_info/
│ │ │ ├── meta_info_DIV2K800sub_GT.txt
│ │ │ ├── meta_info_REDS4_test_GT.txt
│ │ │ ├── meta_info_REDS_GT.txt
│ │ │ ├── meta_info_REDSofficial4_test_GT.txt
│ │ │ ├── meta_info_REDSval_official_test_GT.txt
│ │ │ ├── meta_info_Vimeo90K_test_GT.txt
│ │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt
│ │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt
│ │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt
│ │ │ └── meta_info_Vimeo90K_train_GT.txt
│ │ ├── paired_image_dataset.py
│ │ ├── prefetch_dataloader.py
│ │ ├── realesrgan_dataset.py
│ │ ├── realesrgan_paired_dataset.py
│ │ ├── reds_dataset.py
│ │ ├── single_image_dataset.py
│ │ ├── transforms.py
│ │ ├── video_test_dataset.py
│ │ └── vimeo90k_dataset.py
│ ├── losses/
│ │ ├── __init__.py
│ │ ├── basic_loss.py
│ │ ├── gan_loss.py
│ │ └── loss_util.py
│ ├── metrics/
│ │ ├── README.md
│ │ ├── README_CN.md
│ │ ├── __init__.py
│ │ ├── fid.py
│ │ ├── metric_util.py
│ │ ├── niqe.py
│ │ ├── niqe_pris_params.npz
│ │ ├── psnr_ssim.py
│ │ └── test_metrics/
│ │ └── test_psnr_ssim.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── edvr_model.py
│ │ ├── esrgan_model.py
│ │ ├── hifacegan_model.py
│ │ ├── lr_scheduler.py
│ │ ├── realesrgan_model.py
│ │ ├── realesrnet_model.py
│ │ ├── sr_model.py
│ │ ├── srgan_model.py
│ │ ├── stylegan2_model.py
│ │ ├── swinir_model.py
│ │ ├── video_base_model.py
│ │ ├── video_gan_model.py
│ │ ├── video_recurrent_gan_model.py
│ │ └── video_recurrent_model.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── dcn/
│ │ │ ├── __init__.py
│ │ │ ├── deform_conv.py
│ │ │ └── src/
│ │ │ ├── deform_conv_cuda.cpp
│ │ │ ├── deform_conv_cuda_kernel.cu
│ │ │ └── deform_conv_ext.cpp
│ │ ├── fused_act/
│ │ │ ├── __init__.py
│ │ │ ├── fused_act.py
│ │ │ └── src/
│ │ │ ├── fused_bias_act.cpp
│ │ │ └── fused_bias_act_kernel.cu
│ │ └── upfirdn2d/
│ │ ├── __init__.py
│ │ ├── src/
│ │ │ ├── upfirdn2d.cpp
│ │ │ └── upfirdn2d_kernel.cu
│ │ └── upfirdn2d.py
│ ├── test.py
│ ├── train.py
│ └── utils/
│ ├── __init__.py
│ ├── color_util.py
│ ├── diffjpeg.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_process_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ ├── options.py
│ ├── plot_util.py
│ └── registry.py
├── colab/
│ └── README.md
├── docs/
│ ├── BasicSR_docs_CN.md
│ ├── Config.md
│ ├── DatasetPreparation.md
│ ├── DatasetPreparation_CN.md
│ ├── DesignConvention.md
│ ├── FAQ.md
│ ├── HOWTOs.md
│ ├── HOWTOs_CN.md
│ ├── INSTALL.md
│ ├── Logging.md
│ ├── Logging_CN.md
│ ├── Makefile
│ ├── Metrics.md
│ ├── Metrics_CN.md
│ ├── ModelZoo.md
│ ├── ModelZoo_CN.md
│ ├── Models.md
│ ├── README.md
│ ├── TrainTest.md
│ ├── TrainTest_CN.md
│ ├── auto_generate_api.py
│ ├── conf.py
│ ├── history_updates.md
│ ├── index.rst
│ ├── introduction.md
│ ├── make.bat
│ └── requirements.txt
├── inference/
│ ├── inference_basicvsr.py
│ ├── inference_basicvsrpp.py
│ ├── inference_dfdnet.py
│ ├── inference_esrgan.py
│ ├── inference_ridnet.py
│ ├── inference_stylegan2.py
│ └── inference_swinir.py
├── options/
│ ├── test/
│ │ ├── BasicVSR/
│ │ │ ├── test_BasicVSR_REDS.yml
│ │ │ ├── test_BasicVSR_Vimeo90K_BDx4.yml
│ │ │ ├── test_BasicVSR_Vimeo90K_BIx4.yml
│ │ │ ├── test_IconVSR_REDS.yml
│ │ │ ├── test_IconVSR_Vimeo90K_BDx4.yml
│ │ │ └── test_IconVSR_Vimeo90K_BIx4.yml
│ │ ├── DUF/
│ │ │ └── test_DUF_official.yml
│ │ ├── EDSR/
│ │ │ ├── test_EDSR_Lx2.yml
│ │ │ ├── test_EDSR_Lx3.yml
│ │ │ ├── test_EDSR_Lx4.yml
│ │ │ ├── test_EDSR_Mx2.yml
│ │ │ ├── test_EDSR_Mx3.yml
│ │ │ └── test_EDSR_Mx4.yml
│ │ ├── EDVR/
│ │ │ ├── test_EDVR_L_deblur_REDS.yml
│ │ │ ├── test_EDVR_L_deblurcomp_REDS.yml
│ │ │ ├── test_EDVR_L_x4_SR_REDS.yml
│ │ │ ├── test_EDVR_L_x4_SR_Vid4.yml
│ │ │ ├── test_EDVR_L_x4_SR_Vimeo90K.yml
│ │ │ ├── test_EDVR_L_x4_SRblur_REDS.yml
│ │ │ └── test_EDVR_M_x4_SR_REDS.yml
│ │ ├── ESRGAN/
│ │ │ ├── test_ESRGAN_x4.yml
│ │ │ ├── test_ESRGAN_x4_woGT.yml
│ │ │ └── test_RRDBNet_PSNR_x4.yml
│ │ ├── HiFaceGAN/
│ │ │ ├── test_hifacegan.yml
│ │ │ └── test_hifacegan_woGT.yml
│ │ ├── RCAN/
│ │ │ └── test_RCAN.yml
│ │ ├── SRResNet_SRGAN/
│ │ │ ├── test_MSRGAN_x4.yml
│ │ │ ├── test_MSRResNet_x2.yml
│ │ │ ├── test_MSRResNet_x3.yml
│ │ │ ├── test_MSRResNet_x4.yml
│ │ │ └── test_MSRResNet_x4_woGT.yml
│ │ └── TOF/
│ │ └── test_TOF_official.yml
│ └── train/
│ ├── BasicVSR/
│ │ ├── train_BasicVSR_REDS.yml
│ │ ├── train_BasicVSR_Vimeo90K_BDx4.yml
│ │ ├── train_BasicVSR_Vimeo90K_BIx4.yml
│ │ ├── train_IconVSR_REDS.yml
│ │ ├── train_IconVSR_Vimeo90K_BDx4.yml
│ │ └── train_IconVSR_Vimeo90K_BIx4.yml
│ ├── BasicVSRPP/
│ │ └── train_BasicVSRPP_REDS.yml
│ ├── ECBSR/
│ │ ├── train_ECBSR_x2_m4c16_prelu.yml
│ │ ├── train_ECBSR_x4_m4c16_prelu.yml
│ │ └── train_ECBSR_x4_m4c16_prelu_RGB.yml
│ ├── EDSR/
│ │ ├── train_EDSR_Lx2.yml
│ │ ├── train_EDSR_Lx3.yml
│ │ ├── train_EDSR_Lx4.yml
│ │ ├── train_EDSR_Mx2.yml
│ │ ├── train_EDSR_Mx3.yml
│ │ └── train_EDSR_Mx4.yml
│ ├── EDVR/
│ │ ├── train_EDVRM_woTSA_GAN_TODO.yml
│ │ ├── train_EDVR_L_x4_SR_REDS.yml
│ │ ├── train_EDVR_L_x4_SR_REDS_woTSA.yml
│ │ ├── train_EDVR_M_x4_SR_REDS.yml
│ │ └── train_EDVR_M_x4_SR_REDS_woTSA.yml
│ ├── ESRGAN/
│ │ ├── train_ESRGAN_x4.yml
│ │ └── train_RRDBNet_PSNR_x4.yml
│ ├── HiFaceGAN/
│ │ └── train_hifacegan.yml
│ ├── LDL/
│ │ └── train_LDL_Real_x4.yml
│ ├── RCAN/
│ │ └── train_RCAN_x2.yml
│ ├── RealESRGAN/
│ │ ├── train_realesrgan_x2plus.yml
│ │ ├── train_realesrgan_x4plus.yml
│ │ ├── train_realesrnet_x2plus.yml
│ │ └── train_realesrnet_x4plus.yml
│ ├── SRResNet_SRGAN/
│ │ ├── README.md
│ │ ├── train_MSRGAN_x4.yml
│ │ ├── train_MSRResNet_x2.yml
│ │ ├── train_MSRResNet_x3.yml
│ │ └── train_MSRResNet_x4.yml
│ ├── StyleGAN/
│ │ └── train_StyleGAN2_256_Cmul2_FFHQ.yml
│ ├── SwinIR/
│ │ ├── train_SwinIR_SRx2_scratch.yml
│ │ └── train_SwinIR_SRx4_scratch.yml
│ └── VideoRecurrentGAN/
│ └── train_VideoRecurrentGANModel_REDS.yml
├── requirements.txt
├── scripts/
│ ├── data_preparation/
│ │ ├── create_lmdb.py
│ │ ├── download_datasets.py
│ │ ├── extract_images_from_tfrecords.py
│ │ ├── extract_subimages.py
│ │ ├── generate_meta_info.py
│ │ ├── prepare_hifacegan_dataset.py
│ │ └── regroup_reds_dataset.py
│ ├── dist_test.sh
│ ├── dist_train.sh
│ ├── download_gdrive.py
│ ├── download_pretrained_models.py
│ ├── matlab_scripts/
│ │ ├── back_projection/
│ │ │ ├── backprojection.m
│ │ │ ├── main_bp.m
│ │ │ └── main_reverse_filter.m
│ │ ├── generate_LR_Vimeo90K.m
│ │ └── generate_bicubic_img.m
│ ├── metrics/
│ │ ├── calculate_fid_folder.py
│ │ ├── calculate_fid_stats_from_datasets.py
│ │ ├── calculate_lpips.py
│ │ ├── calculate_niqe.py
│ │ ├── calculate_psnr_ssim.py
│ │ └── calculate_stylegan2_fid.py
│ ├── model_conversion/
│ │ ├── convert_dfdnet.py
│ │ ├── convert_models.py
│ │ ├── convert_ridnet.py
│ │ └── convert_stylegan.py
│ ├── plot/
│ │ ├── README.md
│ │ └── model_complexity_cmp_bsrn.py
│ └── publish_models.py
├── setup.cfg
├── setup.py
├── test_scripts/
│ ├── test_discriminator_backward.py
│ ├── test_ffhq_dataset.py
│ ├── test_lr_scheduler.py
│ ├── test_niqe.py
│ ├── test_paired_image_dataset.py
│ ├── test_reds_dataset.py
│ └── test_vimeo90k_dataset.py
└── tests/
├── README.md
├── data/
│ ├── gt.lmdb/
│ │ ├── data.mdb
│ │ ├── lock.mdb
│ │ └── meta_info.txt
│ ├── lq.lmdb/
│ │ ├── data.mdb
│ │ ├── lock.mdb
│ │ └── meta_info.txt
│ ├── meta_info_gt.txt
│ └── meta_info_pair.txt
├── test_archs/
│ ├── test_basicvsr_arch.py
│ ├── test_discriminator_arch.py
│ ├── test_duf_arch.py
│ ├── test_ecbsr_arch.py
│ └── test_srresnet_arch.py
├── test_data/
│ ├── test_paired_image_dataset.py
│ └── test_single_image_dataset.py
├── test_losses/
│ └── test_losses.py
├── test_metrics/
│ └── test_psnr_ssim.py
└── test_models/
└── test_sr_model.py
SYMBOL INDEX (952 symbols across 134 files)
FILE: basicsr/archs/__init__.py
function build_network (line 18) | def build_network(opt):
FILE: basicsr/archs/arch_util.py
function default_init_weights (line 18) | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
function make_layer (line 48) | def make_layer(basic_block, num_basic_block, **kwarg):
class ResidualBlockNoBN (line 64) | class ResidualBlockNoBN(nn.Module):
method __init__ (line 75) | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
method forward (line 85) | def forward(self, x):
class Upsample (line 91) | class Upsample(nn.Sequential):
method __init__ (line 99) | def __init__(self, scale, num_feat):
function flow_warp (line 113) | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', ali...
function resize_flow (line 147) | def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_co...
function pixel_unshuffle (line 186) | def pixel_unshuffle(x, scale):
class DCNv2Pack (line 205) | class DCNv2Pack(ModulatedDeformConvPack):
method forward (line 215) | def forward(self, x, feat):
function _no_grad_trunc_normal_ (line 234) | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
function trunc_normal_ (line 272) | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
function _ntuple (line 299) | def _ntuple(n):
FILE: basicsr/archs/basicvsr_arch.py
class BasicVSR (line 12) | class BasicVSR(nn.Module):
method __init__ (line 21) | def __init__(self, num_feat=64, num_block=15, spynet_path=None):
method get_flow (line 44) | def get_flow(self, x):
method forward (line 55) | def forward(self, x):
class ConvResidualBlocks (line 101) | class ConvResidualBlocks(nn.Module):
method __init__ (line 110) | def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
method forward (line 116) | def forward(self, fea):
class IconVSR (line 121) | class IconVSR(nn.Module):
method __init__ (line 133) | def __init__(self,
method pad_spatial (line 169) | def pad_spatial(self, x):
method get_flow (line 192) | def get_flow(self, x):
method get_keyframe_feature (line 203) | def get_keyframe_feature(self, x, keyframe_idx):
method forward (line 216) | def forward(self, x):
class EDVRFeatureExtractor (line 271) | class EDVRFeatureExtractor(nn.Module):
method __init__ (line 280) | def __init__(self, num_input_frame, num_feat, load_path):
method forward (line 304) | def forward(self, x):
FILE: basicsr/archs/basicvsrpp_arch.py
class BasicVSRPlusPlus (line 15) | class BasicVSRPlusPlus(nn.Module):
method __init__ (line 43) | def __init__(self,
method check_if_mirror_extended (line 109) | def check_if_mirror_extended(self, lqs):
method compute_flow (line 123) | def compute_flow(self, lqs):
method propagate (line 156) | def propagate(self, feats, flows, module_name):
method upsample (line 237) | def upsample(self, lqs, feats):
method forward (line 280) | def forward(self, lqs):
class SecondOrderDeformableAlignment (line 347) | class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
method __init__ (line 365) | def __init__(self, *args, **kwargs):
method init_offset (line 382) | def init_offset(self):
method forward (line 392) | def forward(self, x, extra_feat, flow_1, flow_2):
FILE: basicsr/archs/dfdnet_arch.py
class SFTUpBlock (line 12) | class SFTUpBlock(nn.Module):
method __init__ (line 22) | def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
method forward (line 45) | def forward(self, x, updated_feat):
class DFDNet (line 57) | class DFDNet(nn.Module):
method __init__ (line 67) | def __init__(self, num_feat, dict_path):
method swap_feat (line 105) | def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_...
method put_dict_to_device (line 126) | def put_dict_to_device(self, x):
method forward (line 133) | def forward(self, x, part_locations):
FILE: basicsr/archs/dfdnet_util.py
class BlurFunctionBackward (line 8) | class BlurFunctionBackward(Function):
method forward (line 11) | def forward(ctx, grad_output, kernel, kernel_flip):
method backward (line 17) | def backward(ctx, gradgrad_output):
class BlurFunction (line 23) | class BlurFunction(Function):
method forward (line 26) | def forward(ctx, x, kernel, kernel_flip):
method backward (line 32) | def backward(ctx, grad_output):
class Blur (line 41) | class Blur(nn.Module):
method __init__ (line 43) | def __init__(self, channel):
method forward (line 53) | def forward(self, x):
function calc_mean_std (line 57) | def calc_mean_std(feat, eps=1e-5):
function adaptive_instance_normalization (line 74) | def adaptive_instance_normalization(content_feat, style_feat):
function AttentionBlock (line 91) | def AttentionBlock(in_channel):
function conv_block (line 97) | def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilat...
class MSDilationBlock (line 123) | class MSDilationBlock(nn.Module):
method __init__ (line 126) | def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), ...
method forward (line 141) | def forward(self, x):
class UpResBlock (line 150) | class UpResBlock(nn.Module):
method __init__ (line 152) | def __init__(self, in_channel):
method forward (line 160) | def forward(self, x):
FILE: basicsr/archs/discriminator_arch.py
class VGGStyleDiscriminator (line 9) | class VGGStyleDiscriminator(nn.Module):
method __init__ (line 19) | def __init__(self, num_in_ch, num_feat, input_size=128):
method forward (line 61) | def forward(self, x):
class UNetDiscriminatorSN (line 91) | class UNetDiscriminatorSN(nn.Module):
method __init__ (line 102) | def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
method forward (line 121) | def forward(self, x):
FILE: basicsr/archs/duf_arch.py
class DenseBlocksTemporalReduce (line 9) | class DenseBlocksTemporalReduce(nn.Module):
method __init__ (line 21) | def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights...
method forward (line 58) | def forward(self, x):
class DenseBlocks (line 78) | class DenseBlocks(nn.Module):
method __init__ (line 92) | def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_offic...
method forward (line 120) | def forward(self, x):
class DynamicUpsamplingFilter (line 134) | class DynamicUpsamplingFilter(nn.Module):
method __init__ (line 145) | def __init__(self, filter_size=(5, 5)):
method forward (line 157) | def forward(self, x, filters):
class DUF (line 182) | class DUF(nn.Module):
method __init__ (line 204) | def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
method forward (line 245) | def forward(self, x):
FILE: basicsr/archs/ecbsr_arch.py
class SeqConv3x3 (line 8) | class SeqConv3x3(nn.Module):
method __init__ (line 22) | def __init__(self, seq_type, in_channels, out_channels, depth_multipli...
method forward (line 105) | def forward(self, x):
method rep_params (line 131) | def rep_params(self):
class ECB (line 156) | class ECB(nn.Module):
method __init__ (line 170) | def __init__(self, in_channels, out_channels, depth_multiplier, act_ty...
method forward (line 202) | def forward(self, x):
method rep_params (line 214) | def rep_params(self):
class ECBSR (line 236) | class ECBSR(nn.Module):
method __init__ (line 252) | def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with...
method forward (line 268) | def forward(self, x):
FILE: basicsr/archs/edsr_arch.py
class EDSR (line 9) | class EDSR(nn.Module):
method __init__ (line 30) | def __init__(self,
method forward (line 50) | def forward(self, x):
FILE: basicsr/archs/edvr_arch.py
class PCDAlignment (line 9) | class PCDAlignment(nn.Module):
method __init__ (line 20) | def __init__(self, num_feat=64, deformable_groups=8):
method forward (line 55) | def forward(self, nbr_feat_l, ref_feat_l):
class TSAFusion (line 100) | class TSAFusion(nn.Module):
method __init__ (line 115) | def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
method forward (line 140) | def forward(self, aligned_feat):
class PredeblurModule (line 192) | class PredeblurModule(nn.Module):
method __init__ (line 201) | def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
method forward (line 223) | def forward(self, x):
class EDVR (line 246) | class EDVR(nn.Module):
method __init__ (line 271) | def __init__(self,
method forward (line 325) | def forward(self, x):
FILE: basicsr/archs/hifacegan_arch.py
class SPADEGenerator (line 10) | class SPADEGenerator(BaseNetwork):
method __init__ (line 13) | def __init__(self,
method encode (line 61) | def encode(self, input_tensor):
method forward (line 71) | def forward(self, x):
method mixed_guidance_forward (line 96) | def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progres...
class HiFaceGAN (line 145) | class HiFaceGAN(SPADEGenerator):
method __init__ (line 151) | def __init__(self,
method encode (line 163) | def encode(self, input_tensor):
class HiFaceGANDiscriminator (line 168) | class HiFaceGANDiscriminator(BaseNetwork):
method __init__ (line 188) | def __init__(self,
method downsample (line 208) | def downsample(self, x):
method forward (line 213) | def forward(self, x):
class NLayerDiscriminator (line 223) | class NLayerDiscriminator(BaseNetwork):
method __init__ (line 226) | def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_featur...
method forward (line 251) | def forward(self, x):
FILE: basicsr/archs/hifacegan_util.py
class SPADE (line 12) | class SPADE(nn.Module):
method __init__ (line 14) | def __init__(self, config_text, norm_nc, label_nc):
method forward (line 40) | def forward(self, x, segmap):
class SPADEResnetBlock (line 57) | class SPADEResnetBlock(nn.Module):
method __init__ (line 67) | def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', sema...
method forward (line 95) | def forward(self, x, seg):
method shortcut (line 102) | def shortcut(self, x, seg):
method act (line 109) | def act(self, x):
class BaseNetwork (line 113) | class BaseNetwork(nn.Module):
method init_weights (line 116) | def init_weights(self, init_type='normal', gain=0.02):
method forward (line 150) | def forward(self, x):
function lip2d (line 154) | def lip2d(x, logit, kernel=3, stride=2, padding=1):
class SoftGate (line 159) | class SoftGate(nn.Module):
method forward (line 162) | def forward(self, x):
class SimplifiedLIP (line 166) | class SimplifiedLIP(nn.Module):
method __init__ (line 168) | def __init__(self, channels):
method init_layer (line 174) | def init_layer(self):
method forward (line 177) | def forward(self, x):
class LIPEncoder (line 182) | class LIPEncoder(BaseNetwork):
method __init__ (line 185) | def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.Inst...
method forward (line 213) | def forward(self, x):
function get_nonspade_norm_layer (line 217) | def get_nonspade_norm_layer(norm_type='instance'):
FILE: basicsr/archs/inception.py
class InceptionV3 (line 17) | class InceptionV3(nn.Module):
method __init__ (line 32) | def __init__(self,
method forward (line 124) | def forward(self, x):
function fid_inception_v3 (line 155) | def fid_inception_v3():
class FIDInceptionA (line 189) | class FIDInceptionA(models.inception.InceptionA):
method __init__ (line 192) | def __init__(self, in_channels, pool_features):
method forward (line 195) | def forward(self, x):
class FIDInceptionC (line 214) | class FIDInceptionC(models.inception.InceptionC):
method __init__ (line 217) | def __init__(self, in_channels, channels_7x7):
method forward (line 220) | def forward(self, x):
class FIDInceptionE_1 (line 242) | class FIDInceptionE_1(models.inception.InceptionE):
method __init__ (line 245) | def __init__(self, in_channels):
method forward (line 248) | def forward(self, x):
class FIDInceptionE_2 (line 275) | class FIDInceptionE_2(models.inception.InceptionE):
method __init__ (line 278) | def __init__(self, in_channels):
method forward (line 281) | def forward(self, x):
FILE: basicsr/archs/rcan_arch.py
class ChannelAttention (line 8) | class ChannelAttention(nn.Module):
method __init__ (line 16) | def __init__(self, num_feat, squeeze_factor=16):
method forward (line 22) | def forward(self, x):
class RCAB (line 27) | class RCAB(nn.Module):
method __init__ (line 36) | def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
method forward (line 44) | def forward(self, x):
class ResidualGroup (line 49) | class ResidualGroup(nn.Module):
method __init__ (line 59) | def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
method forward (line 66) | def forward(self, x):
class RCAN (line 72) | class RCAN(nn.Module):
method __init__ (line 96) | def __init__(self,
method forward (line 124) | def forward(self, x):
FILE: basicsr/archs/ridnet_arch.py
class MeanShift (line 8) | class MeanShift(nn.Conv2d):
method __init__ (line 21) | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_gra...
class EResidualBlockNoBN (line 31) | class EResidualBlockNoBN(nn.Module):
method __init__ (line 37) | def __init__(self, in_channels, out_channels):
method forward (line 49) | def forward(self, x):
class MergeRun (line 55) | class MergeRun(nn.Module):
method __init__ (line 65) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 78) | def forward(self, x):
class ChannelAttention (line 87) | class ChannelAttention(nn.Module):
method __init__ (line 95) | def __init__(self, mid_channels, squeeze_factor=16):
method forward (line 101) | def forward(self, x):
class EAM (line 106) | class EAM(nn.Module):
method __init__ (line 119) | def __init__(self, in_channels, mid_channels, out_channels):
method forward (line 129) | def forward(self, x):
class RIDNet (line 138) | class RIDNet(nn.Module):
method __init__ (line 154) | def __init__(self,
method forward (line 174) | def forward(self, x):
FILE: basicsr/archs/rrdbnet_arch.py
class ResidualDenseBlock (line 9) | class ResidualDenseBlock(nn.Module):
method __init__ (line 19) | def __init__(self, num_feat=64, num_grow_ch=32):
method forward (line 32) | def forward(self, x):
class RRDB (line 42) | class RRDB(nn.Module):
method __init__ (line 52) | def __init__(self, num_feat, num_grow_ch=32):
method forward (line 58) | def forward(self, x):
class RRDBNet (line 67) | class RRDBNet(nn.Module):
method __init__ (line 87) | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_bl...
method forward (line 105) | def forward(self, x):
FILE: basicsr/archs/spynet_arch.py
class BasicModule (line 10) | class BasicModule(nn.Module):
method __init__ (line 14) | def __init__(self):
method forward (line 24) | def forward(self, tensor_input):
class SpyNet (line 29) | class SpyNet(nn.Module):
method __init__ (line 36) | def __init__(self, load_path=None):
method preprocess (line 45) | def preprocess(self, tensor_input):
method process (line 49) | def process(self, ref, supp):
method forward (line 81) | def forward(self, ref, supp):
FILE: basicsr/archs/srresnet_arch.py
class MSRResNet (line 9) | class MSRResNet(nn.Module):
method __init__ (line 25) | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=1...
method forward (line 52) | def forward(self, x):
FILE: basicsr/archs/srvgg_arch.py
class SRVGGNetCompact (line 8) | class SRVGGNetCompact(nn.Module):
method __init__ (line 23) | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16...
method forward (line 61) | def forward(self, x):
FILE: basicsr/archs/stylegan2_arch.py
class NormStyleCode (line 12) | class NormStyleCode(nn.Module):
method forward (line 14) | def forward(self, x):
function make_resample_kernel (line 26) | def make_resample_kernel(k):
class UpFirDnUpsample (line 43) | class UpFirDnUpsample(nn.Module):
method __init__ (line 56) | def __init__(self, resample_kernel, factor=2):
method forward (line 64) | def forward(self, x):
method __repr__ (line 68) | def __repr__(self):
class UpFirDnDownsample (line 72) | class UpFirDnDownsample(nn.Module):
method __init__ (line 81) | def __init__(self, resample_kernel, factor=2):
method forward (line 89) | def forward(self, x):
method __repr__ (line 93) | def __repr__(self):
class UpFirDnSmooth (line 97) | class UpFirDnSmooth(nn.Module):
method __init__ (line 108) | def __init__(self, resample_kernel, upsample_factor=1, downsample_fact...
method forward (line 125) | def forward(self, x):
method __repr__ (line 129) | def __repr__(self):
class EqualLinear (line 134) | class EqualLinear(nn.Module):
method __init__ (line 148) | def __init__(self, in_channels, out_channels, bias=True, bias_init_val...
method forward (line 165) | def forward(self, x):
method __repr__ (line 177) | def __repr__(self):
class ModulatedConv2d (line 182) | class ModulatedConv2d(nn.Module):
method __init__ (line 202) | def __init__(self,
method forward (line 239) | def forward(self, x, style):
method __repr__ (line 281) | def __repr__(self):
class StyleConv (line 288) | class StyleConv(nn.Module):
method __init__ (line 303) | def __init__(self,
method forward (line 323) | def forward(self, x, style, noise=None):
class ToRGB (line 336) | class ToRGB(nn.Module):
method __init__ (line 347) | def __init__(self, in_channels, num_style_feat, upsample=True, resampl...
method forward (line 357) | def forward(self, x, style, skip=None):
class ConstantInput (line 377) | class ConstantInput(nn.Module):
method __init__ (line 385) | def __init__(self, num_channel, size):
method forward (line 389) | def forward(self, batch):
class StyleGAN2Generator (line 395) | class StyleGAN2Generator(nn.Module):
method __init__ (line 411) | def __init__(self,
method make_noise (line 493) | def make_noise(self):
method get_latent (line 504) | def get_latent(self, x):
method mean_latent (line 507) | def mean_latent(self, num_latent):
method forward (line 512) | def forward(self,
class ScaledLeakyReLU (line 589) | class ScaledLeakyReLU(nn.Module):
method __init__ (line 596) | def __init__(self, negative_slope=0.2):
method forward (line 600) | def forward(self, x):
class EqualConv2d (line 605) | class EqualConv2d(nn.Module):
method __init__ (line 620) | def __init__(self, in_channels, out_channels, kernel_size, stride=1, p...
method forward (line 635) | def forward(self, x):
method __repr__ (line 646) | def __repr__(self):
class ConvLayer (line 654) | class ConvLayer(nn.Sequential):
method __init__ (line 671) | def __init__(self,
class ResBlock (line 704) | class ResBlock(nn.Module):
method __init__ (line 716) | def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3...
method forward (line 725) | def forward(self, x):
class StyleGAN2Discriminator (line 734) | class StyleGAN2Discriminator(nn.Module):
method __init__ (line 748) | def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1,...
method forward (line 783) | def forward(self, x):
FILE: basicsr/archs/stylegan2_bilinear_arch.py
class NormStyleCode (line 11) | class NormStyleCode(nn.Module):
method forward (line 13) | def forward(self, x):
class EqualLinear (line 25) | class EqualLinear(nn.Module):
method __init__ (line 39) | def __init__(self, in_channels, out_channels, bias=True, bias_init_val...
method forward (line 56) | def forward(self, x):
method __repr__ (line 68) | def __repr__(self):
class ModulatedConv2d (line 73) | class ModulatedConv2d(nn.Module):
method __init__ (line 91) | def __init__(self,
method forward (line 121) | def forward(self, x, style):
method __repr__ (line 156) | def __repr__(self):
class StyleConv (line 163) | class StyleConv(nn.Module):
method __init__ (line 176) | def __init__(self,
method forward (line 196) | def forward(self, x, style, noise=None):
class ToRGB (line 209) | class ToRGB(nn.Module):
method __init__ (line 218) | def __init__(self, in_channels, num_style_feat, upsample=True, interpo...
method forward (line 236) | def forward(self, x, style, skip=None):
class ConstantInput (line 257) | class ConstantInput(nn.Module):
method __init__ (line 265) | def __init__(self, num_channel, size):
method forward (line 269) | def forward(self, batch):
class StyleGAN2GeneratorBilinear (line 275) | class StyleGAN2GeneratorBilinear(nn.Module):
method __init__ (line 288) | def __init__(self,
method make_noise (line 370) | def make_noise(self):
method get_latent (line 381) | def get_latent(self, x):
method mean_latent (line 384) | def mean_latent(self, num_latent):
method forward (line 389) | def forward(self,
class ScaledLeakyReLU (line 466) | class ScaledLeakyReLU(nn.Module):
method __init__ (line 473) | def __init__(self, negative_slope=0.2):
method forward (line 477) | def forward(self, x):
class EqualConv2d (line 482) | class EqualConv2d(nn.Module):
method __init__ (line 497) | def __init__(self, in_channels, out_channels, kernel_size, stride=1, p...
method forward (line 512) | def forward(self, x):
method __repr__ (line 523) | def __repr__(self):
class ConvLayer (line 531) | class ConvLayer(nn.Sequential):
method __init__ (line 544) | def __init__(self,
class ResBlock (line 580) | class ResBlock(nn.Module):
method __init__ (line 588) | def __init__(self, in_channels, out_channels, interpolation_mode='bili...
method forward (line 609) | def forward(self, x):
FILE: basicsr/archs/swinir_arch.py
function drop_path (line 14) | def drop_path(x, drop_prob: float = 0., training: bool = False):
class DropPath (line 29) | class DropPath(nn.Module):
method __init__ (line 35) | def __init__(self, drop_prob=None):
method forward (line 39) | def forward(self, x):
class Mlp (line 43) | class Mlp(nn.Module):
method __init__ (line 45) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 54) | def forward(self, x):
function window_partition (line 63) | def window_partition(x, window_size):
function window_reverse (line 78) | def window_reverse(windows, window_size, h, w):
class WindowAttention (line 95) | class WindowAttention(nn.Module):
method __init__ (line 109) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
method forward (line 144) | def forward(self, x, mask=None):
method extra_repr (line 177) | def extra_repr(self) -> str:
method flops (line 180) | def flops(self, n):
class SwinTransformerBlock (line 194) | class SwinTransformerBlock(nn.Module):
method __init__ (line 213) | def __init__(self,
method calculate_mask (line 262) | def calculate_mask(self, x_size):
method forward (line 283) | def forward(self, x, x_size):
method extra_repr (line 325) | def extra_repr(self) -> str:
method flops (line 329) | def flops(self):
class PatchMerging (line 344) | class PatchMerging(nn.Module):
method __init__ (line 353) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
method forward (line 360) | def forward(self, x):
method extra_repr (line 383) | def extra_repr(self) -> str:
method flops (line 386) | def flops(self):
class BasicLayer (line 393) | class BasicLayer(nn.Module):
method __init__ (line 413) | def __init__(self,
method forward (line 458) | def forward(self, x, x_size):
method extra_repr (line 468) | def extra_repr(self) -> str:
method flops (line 471) | def flops(self):
class RSTB (line 480) | class RSTB(nn.Module):
method __init__ (line 503) | def __init__(self,
method forward (line 557) | def forward(self, x, x_size):
method flops (line 560) | def flops(self):
class PatchEmbed (line 571) | class PatchEmbed(nn.Module):
method __init__ (line 582) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
method forward (line 600) | def forward(self, x):
method flops (line 606) | def flops(self):
class PatchUnEmbed (line 614) | class PatchUnEmbed(nn.Module):
method __init__ (line 625) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
method forward (line 638) | def forward(self, x, x_size):
method flops (line 642) | def flops(self):
class Upsample (line 647) | class Upsample(nn.Sequential):
method __init__ (line 655) | def __init__(self, scale, num_feat):
class UpsampleOneStep (line 669) | class UpsampleOneStep(nn.Sequential):
method __init__ (line 679) | def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
method flops (line 687) | def flops(self):
class SwinIR (line 694) | class SwinIR(nn.Module):
method __init__ (line 722) | def __init__(self,
method _init_weights (line 859) | def _init_weights(self, m):
method no_weight_decay (line 869) | def no_weight_decay(self):
method no_weight_decay_keywords (line 873) | def no_weight_decay_keywords(self):
method forward_features (line 876) | def forward_features(self, x):
method forward (line 891) | def forward(self, x):
method flops (line 924) | def flops(self):
FILE: basicsr/archs/tof_arch.py
class BasicModule (line 9) | class BasicModule(nn.Module):
method __init__ (line 16) | def __init__(self):
method forward (line 29) | def forward(self, tensor_input):
class SPyNetTOF (line 42) | class SPyNetTOF(nn.Module):
method __init__ (line 59) | def __init__(self, load_path=None):
method forward (line 66) | def forward(self, ref, supp):
class TOFlow (line 94) | class TOFlow(nn.Module):
method __init__ (line 111) | def __init__(self, adapt_official_weights=False):
method normalize (line 131) | def normalize(self, img):
method denormalize (line 134) | def denormalize(self, img):
method forward (line 137) | def forward(self, lrs):
FILE: basicsr/archs/vgg_arch.py
function insert_bn (line 36) | def insert_bn(names):
class VGGFeatureExtractor (line 55) | class VGGFeatureExtractor(nn.Module):
method __init__ (line 78) | def __init__(self,
method forward (line 141) | def forward(self, x):
FILE: basicsr/data/__init__.py
function build_dataset (line 25) | def build_dataset(dataset_opt):
function build_dataloader (line 40) | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sample...
function worker_init_fn (line 97) | def worker_init_fn(worker_id, num_workers, rank, seed):
FILE: basicsr/data/data_sampler.py
class EnlargedSampler (line 6) | class EnlargedSampler(Sampler):
method __init__ (line 21) | def __init__(self, dataset, num_replicas, rank, ratio=1):
method __iter__ (line 29) | def __iter__(self):
method __len__ (line 44) | def __len__(self):
method set_epoch (line 47) | def set_epoch(self, epoch):
FILE: basicsr/data/data_util.py
function read_img_seq (line 11) | def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=F...
function generate_frame_indices (line 43) | def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='...
function paired_paths_from_lmdb (line 95) | def paired_paths_from_lmdb(folders, keys):
function paired_paths_from_meta_info_file (line 156) | def paired_paths_from_meta_info_file(folders, keys, meta_info_file, file...
function paired_paths_from_folder (line 200) | def paired_paths_from_folder(folders, keys, filename_tmpl):
function paths_from_folder (line 236) | def paths_from_folder(folder):
function paths_from_lmdb (line 251) | def paths_from_lmdb(folder):
function generate_gaussian_kernel (line 267) | def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
function duf_downsample (line 285) | def duf_downsample(x, kernel_size=13, scale=4):
FILE: basicsr/data/degradations.py
function sigma_matrix2 (line 16) | def sigma_matrix2(sig_x, sig_y, theta):
function mesh_grid (line 32) | def mesh_grid(kernel_size):
function pdf2 (line 50) | def pdf2(sigma_matrix, grid):
function cdf2 (line 66) | def cdf2(d_matrix, grid):
function bivariate_Gaussian (line 84) | def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isot...
function bivariate_generalized_Gaussian (line 112) | def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, bet...
function bivariate_plateau (line 143) | def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None,...
function random_bivariate_Gaussian (line 176) | def random_bivariate_Gaussian(kernel_size,
function random_bivariate_generalized_Gaussian (line 220) | def random_bivariate_generalized_Gaussian(kernel_size,
function random_bivariate_plateau (line 272) | def random_bivariate_plateau(kernel_size,
function random_mixed_kernels (line 324) | def random_mixed_kernels(kernel_list,
function circular_lowpass_kernel (line 389) | def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
function generate_gaussian_noise (line 419) | def generate_gaussian_noise(img, sigma=10, gray_noise=False):
function add_gaussian_noise (line 438) | def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_nois...
function generate_gaussian_noise_pt (line 460) | def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
function add_gaussian_noise_pt (line 492) | def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds...
function random_generate_gaussian_noise (line 515) | def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
function random_add_gaussian_noise (line 524) | def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, cl...
function random_generate_gaussian_noise_pt (line 536) | def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_pro...
function random_add_gaussian_noise_pt (line 544) | def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0,...
function generate_poisson_noise (line 559) | def generate_poisson_noise(img, scale=1.0, gray_noise=False):
function add_poisson_noise (line 586) | def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_nois...
function generate_poisson_noise_pt (line 609) | def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
function add_poisson_noise_pt (line 657) | def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_n...
function random_generate_poisson_noise (line 685) | def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
function random_add_poisson_noise (line 694) | def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, cli...
function random_generate_poisson_noise_pt (line 706) | def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_pro...
function random_add_poisson_noise_pt (line 714) | def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, ...
function add_jpg_compression (line 731) | def add_jpg_compression(img, quality=90):
function random_add_jpg_compression (line 750) | def random_add_jpg_compression(img, quality_range=(90, 100)):
FILE: basicsr/data/ffhq_dataset.py
class FFHQDataset (line 13) | class FFHQDataset(data.Dataset):
method __init__ (line 26) | def __init__(self, opt):
method __getitem__ (line 47) | def __getitem__(self, index):
method __len__ (line 79) | def __len__(self):
FILE: basicsr/data/paired_image_dataset.py
class PairedImageDataset (line 11) | class PairedImageDataset(data.Dataset):
method __init__ (line 38) | def __init__(self, opt):
method __getitem__ (line 63) | def __getitem__(self, index):
method __len__ (line 105) | def __len__(self):
FILE: basicsr/data/prefetch_dataloader.py
class PrefetchGenerator (line 7) | class PrefetchGenerator(threading.Thread):
method __init__ (line 17) | def __init__(self, generator, num_prefetch_queue):
method run (line 24) | def run(self):
method __next__ (line 29) | def __next__(self):
method __iter__ (line 35) | def __iter__(self):
class PrefetchDataLoader (line 39) | class PrefetchDataLoader(DataLoader):
method __init__ (line 53) | def __init__(self, num_prefetch_queue, **kwargs):
method __iter__ (line 57) | def __iter__(self):
class CPUPrefetcher (line 61) | class CPUPrefetcher():
method __init__ (line 68) | def __init__(self, loader):
method next (line 72) | def next(self):
method reset (line 78) | def reset(self):
class CUDAPrefetcher (line 82) | class CUDAPrefetcher():
method __init__ (line 94) | def __init__(self, loader, opt):
method preload (line 102) | def preload(self):
method next (line 114) | def next(self):
method reset (line 120) | def reset(self):
FILE: basicsr/data/realesrgan_dataset.py
class RealESRGANDataset (line 18) | class RealESRGANDataset(data.Dataset):
method __init__ (line 36) | def __init__(self, opt):
method __getitem__ (line 84) | def __getitem__(self, index):
method __len__ (line 192) | def __len__(self):
FILE: basicsr/data/realesrgan_paired_dataset.py
class RealESRGANPairedDataset (line 12) | class RealESRGANPairedDataset(data.Dataset):
method __init__ (line 39) | def __init__(self, opt):
method __getitem__ (line 73) | def __getitem__(self, index):
method __len__ (line 105) | def __len__(self):
FILE: basicsr/data/reds_dataset.py
class REDSDataset (line 14) | class REDSDataset(data.Dataset):
method __init__ (line 49) | def __init__(self, opt):
method __getitem__ (line 95) | def __getitem__(self, index):
method __len__ (line 204) | def __len__(self):
class REDSRecurrentDataset (line 209) | class REDSRecurrentDataset(data.Dataset):
method __init__ (line 244) | def __init__(self, opt):
method __getitem__ (line 290) | def __getitem__(self, index):
method __len__ (line 351) | def __len__(self):
FILE: basicsr/data/single_image_dataset.py
class SingleImageDataset (line 11) | class SingleImageDataset(data.Dataset):
method __init__ (line 27) | def __init__(self, opt):
method __getitem__ (line 47) | def __getitem__(self, index):
method __len__ (line 67) | def __len__(self):
FILE: basicsr/data/transforms.py
function mod_crop (line 6) | def mod_crop(img, scale):
function paired_random_crop (line 26) | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=N...
function augment (line 94) | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=F...
function img_rotate (line 161) | def img_rotate(img, angle, center=None, scale=1.0):
FILE: basicsr/data/video_test_dataset.py
class VideoTestDataset (line 12) | class VideoTestDataset(data.Dataset):
method __init__ (line 46) | def __init__(self, opt):
method __getitem__ (line 102) | def __getitem__(self, index):
method __len__ (line 129) | def __len__(self):
class VideoTestVimeo90KDataset (line 134) | class VideoTestVimeo90KDataset(data.Dataset):
method __init__ (line 153) | def __init__(self, opt):
method __getitem__ (line 181) | def __getitem__(self, index):
method __len__ (line 197) | def __len__(self):
class VideoTestDUFDataset (line 202) | class VideoTestDUFDataset(VideoTestDataset):
method __getitem__ (line 212) | def __getitem__(self, index):
class VideoRecurrentTestDataset (line 252) | class VideoRecurrentTestDataset(VideoTestDataset):
method __init__ (line 262) | def __init__(self, opt):
method __getitem__ (line 267) | def __getitem__(self, index):
method __len__ (line 282) | def __len__(self):
FILE: basicsr/data/vimeo90k_dataset.py
class Vimeo90KDataset (line 12) | class Vimeo90KDataset(data.Dataset):
method __init__ (line 59) | def __init__(self, opt):
method __getitem__ (line 84) | def __getitem__(self, index):
method __len__ (line 132) | def __len__(self):
class Vimeo90KRecurrentDataset (line 137) | class Vimeo90KRecurrentDataset(Vimeo90KDataset):
method __init__ (line 139) | def __init__(self, opt):
method __getitem__ (line 145) | def __getitem__(self, index):
method __len__ (line 198) | def __len__(self):
FILE: basicsr/losses/__init__.py
function build_loss (line 19) | def build_loss(opt):
FILE: basicsr/losses/basic_loss.py
function l1_loss (line 13) | def l1_loss(pred, target):
function mse_loss (line 18) | def mse_loss(pred, target):
function charbonnier_loss (line 23) | def charbonnier_loss(pred, target, eps=1e-12):
class L1Loss (line 28) | class L1Loss(nn.Module):
method __init__ (line 37) | def __init__(self, loss_weight=1.0, reduction='mean'):
method forward (line 45) | def forward(self, pred, target, weight=None, **kwargs):
class MSELoss (line 56) | class MSELoss(nn.Module):
method __init__ (line 65) | def __init__(self, loss_weight=1.0, reduction='mean'):
method forward (line 73) | def forward(self, pred, target, weight=None, **kwargs):
class CharbonnierLoss (line 84) | class CharbonnierLoss(nn.Module):
method __init__ (line 98) | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
method forward (line 107) | def forward(self, pred, target, weight=None, **kwargs):
class WeightedTVLoss (line 118) | class WeightedTVLoss(L1Loss):
method __init__ (line 125) | def __init__(self, loss_weight=1.0, reduction='mean'):
method forward (line 130) | def forward(self, pred, weight=None):
class PerceptualLoss (line 147) | class PerceptualLoss(nn.Module):
method __init__ (line 170) | def __init__(self,
method forward (line 198) | def forward(self, x, gt):
method _gram_mat (line 240) | def _gram_mat(self, x):
FILE: basicsr/losses/gan_loss.py
class GANLoss (line 11) | class GANLoss(nn.Module):
method __init__ (line 23) | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, l...
method _wgan_loss (line 43) | def _wgan_loss(self, input, target):
method _wgan_softplus_loss (line 55) | def _wgan_softplus_loss(self, input, target):
method get_target_label (line 72) | def get_target_label(self, input, target_is_real):
method forward (line 89) | def forward(self, input, target_is_real, is_disc=False):
class MultiScaleGANLoss (line 116) | class MultiScaleGANLoss(GANLoss):
method __init__ (line 121) | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, l...
method forward (line 124) | def forward(self, input, target_is_real, is_disc=False):
function r1_penalty (line 143) | def r1_penalty(real_pred, real_img):
function g_path_regularize (line 159) | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
function gradient_penalty_loss (line 171) | def gradient_penalty_loss(discriminator, real_data, fake_data, weight=No...
FILE: basicsr/losses/loss_util.py
function reduce_loss (line 6) | def reduce_loss(loss, reduction):
function weight_reduce_loss (line 26) | def weight_reduce_loss(loss, weight=None, reduction='mean'):
function weighted_loss (line 58) | def weighted_loss(loss_func):
function get_local_weights (line 99) | def get_local_weights(residual, ksize):
function get_refined_artifact_map (line 121) | def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
FILE: basicsr/metrics/__init__.py
function calculate_metric (line 10) | def calculate_metric(data, opt):
FILE: basicsr/metrics/fid.py
function load_patched_inception_v3 (line 10) | def load_patched_inception_v3(device='cuda', resize_input=True, normaliz...
function extract_inception_features (line 19) | def extract_inception_features(data_generator, inception, len_generator=...
function calculate_fid (line 50) | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
FILE: basicsr/metrics/metric_util.py
function reorder_image (line 6) | def reorder_image(img, input_order='HWC'):
function to_y_channel (line 32) | def to_y_channel(img):
FILE: basicsr/metrics/niqe.py
function estimate_aggd_param (line 13) | def estimate_aggd_param(block):
function compute_feature (line 41) | def compute_feature(block):
function niqe (line 68) | def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size...
function calculate_niqe (line 145) | def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', ...
FILE: basicsr/metrics/psnr_ssim.py
function calculate_psnr (line 12) | def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_cha...
function calculate_psnr_pt (line 52) | def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kw...
function calculate_ssim (line 85) | def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_cha...
function calculate_ssim_pt (line 132) | def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kw...
function _ssim (line 170) | def _ssim(img, img2):
function _ssim_pth (line 201) | def _ssim_pth(img, img2):
FILE: basicsr/metrics/test_metrics/test_psnr_ssim.py
function test (line 9) | def test(img_path, img_path2, crop_border, test_y_channel=False):
FILE: basicsr/models/__init__.py
function build_model (line 18) | def build_model(opt):
FILE: basicsr/models/base_model.py
class BaseModel (line 13) | class BaseModel():
method __init__ (line 16) | def __init__(self, opt):
method feed_data (line 23) | def feed_data(self, data):
method optimize_parameters (line 26) | def optimize_parameters(self):
method get_current_visuals (line 29) | def get_current_visuals(self):
method save (line 32) | def save(self, epoch, current_iter):
method validation (line 36) | def validation(self, dataloader, current_iter, tb_logger, save_img=Fal...
method _initialize_best_metric_results (line 50) | def _initialize_best_metric_results(self, dataset_name):
method _update_best_metric_result (line 65) | def _update_best_metric_result(self, dataset_name, metric, val, curren...
method model_ema (line 75) | def model_ema(self, decay=0.999):
method get_current_log (line 84) | def get_current_log(self):
method model_to_device (line 87) | def model_to_device(self, net):
method get_optimizer (line 103) | def get_optimizer(self, optim_type, params, lr, **kwargs):
method setup_schedulers (line 122) | def setup_schedulers(self):
method get_bare_model (line 135) | def get_bare_model(self, net):
method print_network (line 144) | def print_network(self, net):
method _set_lr (line 163) | def _set_lr(self, lr_groups_l):
method _get_init_lr (line 173) | def _get_init_lr(self):
method update_learning_rate (line 181) | def update_learning_rate(self, current_iter, warmup_iter=-1):
method get_current_learning_rate (line 204) | def get_current_learning_rate(self):
method save_network (line 208) | def save_network(self, net, net_label, current_iter, param_key='params'):
method _print_different_keys_loading (line 254) | def _print_different_keys_loading(self, crt_net, load_net, strict=True):
method load_network (line 289) | def load_network(self, net, load_path, strict=True, param_key='params'):
method save_training_state (line 318) | def save_training_state(self, epoch, current_iter):
method resume_training (line 352) | def resume_training(self, resume_state):
method reduce_loss_dict (line 367) | def reduce_loss_dict(self, loss_dict):
FILE: basicsr/models/edvr_model.py
class EDVRModel (line 7) | class EDVRModel(VideoBaseModel):
method __init__ (line 13) | def __init__(self, opt):
method setup_optimizers (line 18) | def setup_optimizers(self):
method optimize_parameters (line 48) | def optimize_parameters(self, current_iter):
FILE: basicsr/models/esrgan_model.py
class ESRGANModel (line 9) | class ESRGANModel(SRGANModel):
method optimize_parameters (line 12) | def optimize_parameters(self, current_iter):
FILE: basicsr/models/hifacegan_model.py
class HiFaceGANModel (line 15) | class HiFaceGANModel(SRModel):
method init_training_settings (line 21) | def init_training_settings(self):
method setup_optimizers (line 63) | def setup_optimizers(self):
method discriminate (line 74) | def discriminate(self, input_lq, output, ground_truth):
method _divide_pred (line 98) | def _divide_pred(pred):
method optimize_parameters (line 116) | def optimize_parameters(self, current_iter):
method validation (line 194) | def validation(self, dataloader, current_iter, tb_logger, save_img=Fal...
method nondist_validation (line 216) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
method save (line 282) | def save(self, epoch, current_iter):
FILE: basicsr/models/lr_scheduler.py
class MultiStepRestartLR (line 6) | class MultiStepRestartLR(_LRScheduler):
method __init__ (line 19) | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), r...
method get_lr (line 27) | def get_lr(self):
function get_position_from_periods (line 36) | def get_position_from_periods(iteration, cumulative_period):
class CosineAnnealingRestartLR (line 57) | class CosineAnnealingRestartLR(_LRScheduler):
method __init__ (line 77) | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=...
method get_lr (line 86) | def get_lr(self):
FILE: basicsr/models/realesrgan_model.py
class RealESRGANModel (line 17) | class RealESRGANModel(SRGANModel):
method __init__ (line 25) | def __init__(self, opt):
method _dequeue_and_enqueue (line 32) | def _dequeue_and_enqueue(self):
method feed_data (line 69) | def feed_data(self, data):
method nondist_validation (line 187) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
method optimize_parameters (line 193) | def optimize_parameters(self, current_iter):
FILE: basicsr/models/realesrnet_model.py
class RealESRNetModel (line 15) | class RealESRNetModel(SRModel):
method __init__ (line 24) | def __init__(self, opt):
method _dequeue_and_enqueue (line 31) | def _dequeue_and_enqueue(self):
method feed_data (line 68) | def feed_data(self, data):
method nondist_validation (line 185) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
FILE: basicsr/models/sr_model.py
class SRModel (line 15) | class SRModel(BaseModel):
method __init__ (line 18) | def __init__(self, opt):
method init_training_settings (line 35) | def init_training_settings(self):
method setup_optimizers (line 73) | def setup_optimizers(self):
method feed_data (line 87) | def feed_data(self, data):
method optimize_parameters (line 92) | def optimize_parameters(self, current_iter):
method test (line 121) | def test(self):
method test_selfensemble (line 132) | def test_selfensemble(self):
method dist_validation (line 180) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
method nondist_validation (line 184) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
method _log_validation_metric_values (line 251) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
method get_current_visuals (line 266) | def get_current_visuals(self):
method save (line 274) | def save(self, epoch, current_iter):
FILE: basicsr/models/srgan_model.py
class SRGANModel (line 12) | class SRGANModel(SRModel):
method init_training_settings (line 15) | def init_training_settings(self):
method setup_optimizers (line 74) | def setup_optimizers(self):
method optimize_parameters (line 85) | def optimize_parameters(self, current_iter):
method save (line 143) | def save(self, epoch, current_iter):
FILE: basicsr/models/stylegan2_model.py
class StyleGAN2Model (line 18) | class StyleGAN2Model(BaseModel):
method __init__ (line 21) | def __init__(self, opt):
method init_training_settings (line 42) | def init_training_settings(self):
method setup_optimizers (line 88) | def setup_optimizers(self):
method feed_data (line 169) | def feed_data(self, data):
method make_noise (line 172) | def make_noise(self, batch, num_noise):
method mixing_noise (line 179) | def mixing_noise(self, batch, prob):
method optimize_parameters (line 185) | def optimize_parameters(self, current_iter):
method test (line 256) | def test(self):
method dist_validation (line 261) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
method nondist_validation (line 265) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
method save (line 280) | def save(self, epoch, current_iter):
FILE: basicsr/models/swinir_model.py
class SwinIRModel (line 9) | class SwinIRModel(SRModel):
method test (line 11) | def test(self):
FILE: basicsr/models/video_base_model.py
class VideoBaseModel (line 15) | class VideoBaseModel(SRModel):
method dist_validation (line 18) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
method nondist_validation (line 113) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
method _log_validation_metric_values (line 118) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
FILE: basicsr/models/video_gan_model.py
class VideoGANModel (line 7) | class VideoGANModel(SRGANModel, VideoBaseModel):
FILE: basicsr/models/video_recurrent_gan_model.py
class VideoRecurrentGANModel (line 12) | class VideoRecurrentGANModel(VideoRecurrentModel):
method init_training_settings (line 14) | def init_training_settings(self):
method setup_optimizers (line 68) | def setup_optimizers(self):
method optimize_parameters (line 101) | def optimize_parameters(self, current_iter):
method save (line 174) | def save(self, epoch, current_iter):
FILE: basicsr/models/video_recurrent_model.py
class VideoRecurrentModel (line 15) | class VideoRecurrentModel(VideoBaseModel):
method __init__ (line 17) | def __init__(self, opt):
method setup_optimizers (line 22) | def setup_optimizers(self):
method optimize_parameters (line 52) | def optimize_parameters(self, current_iter):
method dist_validation (line 66) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
method test (line 176) | def test(self):
FILE: basicsr/ops/dcn/deform_conv.py
class DeformConvFunction (line 33) | class DeformConvFunction(Function):
method forward (line 36) | def forward(ctx,
method backward (line 75) | def backward(ctx, grad_output):
method _output_size (line 107) | def _output_size(input, weight, padding, dilation, stride):
class ModulatedDeformConvFunction (line 121) | class ModulatedDeformConvFunction(Function):
method forward (line 124) | def forward(ctx,
method backward (line 157) | def backward(ctx, grad_output):
method _infer_shape (line 177) | def _infer_shape(ctx, input, weight):
class DeformConv (line 191) | class DeformConv(nn.Module):
method __init__ (line 193) | def __init__(self,
method reset_parameters (line 225) | def reset_parameters(self):
method forward (line 232) | def forward(self, x, offset):
class DeformConvPack (line 248) | class DeformConvPack(DeformConv):
method __init__ (line 266) | def __init__(self, *args, **kwargs):
method init_offset (line 279) | def init_offset(self):
method forward (line 283) | def forward(self, x):
class ModulatedDeformConv (line 289) | class ModulatedDeformConv(nn.Module):
method __init__ (line 291) | def __init__(self,
method init_weights (line 322) | def init_weights(self):
method forward (line 331) | def forward(self, x, offset, mask):
class ModulatedDeformConvPack (line 336) | class ModulatedDeformConvPack(ModulatedDeformConv):
method __init__ (line 354) | def __init__(self, *args, **kwargs):
method init_weights (line 367) | def init_weights(self):
method forward (line 373) | def forward(self, x):
FILE: basicsr/ops/dcn/src/deform_conv_cuda.cpp
function shape_check (line 62) | void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOu...
function deform_conv_forward_cuda (line 152) | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
function deform_conv_backward_input_cuda (line 262) | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
function deform_conv_backward_parameters_cuda (line 376) | int deform_conv_backward_parameters_cuda(
function modulated_deform_conv_cuda_forward (line 490) | void modulated_deform_conv_cuda_forward(
function modulated_deform_conv_cuda_backward (line 571) | void modulated_deform_conv_cuda_backward(
FILE: basicsr/ops/dcn/src/deform_conv_ext.cpp
function deform_conv_forward (line 52) | int deform_conv_forward(at::Tensor input, at::Tensor weight,
function deform_conv_backward_input (line 70) | int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
function deform_conv_backward_parameters (line 89) | int deform_conv_backward_parameters(
function modulated_deform_conv_forward (line 107) | void modulated_deform_conv_forward(
function modulated_deform_conv_backward (line 127) | void modulated_deform_conv_backward(
function PYBIND11_MODULE (line 150) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: basicsr/ops/fused_act/fused_act.py
class FusedLeakyReLUFunctionBackward (line 30) | class FusedLeakyReLUFunctionBackward(Function):
method forward (line 33) | def forward(ctx, grad_output, out, negative_slope, scale):
method backward (line 52) | def backward(ctx, gradgrad_input, gradgrad_bias):
class FusedLeakyReLUFunction (line 60) | class FusedLeakyReLUFunction(Function):
method forward (line 63) | def forward(ctx, input, bias, negative_slope, scale):
method backward (line 73) | def backward(ctx, grad_output):
class FusedLeakyReLU (line 81) | class FusedLeakyReLU(nn.Module):
method __init__ (line 83) | def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
method forward (line 90) | def forward(self, input):
function fused_leaky_relu (line 94) | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
FILE: basicsr/ops/fused_act/src/fused_bias_act.cpp
function fused_bias_act (line 14) | torch::Tensor fused_bias_act(const torch::Tensor& input,
function PYBIND11_MODULE (line 24) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
function upfirdn2d (line 13) | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor&...
function PYBIND11_MODULE (line 22) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: basicsr/ops/upfirdn2d/upfirdn2d.py
class UpFirDn2dBackward (line 30) | class UpFirDn2dBackward(Function):
method forward (line 33) | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pa...
method backward (line 73) | def backward(ctx, gradgrad_input):
class UpFirDn2d (line 97) | class UpFirDn2d(Function):
method forward (line 100) | def forward(ctx, input, kernel, up, down, pad):
method backward (line 135) | def backward(ctx, grad_output):
function upfirdn2d (line 153) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
function upfirdn2d_native (line 162) | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, ...
FILE: basicsr/test.py
function test_pipeline (line 11) | def test_pipeline(root_path):
FILE: basicsr/train.py
function init_tb_loggers (line 17) | def init_tb_loggers(opt):
function create_train_val_dataloader (line 29) | def create_train_val_dataloader(opt, logger):
function load_resume_state (line 68) | def load_resume_state(opt):
function train_pipeline (line 91) | def train_pipeline(root_path):
FILE: basicsr/utils/color_util.py
function rgb2ycbcr (line 5) | def rgb2ycbcr(img, y_only=False):
function bgr2ycbcr (line 38) | def bgr2ycbcr(img, y_only=False):
function ycbcr2rgb (line 71) | def ycbcr2rgb(img):
function ycbcr2bgr (line 100) | def ycbcr2bgr(img):
function _convert_input_type_range (line 129) | def _convert_input_type_range(img):
function _convert_output_type_range (line 156) | def _convert_output_type_range(img, dst_type):
function rgb2ycbcr_pt (line 186) | def rgb2ycbcr_pt(img, y_only=False):
FILE: basicsr/utils/diffjpeg.py
function diff_round (line 26) | def diff_round(x):
function quality_to_factor (line 32) | def quality_to_factor(quality):
class RGB2YCbCrJpeg (line 49) | class RGB2YCbCrJpeg(nn.Module):
method __init__ (line 53) | def __init__(self):
method forward (line 60) | def forward(self, image):
class ChromaSubsampling (line 73) | class ChromaSubsampling(nn.Module):
method __init__ (line 77) | def __init__(self):
method forward (line 80) | def forward(self, image):
class BlockSplitting (line 98) | class BlockSplitting(nn.Module):
method __init__ (line 102) | def __init__(self):
method forward (line 106) | def forward(self, image):
class DCT8x8 (line 121) | class DCT8x8(nn.Module):
method __init__ (line 125) | def __init__(self):
method forward (line 134) | def forward(self, image):
class YQuantize (line 148) | class YQuantize(nn.Module):
method __init__ (line 155) | def __init__(self, rounding):
method forward (line 160) | def forward(self, image, factor=1):
class CQuantize (line 178) | class CQuantize(nn.Module):
method __init__ (line 185) | def __init__(self, rounding):
method forward (line 190) | def forward(self, image, factor=1):
class CompressJpeg (line 208) | class CompressJpeg(nn.Module):
method __init__ (line 215) | def __init__(self, rounding=torch.round):
method forward (line 222) | def forward(self, image, factor=1):
class YDequantize (line 247) | class YDequantize(nn.Module):
method __init__ (line 251) | def __init__(self):
method forward (line 255) | def forward(self, image, factor=1):
class CDequantize (line 272) | class CDequantize(nn.Module):
method __init__ (line 276) | def __init__(self):
method forward (line 280) | def forward(self, image, factor=1):
class iDCT8x8 (line 297) | class iDCT8x8(nn.Module):
method __init__ (line 301) | def __init__(self):
method forward (line 310) | def forward(self, image):
class BlockMerging (line 324) | class BlockMerging(nn.Module):
method __init__ (line 328) | def __init__(self):
method forward (line 331) | def forward(self, patches, height, width):
class ChromaUpsampling (line 348) | class ChromaUpsampling(nn.Module):
method __init__ (line 352) | def __init__(self):
method forward (line 355) | def forward(self, y, cb, cr):
class YCbCr2RGBJpeg (line 378) | class YCbCr2RGBJpeg(nn.Module):
method __init__ (line 382) | def __init__(self):
method forward (line 389) | def forward(self, image):
class DeCompressJpeg (line 401) | class DeCompressJpeg(nn.Module):
method __init__ (line 408) | def __init__(self, rounding=torch.round):
method forward (line 417) | def forward(self, y, cb, cr, imgh, imgw, factor=1):
class DiffJPEG (line 449) | class DiffJPEG(nn.Module):
method __init__ (line 457) | def __init__(self, differentiable=True):
method forward (line 467) | def forward(self, x, quality):
FILE: basicsr/utils/dist_util.py
function init_dist (line 10) | def init_dist(launcher, backend='nccl', **kwargs):
function _init_dist_pytorch (line 21) | def _init_dist_pytorch(backend, **kwargs):
function _init_dist_slurm (line 28) | def _init_dist_slurm(backend, port=None):
function get_dist_info (line 60) | def get_dist_info():
function master_only (line 74) | def master_only(func):
FILE: basicsr/utils/download_util.py
function download_file_from_google_drive (line 11) | def download_file_from_google_drive(file_id, save_path):
function get_confirm_token (line 41) | def get_confirm_token(response):
function save_response_content (line 48) | def save_response_content(response, destination, file_size=None, chunk_s...
function load_file_from_url (line 69) | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
FILE: basicsr/utils/file_client.py
class BaseStorageBackend (line 5) | class BaseStorageBackend(metaclass=ABCMeta):
method get (line 14) | def get(self, filepath):
method get_text (line 18) | def get_text(self, filepath):
class MemcachedBackend (line 22) | class MemcachedBackend(BaseStorageBackend):
method __init__ (line 32) | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
method get (line 47) | def get(self, filepath):
method get_text (line 54) | def get_text(self, filepath):
class HardDiskBackend (line 58) | class HardDiskBackend(BaseStorageBackend):
method get (line 61) | def get(self, filepath):
method get_text (line 67) | def get_text(self, filepath):
class LmdbBackend (line 74) | class LmdbBackend(BaseStorageBackend):
method __init__ (line 94) | def __init__(self, db_paths, client_keys='default', readonly=True, loc...
method get (line 114) | def get(self, filepath, client_key):
method get_text (line 128) | def get_text(self, filepath):
class FileClient (line 132) | class FileClient(object):
method __init__ (line 151) | def __init__(self, backend='disk', **kwargs):
method get (line 158) | def get(self, filepath, client_key='default'):
method get_text (line 166) | def get_text(self, filepath):
FILE: basicsr/utils/flow_util.py
function flowread (line 7) | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
function flowwrite (line 45) | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kw...
function quantize_flow (line 76) | def quantize_flow(flow, max_val=0.02, norm=True):
function dequantize_flow (line 102) | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
function quantize (line 126) | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
function dequantize (line 150) | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
FILE: basicsr/utils/img_process_util.py
function filter2D (line 7) | def filter2D(img, kernel):
function usm_sharp (line 34) | def usm_sharp(img, weight=0.5, radius=50, threshold=10):
class USMSharp (line 63) | class USMSharp(torch.nn.Module):
method __init__ (line 65) | def __init__(self, radius=50, sigma=0):
method forward (line 74) | def forward(self, img, weight=0.5, threshold=10):
FILE: basicsr/utils/img_util.py
function img2tensor (line 9) | def img2tensor(imgs, bgr2rgb=True, float32=True):
function tensor2img (line 38) | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
function tensor2img_fast (line 97) | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
function imfrombytes (line 114) | def imfrombytes(content, flag='color', float32=False):
function imwrite (line 135) | def imwrite(img, file_path, params=None, auto_mkdir=True):
function crop_border (line 156) | def crop_border(imgs, crop_border):
FILE: basicsr/utils/lmdb_util.py
function make_lmdb_from_imgs (line 9) | def make_lmdb_from_imgs(data_path,
function read_img_worker (line 135) | def read_img_worker(path, key, compress_level):
class LmdbMaker (line 159) | class LmdbMaker():
method __init__ (line 170) | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_l...
method put (line 185) | def put(self, img_byte, key, img_shape):
method close (line 196) | def close(self):
FILE: basicsr/utils/logger.py
class AvgTimer (line 10) | class AvgTimer():
method __init__ (line 12) | def __init__(self, window=200):
method start (line 20) | def start(self):
method record (line 23) | def record(self):
method get_current_time (line 38) | def get_current_time(self):
method get_avg_time (line 41) | def get_avg_time(self):
class MessageLogger (line 45) | class MessageLogger():
method __init__ (line 58) | def __init__(self, opt, start_iter=1, tb_logger=None):
method reset_start_time (line 68) | def reset_start_time(self):
method __call__ (line 72) | def __call__(self, log_vars):
function init_tb_logger (line 119) | def init_tb_logger(log_dir):
function init_wandb_logger (line 126) | def init_wandb_logger(opt):
function get_root_logger (line 146) | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_f...
function get_env_info (line 188) | def get_env_info():
FILE: basicsr/utils/matlab_functions.py
function cubic (line 6) | def cubic(x):
function calculate_weights_indices (line 16) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
function imresize (line 86) | def imresize(img, scale, antialiasing=True):
FILE: basicsr/utils/misc.py
function set_random_seed (line 11) | def set_random_seed(seed):
function get_time_str (line 20) | def get_time_str():
function mkdir_and_rename (line 24) | def mkdir_and_rename(path):
function make_exp_dirs (line 38) | def make_exp_dirs(opt):
function scandir (line 52) | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
function check_resume (line 94) | def check_resume(opt, resume_iter):
function sizeof_fmt (line 127) | def sizeof_fmt(size, suffix='B'):
FILE: basicsr/utils/options.py
function ordered_yaml (line 13) | def ordered_yaml():
function yaml_load (line 38) | def yaml_load(f):
function dict2str (line 54) | def dict2str(opt, indent_level=1):
function _postprocess_yml_value (line 75) | def _postprocess_yml_value(value):
function parse_options (line 99) | def parse_options(root_path, is_train=True):
function copy_opt_file (line 205) | def copy_opt_file(opt_file, experiments_root):
FILE: basicsr/utils/plot_util.py
function read_data_from_tensorboard (line 4) | def read_data_from_tensorboard(log_path, tag):
function read_data_from_txt_2v (line 23) | def read_data_from_txt_2v(path, pattern, step_one=False):
function read_data_from_txt_1v (line 48) | def read_data_from_txt_1v(path, pattern):
function smooth_data (line 68) | def smooth_data(values, smooth_weight):
FILE: basicsr/utils/registry.py
class Registry (line 4) | class Registry():
method __init__ (line 30) | def __init__(self, name):
method _do_register (line 38) | def _do_register(self, name, obj, suffix=None):
method register (line 46) | def register(self, obj=None, suffix=None):
method get (line 65) | def get(self, name, suffix='basicsr'):
method __contains__ (line 74) | def __contains__(self, name):
method __iter__ (line 77) | def __iter__(self):
method keys (line 80) | def keys(self):
FILE: docs/auto_generate_api.py
function scandir (line 5) | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
FILE: docs/conf.py
function auto_generate_api (line 78) | def auto_generate_api(app):
function setup (line 82) | def setup(app):
FILE: inference/inference_basicvsr.py
function inference (line 13) | def inference(imgs, imgnames, model, save_path):
function main (line 24) | def main():
FILE: inference/inference_basicvsrpp.py
function inference (line 13) | def inference(imgs, imgnames, model, save_path):
function main (line 24) | def main():
FILE: inference/inference_dfdnet.py
function get_part_location (line 20) | def get_part_location(landmarks):
FILE: inference/inference_esrgan.py
function main (line 11) | def main():
FILE: inference/inference_stylegan2.py
function generate (line 11) | def generate(args, g_ema, device, mean_latent, randomize_noise):
FILE: inference/inference_swinir.py
function main (line 13) | def main():
function define_model (line 79) | def define_model(args):
FILE: scripts/data_preparation/create_lmdb.py
function create_lmdb_for_div2k (line 8) | def create_lmdb_for_div2k():
function prepare_keys_div2k (line 47) | def prepare_keys_div2k(folder_path):
function create_lmdb_for_reds (line 64) | def create_lmdb_for_reds():
function prepare_keys_reds (line 89) | def prepare_keys_reds(folder_path):
function create_lmdb_for_vimeo90k (line 106) | def create_lmdb_for_vimeo90k():
function prepare_keys_vimeo90k (line 127) | def prepare_keys_vimeo90k(folder_path, train_list_path, mode):
FILE: scripts/data_preparation/download_datasets.py
function download_dataset (line 9) | def download_dataset(dataset, file_ids):
FILE: scripts/data_preparation/extract_images_from_tfrecords.py
function convert_celeba_tfrecords (line 10) | def convert_celeba_tfrecords(tf_file, log_resolution, save_root, save_ty...
function convert_ffhq_tfrecords (line 69) | def convert_ffhq_tfrecords(tf_file, log_resolution, save_root, save_type...
function make_ffhq_lmdb_from_imgs (line 119) | def make_ffhq_lmdb_from_imgs(folder_path, log_resolution, save_root, sav...
FILE: scripts/data_preparation/extract_subimages.py
function main (line 12) | def main():
function extract_subimages (line 79) | def extract_subimages(opt):
function worker (line 109) | def worker(path, opt):
FILE: scripts/data_preparation/generate_meta_info.py
function generate_meta_info_div2k (line 7) | def generate_meta_info_div2k():
FILE: scripts/data_preparation/prepare_hifacegan_dataset.py
class Mosaic16x (line 6) | class Mosaic16x:
method augment_image (line 13) | def augment_image(self, x):
class DegradationSimulator (line 25) | class DegradationSimulator:
method __init__ (line 35) | def __init__(self, ):
method create_training_dataset (line 80) | def create_training_dataset(self, deg, gt_folder, lq_folder=None):
FILE: scripts/data_preparation/regroup_reds_dataset.py
function regroup_reds_dataset (line 5) | def regroup_reds_dataset(train_path, val_path):
FILE: scripts/download_pretrained_models.py
function download_pretrained_models (line 8) | def download_pretrained_models(method, file_ids):
FILE: scripts/metrics/calculate_fid_folder.py
function calculate_fid_folder (line 11) | def calculate_fid_folder():
FILE: scripts/metrics/calculate_fid_stats_from_datasets.py
function calculate_stats_from_dataset (line 11) | def calculate_stats_from_dataset():
FILE: scripts/metrics/calculate_lpips.py
function main (line 15) | def main():
FILE: scripts/metrics/calculate_niqe.py
function main (line 10) | def main(args):
FILE: scripts/metrics/calculate_psnr_ssim.py
function main (line 10) | def main(args):
FILE: scripts/metrics/calculate_stylegan2_fid.py
function calculate_stylegan2_fid (line 11) | def calculate_stylegan2_fid():
FILE: scripts/model_conversion/convert_dfdnet.py
function convert_net (line 7) | def convert_net(ori_net, crt_net):
FILE: scripts/model_conversion/convert_models.py
function convert_edvr (line 4) | def convert_edvr():
function convert_edsr (line 102) | def convert_edsr(ori_net_path, crt_net_path, save_path, num_block=32):
function convert_rcan_model (line 138) | def convert_rcan_model():
function convert_esrgan_model (line 174) | def convert_esrgan_model():
function convert_duf_model (line 202) | def convert_duf_model():
FILE: scripts/model_conversion/convert_stylegan.py
function convert_net_g (line 6) | def convert_net_g(ori_net, crt_net):
function convert_net_d (line 49) | def convert_net_d(ori_net, crt_net):
FILE: scripts/plot/model_complexity_cmp_bsrn.py
function main (line 1) | def main():
FILE: scripts/publish_models.py
function update_sha (line 8) | def update_sha(paths):
function convert_to_backward_compatible_models (line 39) | def convert_to_backward_compatible_models(paths):
FILE: setup.py
function readme (line 12) | def readme():
function get_git_hash (line 18) | def get_git_hash():
function get_hash (line 43) | def get_hash():
function write_version_py (line 59) | def write_version_py():
function get_version (line 76) | def get_version():
function make_cuda_ext (line 82) | def make_cuda_ext(name, module, sources, sources_cuda=None):
function get_requirements (line 108) | def get_requirements(filename='requirements.txt'):
FILE: test_scripts/test_discriminator_backward.py
class ToyDiscriminator (line 7) | class ToyDiscriminator(nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 18) | def forward(self, x):
function main (line 26) | def main():
FILE: test_scripts/test_ffhq_dataset.py
function main (line 9) | def main():
FILE: test_scripts/test_lr_scheduler.py
function main (line 15) | def main():
FILE: test_scripts/test_niqe.py
function main (line 7) | def main():
FILE: test_scripts/test_paired_image_dataset.py
function main (line 8) | def main(mode='folder'):
FILE: test_scripts/test_reds_dataset.py
function main (line 8) | def main(mode='folder'):
FILE: test_scripts/test_vimeo90k_dataset.py
function main (line 8) | def main(mode='folder'):
FILE: tests/test_archs/test_basicvsr_arch.py
function test_basicvsr (line 6) | def test_basicvsr():
function test_convresidualblocks (line 16) | def test_convresidualblocks():
function test_iconvsr (line 26) | def test_iconvsr():
FILE: tests/test_archs/test_discriminator_arch.py
function test_vggstylediscriminator (line 7) | def test_vggstylediscriminator():
FILE: tests/test_archs/test_duf_arch.py
function test_duf (line 7) | def test_duf():
function test_dynamicupsamplingfilter (line 31) | def test_dynamicupsamplingfilter():
FILE: tests/test_archs/test_ecbsr_arch.py
function test_ecbsr (line 7) | def test_ecbsr():
function test_seqconv3x3 (line 23) | def test_seqconv3x3():
function test_ecb (line 57) | def test_ecb():
FILE: tests/test_archs/test_srresnet_arch.py
function test_msrresnet (line 6) | def test_msrresnet():
FILE: tests/test_data/test_paired_image_dataset.py
function test_pairedimagedataset (line 6) | def test_pairedimagedataset():
FILE: tests/test_data/test_single_image_dataset.py
function test_singleimagedataset (line 6) | def test_singleimagedataset():
FILE: tests/test_losses/test_losses.py
function test_pixellosses (line 8) | def test_pixellosses(loss_class):
function test_weightedtvloss (line 41) | def test_weightedtvloss():
FILE: tests/test_metrics/test_psnr_ssim.py
function test_calculate_psnr (line 7) | def test_calculate_psnr():
function test_calculate_ssim (line 26) | def test_calculate_ssim():
FILE: tests/test_models/test_sr_model.py
function test_srmodel (line 11) | def test_srmodel():
Condensed preview — 293 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (4,253K chars).
[
{
"path": ".github/workflows/publish-pip.yml",
"chars": 979,
"preview": "name: PyPI Publish\n\non: push\n\njobs:\n build-n-publish:\n runs-on: ubuntu-latest\n if: startsWith(github.event.ref, '"
},
{
"path": ".github/workflows/pylint.yml",
"chars": 707,
"preview": "name: PyLint\n\non: [push, pull_request]\n\njobs:\n build:\n\n runs-on: ubuntu-latest\n strategy:\n matrix:\n p"
},
{
"path": ".github/workflows/release.yml",
"chars": 1000,
"preview": "name: release\non:\n push:\n tags:\n - '*'\n\njobs:\n build:\n permissions: write-all\n name: Create Release\n "
},
{
"path": ".gitignore",
"chars": 1444,
"preview": "# ignored folders\ndatasets/*\nexperiments/*\nresults/*\ntb_logger/*\nwandb/*\ntmp/*\n\ndocs/api\nscripts/__init__.py\n\n*.DS_Store"
},
{
"path": ".pre-commit-config.yaml",
"chars": 1483,
"preview": "repos:\n # flake8\n - repo: https://github.com/PyCQA/flake8\n rev: 3.8.3\n hooks:\n - id: flake8\n args: ["
},
{
"path": ".readthedocs.yaml",
"chars": 709,
"preview": "# .readthedocs.yaml\n# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html f"
},
{
"path": ".vscode/settings.json",
"chars": 581,
"preview": "{\n \"files.trimTrailingWhitespace\": true,\n \"editor.wordWrap\": \"on\",\n \"editor.rulers\": [\n 80,\n 120\n"
},
{
"path": "CITATION.cff",
"chars": 546,
"preview": "cff-version: 1.2.0\nmessage: \"If you use this project, please cite it as below.\"\ntitle: \"BasicSR: Open Source Image and V"
},
{
"path": "LICENSE/LICENSE-NVIDIA",
"chars": 4665,
"preview": "Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\n\n\nNvidia Source Code License-NC\n\n=========================="
},
{
"path": "LICENSE/LICENSE-stylegan2-pytorch",
"chars": 1071,
"preview": "MIT License\n\nCopyright (c) 2019 Kim Seonghyeon\n\nPermission is hereby granted, free of charge, to any person obtaining a "
},
{
"path": "LICENSE/LICENSE_SwinIR",
"chars": 11348,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "LICENSE/LICENSE_pytorch-image-models",
"chars": 11343,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "LICENSE/README.md",
"chars": 2149,
"preview": "# License and Acknowledgement\n\nThis BasicSR project is released under the Apache 2.0 license.\n\n- StyleGAN2\n - The codes"
},
{
"path": "LICENSE.txt",
"chars": 11350,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "MANIFEST.in",
"chars": 287,
"preview": "include basicsr/ops/dcn/src/*.cu basicsr/ops/dcn/src/*.cpp\ninclude basicsr/ops/fused_act/src/*.cu basicsr/ops/fused_act/"
},
{
"path": "README.md",
"chars": 10288,
"preview": "<p align=\"center\">\n <img src=\"assets/basicsr_xpixel_logo.png\" height=120>\n</p>\n\n## <div align=\"center\"><b><a href=\"READ"
},
{
"path": "README_CN.md",
"chars": 11167,
"preview": "<p align=\"center\">\n <img src=\"assets/basicsr_xpixel_logo.png\" height=120>\n</p>\n\n## <div align=\"center\"><b><a href=\"READ"
},
{
"path": "VERSION",
"chars": 6,
"preview": "1.4.2\n"
},
{
"path": "basicsr/__init__.py",
"chars": 286,
"preview": "# https://github.com/xinntao/BasicSR\n# flake8: noqa\nfrom .archs import *\nfrom .data import *\nfrom .losses import *\nfrom "
},
{
"path": "basicsr/archs/__init__.py",
"chars": 884,
"preview": "import importlib\nfrom copy import deepcopy\nfrom os import path as osp\n\nfrom basicsr.utils import get_root_logger, scandi"
},
{
"path": "basicsr/archs/arch_util.py",
"chars": 11348,
"preview": "import collections.abc\nimport math\nimport torch\nimport torchvision\nimport warnings\nfrom distutils.version import LooseVe"
},
{
"path": "basicsr/archs/basicvsr_arch.py",
"chars": 12595,
"preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry import ARCH_RE"
},
{
"path": "basicsr/archs/basicvsrpp_arch.py",
"chars": 16542,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\nimport warnings\n\nfrom basicsr.arch"
},
{
"path": "basicsr/archs/dfdnet_arch.py",
"chars": 7464,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.spectral_norm "
},
{
"path": "basicsr/archs/dfdnet_util.py",
"chars": 5213,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Function\nfrom torch.nn.uti"
},
{
"path": "basicsr/archs/discriminator_arch.py",
"chars": 6813,
"preview": "from torch import nn as nn\nfrom torch.nn import functional as F\nfrom torch.nn.utils import spectral_norm\n\nfrom basicsr.u"
},
{
"path": "basicsr/archs/duf_arch.py",
"chars": 11751,
"preview": "import numpy as np\nimport torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.regi"
},
{
"path": "basicsr/archs/ecbsr_arch.py",
"chars": 11973,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\nc"
},
{
"path": "basicsr/archs/edsr_arch.py",
"chars": 2162,
"preview": "import torch\nfrom torch import nn as nn\n\nfrom basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer\nfro"
},
{
"path": "basicsr/archs/edvr_arch.py",
"chars": 16239,
"preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry import ARCH_RE"
},
{
"path": "basicsr/archs/hifacegan_arch.py",
"chars": 8933,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom basicsr.utils.registry impor"
},
{
"path": "basicsr/archs/hifacegan_util.py",
"chars": 9644,
"preview": "import re\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n# Warning: spectr"
},
{
"path": "basicsr/archs/inception.py",
"chars": 12055,
"preview": "# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501\n# For FID metr"
},
{
"path": "basicsr/archs/rcan_arch.py",
"chars": 4587,
"preview": "import torch\nfrom torch import nn as nn\n\nfrom basicsr.utils.registry import ARCH_REGISTRY\nfrom .arch_util import Upsampl"
},
{
"path": "basicsr/archs/ridnet_arch.py",
"chars": 6306,
"preview": "import torch\nimport torch.nn as nn\n\nfrom basicsr.utils.registry import ARCH_REGISTRY\nfrom .arch_util import ResidualBloc"
},
{
"path": "basicsr/archs/rrdbnet_arch.py",
"chars": 4621,
"preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry import ARCH_RE"
},
{
"path": "basicsr/archs/spynet_arch.py",
"chars": 3776,
"preview": "import math\nimport torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry im"
},
{
"path": "basicsr/archs/srresnet_arch.py",
"chars": 2665,
"preview": "from torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry import ARCH_REGISTRY\nfrom ."
},
{
"path": "basicsr/archs/srvgg_arch.py",
"chars": 2739,
"preview": "from torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\n@ARC"
},
{
"path": "basicsr/archs/stylegan2_arch.py",
"chars": 30194,
"preview": "import math\nimport random\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom basicsr.ops.fused"
},
{
"path": "basicsr/archs/stylegan2_bilinear_arch.py",
"chars": 22326,
"preview": "import math\nimport random\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom basicsr.ops.fused"
},
{
"path": "basicsr/archs/swinir_arch.py",
"chars": 37271,
"preview": "# Modified from https://github.com/JingyunLiang/SwinIR\n# SwinIR: Image Restoration Using Swin Transformer, https://arxiv"
},
{
"path": "basicsr/archs/tof_arch.py",
"chars": 6160,
"preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry import ARCH_RE"
},
{
"path": "basicsr/archs/vgg_arch.py",
"chars": 6142,
"preview": "import os\nimport torch\nfrom collections import OrderedDict\nfrom torch import nn as nn\nfrom torchvision.models import vgg"
},
{
"path": "basicsr/data/__init__.py",
"chars": 4332,
"preview": "import importlib\nimport numpy as np\nimport random\nimport torch\nimport torch.utils.data\nfrom copy import deepcopy\nfrom fu"
},
{
"path": "basicsr/data/data_sampler.py",
"chars": 1639,
"preview": "import math\nimport torch\nfrom torch.utils.data.sampler import Sampler\n\n\nclass EnlargedSampler(Sampler):\n \"\"\"Sampler t"
},
{
"path": "basicsr/data/data_util.py",
"chars": 11806,
"preview": "import cv2\nimport numpy as np\nimport torch\nfrom os import path as osp\nfrom torch.nn import functional as F\n\nfrom basicsr"
},
{
"path": "basicsr/data/degradations.py",
"chars": 28194,
"preview": "import cv2\nimport math\nimport numpy as np\nimport random\nimport torch\nfrom scipy import special\nfrom scipy.stats import m"
},
{
"path": "basicsr/data/ffhq_dataset.py",
"chars": 3023,
"preview": "import random\nimport time\nfrom os import path as osp\nfrom torch.utils import data as data\nfrom torchvision.transforms.fu"
},
{
"path": "basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt",
"chars": 847392,
"preview": "0001_s001.png (480,480,3)\n0001_s002.png (480,480,3)\n0001_s003.png (480,480,3)\n0001_s004.png (480,480,3)\n0001_s005.png (4"
},
{
"path": "basicsr/data/meta_info/meta_info_REDS4_test_GT.txt",
"chars": 84,
"preview": "000 100 (720,1280,3)\n011 100 (720,1280,3)\n015 100 (720,1280,3)\n020 100 (720,1280,3)\n"
},
{
"path": "basicsr/data/meta_info/meta_info_REDS_GT.txt",
"chars": 5670,
"preview": "000 100 (720,1280,3)\n001 100 (720,1280,3)\n002 100 (720,1280,3)\n003 100 (720,1280,3)\n004 100 (720,1280,3)\n005 100 (720,12"
},
{
"path": "basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt",
"chars": 84,
"preview": "240 100 (720,1280,3)\n241 100 (720,1280,3)\n246 100 (720,1280,3)\n257 100 (720,1280,3)\n"
},
{
"path": "basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt",
"chars": 630,
"preview": "240 100 (720,1280,3)\n241 100 (720,1280,3)\n242 100 (720,1280,3)\n243 100 (720,1280,3)\n244 100 (720,1280,3)\n245 100 (720,12"
},
{
"path": "basicsr/data/meta_info/meta_info_Vimeo90K_test_GT.txt",
"chars": 195600,
"preview": "00001/0266 7 (256,448,3)\n00001/0268 7 (256,448,3)\n00001/0275 7 (256,448,3)\n00001/0278 7 (256,448,3)\n00001/0285 7 (256,44"
},
{
"path": "basicsr/data/meta_info/meta_info_Vimeo90K_test_fast_GT.txt",
"chars": 30625,
"preview": "00001/0625 7 (256,448,3)\n00001/0632 7 (256,448,3)\n00001/0807 7 (256,448,3)\n00001/0832 7 (256,448,3)\n00001/0834 7 (256,44"
},
{
"path": "basicsr/data/meta_info/meta_info_Vimeo90K_test_medium_GT.txt",
"chars": 124425,
"preview": "00001/0285 7 (256,448,3)\n00001/0619 7 (256,448,3)\n00001/0622 7 (256,448,3)\n00001/0628 7 (256,448,3)\n00001/0629 7 (256,44"
},
{
"path": "basicsr/data/meta_info/meta_info_Vimeo90K_test_slow_GT.txt",
"chars": 40325,
"preview": "00001/0266 7 (256,448,3)\n00001/0268 7 (256,448,3)\n00001/0275 7 (256,448,3)\n00001/0278 7 (256,448,3)\n00001/0287 7 (256,44"
},
{
"path": "basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt",
"chars": 1615300,
"preview": "00001/0001 7 (256,448,3)\n00001/0002 7 (256,448,3)\n00001/0003 7 (256,448,3)\n00001/0004 7 (256,448,3)\n00001/0005 7 (256,44"
},
{
"path": "basicsr/data/paired_image_dataset.py",
"chars": 4972,
"preview": "from torch.utils import data as data\nfrom torchvision.transforms.functional import normalize\n\nfrom basicsr.data.data_uti"
},
{
"path": "basicsr/data/prefetch_dataloader.py",
"chars": 3137,
"preview": "import queue as Queue\nimport threading\nimport torch\nfrom torch.utils.data import DataLoader\n\n\nclass PrefetchGenerator(th"
},
{
"path": "basicsr/data/realesrgan_dataset.py",
"chars": 8750,
"preview": "import cv2\nimport math\nimport numpy as np\nimport os\nimport os.path as osp\nimport random\nimport time\nimport torch\nfrom to"
},
{
"path": "basicsr/data/realesrgan_paired_dataset.py",
"chars": 4974,
"preview": "import os\nfrom torch.utils import data as data\nfrom torchvision.transforms.functional import normalize\n\nfrom basicsr.dat"
},
{
"path": "basicsr/data/reds_dataset.py",
"chars": 15157,
"preview": "import numpy as np\nimport random\nimport torch\nfrom pathlib import Path\nfrom torch.utils import data as data\n\nfrom basics"
},
{
"path": "basicsr/data/single_image_dataset.py",
"chars": 2690,
"preview": "from os import path as osp\nfrom torch.utils import data as data\nfrom torchvision.transforms.functional import normalize\n"
},
{
"path": "basicsr/data/transforms.py",
"chars": 6225,
"preview": "import cv2\nimport random\nimport torch\n\n\ndef mod_crop(img, scale):\n \"\"\"Mod crop images, used during testing.\n\n Args"
},
{
"path": "basicsr/data/video_test_dataset.py",
"chars": 11963,
"preview": "import glob\nimport torch\nfrom os import path as osp\nfrom torch.utils import data as data\n\nfrom basicsr.data.data_util im"
},
{
"path": "basicsr/data/vimeo90k_dataset.py",
"chars": 6945,
"preview": "import random\nimport torch\nfrom pathlib import Path\nfrom torch.utils import data as data\n\nfrom basicsr.data.transforms i"
},
{
"path": "basicsr/losses/__init__.py",
"chars": 1149,
"preview": "import importlib\nfrom copy import deepcopy\nfrom os import path as osp\n\nfrom basicsr.utils import get_root_logger, scandi"
},
{
"path": "basicsr/losses/basic_loss.py",
"chars": 9246,
"preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.archs.vgg_arch import VGGFeat"
},
{
"path": "basicsr/losses/gan_loss.py",
"chars": 7486,
"preview": "import math\nimport torch\nfrom torch import autograd as autograd\nfrom torch import nn as nn\nfrom torch.nn import function"
},
{
"path": "basicsr/losses/loss_util.py",
"chars": 4772,
"preview": "import functools\nimport torch\nfrom torch.nn import functional as F\n\n\ndef reduce_loss(loss, reduction):\n \"\"\"Reduce los"
},
{
"path": "basicsr/metrics/README.md",
"chars": 1503,
"preview": "# Metrics\n\n[English](README.md) **|** [简体中文](README_CN.md)\n\n- [约定](#约定)\n- [PSNR 和 SSIM](#psnr-和-ssim)\n\n## 约定\n\n因为不同的输入类型会"
},
{
"path": "basicsr/metrics/README_CN.md",
"chars": 1503,
"preview": "# Metrics\n\n[English](README.md) **|** [简体中文](README_CN.md)\n\n- [约定](#约定)\n- [PSNR 和 SSIM](#psnr-和-ssim)\n\n## 约定\n\n因为不同的输入类型会"
},
{
"path": "basicsr/metrics/__init__.py",
"chars": 557,
"preview": "from copy import deepcopy\n\nfrom basicsr.utils.registry import METRIC_REGISTRY\nfrom .niqe import calculate_niqe\nfrom .psn"
},
{
"path": "basicsr/metrics/fid.py",
"chars": 3209,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom scipy import linalg\nfrom tqdm import tqdm\n\nfrom basicsr.archs"
},
{
"path": "basicsr/metrics/metric_util.py",
"chars": 1268,
"preview": "import numpy as np\n\nfrom basicsr.utils import bgr2ycbcr\n\n\ndef reorder_image(img, input_order='HWC'):\n \"\"\"Reorder imag"
},
{
"path": "basicsr/metrics/niqe.py",
"chars": 8356,
"preview": "import cv2\nimport math\nimport numpy as np\nimport os\nfrom scipy.ndimage import convolve\nfrom scipy.special import gamma\n\n"
},
{
"path": "basicsr/metrics/psnr_ssim.py",
"chars": 8488,
"preview": "import cv2\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom basicsr.metrics.metric_util import reor"
},
{
"path": "basicsr/metrics/test_metrics/test_psnr_ssim.py",
"chars": 2598,
"preview": "import cv2\nimport torch\n\nfrom basicsr.metrics import calculate_psnr, calculate_ssim\nfrom basicsr.metrics.psnr_ssim impor"
},
{
"path": "basicsr/models/__init__.py",
"chars": 1011,
"preview": "import importlib\nfrom copy import deepcopy\nfrom os import path as osp\n\nfrom basicsr.utils import get_root_logger, scandi"
},
{
"path": "basicsr/models/base_model.py",
"chars": 15818,
"preview": "import os\nimport time\nimport torch\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom torch.nn.parallel "
},
{
"path": "basicsr/models/edvr_model.py",
"chars": 2429,
"preview": "from basicsr.utils import get_root_logger\nfrom basicsr.utils.registry import MODEL_REGISTRY\nfrom .video_base_model impor"
},
{
"path": "basicsr/models/esrgan_model.py",
"chars": 3174,
"preview": "import torch\nfrom collections import OrderedDict\n\nfrom basicsr.utils.registry import MODEL_REGISTRY\nfrom .srgan_model im"
},
{
"path": "basicsr/models/hifacegan_model.py",
"chars": 11766,
"preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom basicsr.archs im"
},
{
"path": "basicsr/models/lr_scheduler.py",
"chars": 3956,
"preview": "import math\nfrom collections import Counter\nfrom torch.optim.lr_scheduler import _LRScheduler\n\n\nclass MultiStepRestartLR"
},
{
"path": "basicsr/models/realesrgan_model.py",
"chars": 12261,
"preview": "import numpy as np\nimport random\nimport torch\nfrom collections import OrderedDict\nfrom torch.nn import functional as F\n\n"
},
{
"path": "basicsr/models/realesrnet_model.py",
"chars": 9115,
"preview": "import numpy as np\nimport random\nimport torch\nfrom torch.nn import functional as F\n\nfrom basicsr.data.degradations impor"
},
{
"path": "basicsr/models/sr_model.py",
"chars": 10932,
"preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom basicsr.archs im"
},
{
"path": "basicsr/models/srgan_model.py",
"chars": 5835,
"preview": "import torch\nfrom collections import OrderedDict\n\nfrom basicsr.archs import build_network\nfrom basicsr.losses import bui"
},
{
"path": "basicsr/models/stylegan2_model.py",
"chars": 11642,
"preview": "import cv2\nimport math\nimport numpy as np\nimport random\nimport torch\nfrom collections import OrderedDict\nfrom os import "
},
{
"path": "basicsr/models/swinir_model.py",
"chars": 1115,
"preview": "import torch\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry import MODEL_REGISTRY\nfrom .sr_model impo"
},
{
"path": "basicsr/models/video_base_model.py",
"chars": 7433,
"preview": "import torch\nfrom collections import Counter\nfrom os import path as osp\nfrom torch import distributed as dist\nfrom tqdm "
},
{
"path": "basicsr/models/video_gan_model.py",
"chars": 507,
"preview": "from basicsr.utils.registry import MODEL_REGISTRY\nfrom .srgan_model import SRGANModel\nfrom .video_base_model import Vide"
},
{
"path": "basicsr/models/video_recurrent_gan_model.py",
"chars": 7142,
"preview": "import torch\nfrom collections import OrderedDict\n\nfrom basicsr.archs import build_network\nfrom basicsr.losses import bui"
},
{
"path": "basicsr/models/video_recurrent_model.py",
"chars": 8168,
"preview": "import torch\nfrom collections import Counter\nfrom os import path as osp\nfrom torch import distributed as dist\nfrom tqdm "
},
{
"path": "basicsr/ops/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "basicsr/ops/dcn/__init__.py",
"chars": 306,
"preview": "from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,\n "
},
{
"path": "basicsr/ops/dcn/deform_conv.py",
"chars": 15724,
"preview": "import math\nimport os\nimport torch\nfrom torch import nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.fu"
},
{
"path": "basicsr/ops/dcn/src/deform_conv_cuda.cpp",
"chars": 28842,
"preview": "// modify from\n// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/def"
},
{
"path": "basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu",
"chars": 42622,
"preview": "/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contribu"
},
{
"path": "basicsr/ops/dcn/src/deform_conv_ext.cpp",
"chars": 7492,
"preview": "// modify from\n// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/def"
},
{
"path": "basicsr/ops/fused_act/__init__.py",
"chars": 106,
"preview": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\n\n__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']\n"
},
{
"path": "basicsr/ops/fused_act/fused_act.py",
"chars": 2941,
"preview": "# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501\n\nimport os\nimport "
},
{
"path": "basicsr/ops/fused_act/src/fused_bias_act.cpp",
"chars": 1092,
"preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp\n#include <torch/extension.h>\n\n"
},
{
"path": "basicsr/ops/fused_act/src/fused_bias_act_kernel.cu",
"chars": 2874,
"preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu\n// Copyright (c) 2019, N"
},
{
"path": "basicsr/ops/upfirdn2d/__init__.py",
"chars": 58,
"preview": "from .upfirdn2d import upfirdn2d\n\n__all__ = ['upfirdn2d']\n"
},
{
"path": "basicsr/ops/upfirdn2d/src/upfirdn2d.cpp",
"chars": 1052,
"preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp\n#include <torch/extension.h>\n\n\ntorc"
},
{
"path": "basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu",
"chars": 11803,
"preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu\n// Copyright (c) 2019, NVIDIA"
},
{
"path": "basicsr/ops/upfirdn2d/upfirdn2d.py",
"chars": 6085,
"preview": "# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501\n\nimport os\nimport"
},
{
"path": "basicsr/test.py",
"chars": 1730,
"preview": "import logging\nimport torch\nfrom os import path as osp\n\nfrom basicsr.data import build_dataloader, build_dataset\nfrom ba"
},
{
"path": "basicsr/train.py",
"chars": 9672,
"preview": "import datetime\nimport logging\nimport math\nimport time\nimport torch\nfrom os import path as osp\n\nfrom basicsr.data import"
},
{
"path": "basicsr/utils/__init__.py",
"chars": 1220,
"preview": "from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb\nfrom .diffjpeg import DiffJPEG\nfrom .fi"
},
{
"path": "basicsr/utils/color_util.py",
"chars": 7981,
"preview": "import numpy as np\nimport torch\n\n\ndef rgb2ycbcr(img, y_only=False):\n \"\"\"Convert a RGB image to YCbCr image.\n\n This"
},
{
"path": "basicsr/utils/diffjpeg.py",
"chars": 15666,
"preview": "\"\"\"\nModified from https://github.com/mlomnitz/DiffJPEG\n\nFor images not divisible by 8\nhttps://dsp.stackexchange.com/ques"
},
{
"path": "basicsr/utils/dist_util.py",
"chars": 2608,
"preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501\nimport functools\n"
},
{
"path": "basicsr/utils/download_util.py",
"chars": 3341,
"preview": "import math\nimport os\nimport requests\nfrom torch.hub import download_url_to_file, get_dir\nfrom tqdm import tqdm\nfrom url"
},
{
"path": "basicsr/utils/file_client.py",
"chars": 6014,
"preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501\nfrom abc import "
},
{
"path": "basicsr/utils/flow_util.py",
"chars": 6159,
"preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501\nimport cv2\nimport num"
},
{
"path": "basicsr/utils/img_process_util.py",
"chars": 2563,
"preview": "import cv2\nimport numpy as np\nimport torch\nfrom torch.nn import functional as F\n\n\ndef filter2D(img, kernel):\n \"\"\"PyTo"
},
{
"path": "basicsr/utils/img_util.py",
"chars": 6195,
"preview": "import cv2\nimport math\nimport numpy as np\nimport os\nimport torch\nfrom torchvision.utils import make_grid\n\n\ndef img2tenso"
},
{
"path": "basicsr/utils/lmdb_util.py",
"chars": 7130,
"preview": "import cv2\nimport lmdb\nimport sys\nfrom multiprocessing import Pool\nfrom os import path as osp\nfrom tqdm import tqdm\n\n\nde"
},
{
"path": "basicsr/utils/logger.py",
"chars": 7148,
"preview": "import datetime\nimport logging\nimport time\n\nfrom .dist_util import get_dist_info, master_only\n\ninitialized_logger = {}\n\n"
},
{
"path": "basicsr/utils/matlab_functions.py",
"chars": 6962,
"preview": "import math\nimport numpy as np\nimport torch\n\n\ndef cubic(x):\n \"\"\"cubic function used for calculate_weights_indices.\"\"\""
},
{
"path": "basicsr/utils/misc.py",
"chars": 4655,
"preview": "import numpy as np\nimport os\nimport random\nimport time\nimport torch\nfrom os import path as osp\n\nfrom .dist_util import m"
},
{
"path": "basicsr/utils/options.py",
"chars": 6998,
"preview": "import argparse\nimport os\nimport random\nimport torch\nimport yaml\nfrom collections import OrderedDict\nfrom os import path"
},
{
"path": "basicsr/utils/plot_util.py",
"chars": 2525,
"preview": "import re\n\n\ndef read_data_from_tensorboard(log_path, tag):\n \"\"\"Get raw data (steps and values) from tensorboard event"
},
{
"path": "basicsr/utils/registry.py",
"chars": 2477,
"preview": "# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501\n\n\nclass "
},
{
"path": "colab/README.md",
"chars": 821,
"preview": "# Colab\n\n<a href=\"https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing\"><img src=\"https:"
},
{
"path": "docs/BasicSR_docs_CN.md",
"chars": 183,
"preview": "# BasicSR 中文文档\n\n我们提供了更完整的 BasicSR 中文解读文档 PDF,你所需要的内容可以在相应的章节中找到。\n\n文档的最新版可以从 [BasicSR-docs/releases](https://github.com/X"
},
{
"path": "docs/Config.md",
"chars": 11720,
"preview": "# Configuration\n\n[English](Config.md) **|** [简体中文](BasicSR_docs_CN.md)\n\n#### Contents\n\n1. [Experiment Name Convention](#"
},
{
"path": "docs/DatasetPreparation.md",
"chars": 15317,
"preview": "# Dataset Preparation\n\n[English](DatasetPreparation.md) **|** [简体中文](DatasetPreparation_CN.md)\n\n📁 Dataset Download: ⏬ [G"
},
{
"path": "docs/DatasetPreparation_CN.md",
"chars": 10441,
"preview": "# 数据准备\n\n[English](DatasetPreparation.md) **|** [简体中文](DatasetPreparation_CN.md)\n\n#### 目录\n\n1. [数据存储形式](#数据存储形式)\n 1. [如"
},
{
"path": "docs/DesignConvention.md",
"chars": 4341,
"preview": "# Codebase Designs and Conventions\n\n[English](DesignConvention.md) **|** [简体中文](BasicSR_docs_CN.md)\n\n#### Contents\n\n1. ["
},
{
"path": "docs/FAQ.md",
"chars": 6,
"preview": "# FAQ\n"
},
{
"path": "docs/HOWTOs.md",
"chars": 5149,
"preview": "# HOWTOs\n\n[English](HOWTOs.md) **|** [简体中文](HOWTOs_CN.md)\n\n## How to train StyleGAN2\n\n1. Prepare training dataset: [FFHQ"
},
{
"path": "docs/HOWTOs_CN.md",
"chars": 4545,
"preview": "# HOWTOs\n\n[English](HOWTOs.md) **|** [简体中文](HOWTOs_CN.md)\n\n## 如何训练 StyleGAN2\n\n1. 准备训练数据集: [FFHQ](https://github.com/NVla"
},
{
"path": "docs/INSTALL.md",
"chars": 5574,
"preview": "# Installation\n\n## Contents\n\n- [Requirements](#requirements)\n- [BASICSR_EXT and BASICSR_JIT environment variables](#basi"
},
{
"path": "docs/Logging.md",
"chars": 1629,
"preview": "# Logging\n\n[English](Logging.md) **|** [简体中文](Logging_CN.md)\n\n#### Contents\n\n1. [Text Logger](#Text-Logger)\n1. [Tensorbo"
},
{
"path": "docs/Logging_CN.md",
"chars": 1195,
"preview": "# Logging日志\n\n[English](Logging.md) **|** [简体中文](Logging_CN.md)\n\n#### 目录\n\n1. [文本屏幕日志](#文本屏幕日志)\n1. [Tensorboard日志](#Tensor"
},
{
"path": "docs/Makefile",
"chars": 587,
"preview": "# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS ?="
},
{
"path": "docs/Metrics.md",
"chars": 1782,
"preview": "# Metrics\n\n[English](Metrics.md) **|** [简体中文](Metrics_CN.md)\n\n## PSNR and SSIM\n\n## NIQE\n\n## FID\n\n> FID measures the simi"
},
{
"path": "docs/Metrics_CN.md",
"chars": 1753,
"preview": "# 评价指标\n\n[English](Metrics.md) **|** [简体中文](Metrics_CN.md)\n\n## PSNR and SSIM\n\n## NIQE\n\n## FID\n\n> FID measures the similar"
},
{
"path": "docs/ModelZoo.md",
"chars": 8642,
"preview": "# Model Zoo and Baselines\n\n[English](ModelZoo.md) **|** [简体中文](ModelZoo_CN.md)\n\nDownload: ⏬ Google Drive: [Pretrained Mo"
},
{
"path": "docs/ModelZoo_CN.md",
"chars": 8083,
"preview": "# 模型库和基准\n\n[English](ModelZoo.md) **|** [简体中文](ModelZoo_CN.md)\n\n:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s"
},
{
"path": "docs/Models.md",
"chars": 1254,
"preview": "# Models\n\n[English](Models.md) **|** [简体中文](BasicSR_docs_CN.md)\n\n#### Contents\n\n1. [Supported Models](#Supported-Models)"
},
{
"path": "docs/README.md",
"chars": 1179,
"preview": "# BasicSR docs\n\nThis folder includes:\n\n- Auto-generated API in [*basicsr.readthedocs.io*](https://basicsr.readthedocs.io"
},
{
"path": "docs/TrainTest.md",
"chars": 5183,
"preview": "# Training and Testing\n\n[English](TrainTest.md) **|** [简体中文](TrainTest_CN.md)\n\nPlease run the commands in the root path "
},
{
"path": "docs/TrainTest_CN.md",
"chars": 4534,
"preview": "# 训练和测试\n\n[English](TrainTest.md) **|** [简体中文](TrainTest_CN.md)\n\n所有的命令都在 `BasicSR` 的根目录下运行. <br>\n一般来说, 训练和测试都有以下的步骤:\n\n1. "
},
{
"path": "docs/auto_generate_api.py",
"chars": 3399,
"preview": "import os\nfrom os import path as osp\n\n\ndef scandir(dir_path, suffix=None, recursive=False, full_path=False):\n \"\"\"Scan"
},
{
"path": "docs/conf.py",
"chars": 2934,
"preview": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common op"
},
{
"path": "docs/history_updates.md",
"chars": 2156,
"preview": "# History of New Features/Updates\n\n:triangular_flag_on_post: **New Features/Updates**\n\n- :white_check_mark: Oct 5, 2021."
},
{
"path": "docs/index.rst",
"chars": 258,
"preview": "Welcome to BasicSR's documentation!\n===================================\n\n.. toctree::\n :maxdepth: 4\n :caption: API\n\n"
},
{
"path": "docs/introduction.md",
"chars": 548,
"preview": "# Introduction\n\n## Codebase Designs and Conventions\n\nPlease see [DesignConvention.md](DesignConvention.md) for the desig"
},
{
"path": "docs/make.bat",
"chars": 759,
"preview": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-bu"
},
{
"path": "docs/requirements.txt",
"chars": 246,
"preview": "# add all requirements to auto generate the docs\naddict\nfuture\nlmdb\nnumpy\nopencv-python\nPillow\npyyaml\nrecommonmark\nreque"
},
{
"path": "inference/inference_basicvsr.py",
"chars": 2737,
"preview": "import argparse\nimport cv2\nimport glob\nimport os\nimport shutil\nimport torch\n\nfrom basicsr.archs.basicvsr_arch import Bas"
},
{
"path": "inference/inference_basicvsrpp.py",
"chars": 2774,
"preview": "import argparse\nimport cv2\nimport glob\nimport os\nimport shutil\nimport torch\n\nfrom basicsr.archs.basicvsrpp_arch import B"
},
{
"path": "inference/inference_dfdnet.py",
"chars": 8817,
"preview": "import argparse\nimport glob\nimport numpy as np\nimport os\nimport torch\nimport torchvision.transforms as transforms\nfrom s"
},
{
"path": "inference/inference_esrgan.py",
"chars": 1940,
"preview": "import argparse\nimport cv2\nimport glob\nimport numpy as np\nimport os\nimport torch\n\nfrom basicsr.archs.rrdbnet_arch import"
},
{
"path": "inference/inference_ridnet.py",
"chars": 1990,
"preview": "import argparse\nimport cv2\nimport glob\nimport numpy as np\nimport os\nimport torch\nfrom tqdm import tqdm\n\nfrom basicsr.arc"
},
{
"path": "inference/inference_stylegan2.py",
"chars": 2262,
"preview": "import argparse\nimport math\nimport os\nimport torch\nfrom torchvision import utils\n\nfrom basicsr.archs.stylegan2_arch impo"
},
{
"path": "inference/inference_swinir.py",
"chars": 6842,
"preview": "# Modified from https://github.com/JingyunLiang/SwinIR\nimport argparse\nimport cv2\nimport glob\nimport numpy as np\nimport "
},
{
"path": "options/test/BasicVSR/test_BasicVSR_REDS.yml",
"chars": 1000,
"preview": "name: BasicVSR_REDS\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed: 0\n\nd"
},
{
"path": "options/test/BasicVSR/test_BasicVSR_Vimeo90K_BDx4.yml",
"chars": 1141,
"preview": "name: BasicVSR_Vimeo90K_BIx4\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_s"
},
{
"path": "options/test/BasicVSR/test_BasicVSR_Vimeo90K_BIx4.yml",
"chars": 1141,
"preview": "name: BasicVSR_Vimeo90K_BIx4\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_s"
},
{
"path": "options/test/BasicVSR/test_IconVSR_REDS.yml",
"chars": 1137,
"preview": "name: IconVSR_REDS\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed: 0\n\nda"
},
{
"path": "options/test/BasicVSR/test_IconVSR_Vimeo90K_BDx4.yml",
"chars": 1283,
"preview": "name: IconVSR_Vimeo90K_BDx4\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_se"
},
{
"path": "options/test/BasicVSR/test_IconVSR_Vimeo90K_BIx4.yml",
"chars": 1283,
"preview": "name: IconVSR_Vimeo90K_BIx4\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_se"
},
{
"path": "options/test/DUF/test_DUF_official.yml",
"chars": 945,
"preview": "name: DUF_x4_52L_official\nmodel_type: VideoBaseModel\nscale: 4\nnum_gpu: 8 # set num_gpu: 0 for cpu mode\nmanual_seed: 0\n\n"
},
{
"path": "options/test/EDSR/test_EDSR_Lx2.yml",
"chars": 1423,
"preview": "name: EDSR_Lx2_f256b32_DIV2K_official\nmodel_type: SRModel\nscale: 2\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed"
},
{
"path": "options/test/EDSR/test_EDSR_Lx3.yml",
"chars": 1397,
"preview": "name: EDSR_Lx3_f256b32_DIV2K_official\nmodel_type: SRModel\nscale: 3\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed"
},
{
"path": "options/test/EDSR/test_EDSR_Lx4.yml",
"chars": 1397,
"preview": "name: EDSR_Lx4_f256b32_DIV2K_official\nmodel_type: SRModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed"
},
{
"path": "options/test/EDSR/test_EDSR_Mx2.yml",
"chars": 1392,
"preview": "name: EDSR_Mx2_f64b16_DIV2K_official\nmodel_type: SRModel\nscale: 2\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed:"
},
{
"path": "options/test/EDSR/test_EDSR_Mx3.yml",
"chars": 1392,
"preview": "name: EDSR_Mx3_f64b16_DIV2K_official\nmodel_type: SRModel\nscale: 3\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed:"
},
{
"path": "options/test/EDSR/test_EDSR_Mx4.yml",
"chars": 1392,
"preview": "name: EDSR_Mx4_f64b16_DIV2K_official\nmodel_type: SRModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed:"
},
{
"path": "options/test/EDVR/test_EDVR_L_deblur_REDS.yml",
"chars": 1068,
"preview": "name: EDVR_L_REDS_deblur_official\nmodel_type: EDVRModel\nscale: 1\nnum_gpu: 4 # set num_gpu: 0 for cpu mode\nmanual_seed: "
},
{
"path": "options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml",
"chars": 1081,
"preview": "name: EDVR_L_REDS_deblurcomp_official\nmodel_type: EDVRModel\nscale: 1\nnum_gpu: 4 # set num_gpu: 0 for cpu mode\nmanual_se"
},
{
"path": "options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml",
"chars": 1085,
"preview": "name: EDVR_L_x4_REDS_SR_official\nmodel_type: EDVRModel\nscale: 4\nnum_gpu: 4 # set num_gpu: 0 for cpu mode\nmanual_seed: 0"
},
{
"path": "options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml",
"chars": 997,
"preview": "name: EDVR_L_x4_Vimeo90K_SR_official\nmodel_type: EDVRModel\nscale: 4\nnum_gpu: 4 # set num_gpu: 0 for cpu mode\nmanual_see"
},
{
"path": "options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml",
"chars": 1151,
"preview": "name: EDVR_L_x4_Vimeo90K_SR_official\nmodel_type: EDVRModel\nscale: 4\nnum_gpu: 8 # set num_gpu: 0 for cpu mode\nmanual_see"
},
{
"path": "options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml",
"chars": 1083,
"preview": "name: EDVR_L_x4_REDS_SRblur_official\nmodel_type: EDVRModel\nscale: 4\nnum_gpu: 4 # set num_gpu: 0 for cpu mode\nmanual_see"
},
{
"path": "options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml",
"chars": 1084,
"preview": "name: EDVR_M_x4_SR_REDS_official\nmodel_type: EDVRModel\nscale: 4\nnum_gpu: 4 # set num_gpu: 0 for cpu mode\nmanual_seed: 0"
},
{
"path": "options/test/ESRGAN/test_ESRGAN_x4.yml",
"chars": 1332,
"preview": "name: ESRGAN_SRx4_DF2KOST_official\nmodel_type: ESRGANModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_see"
},
{
"path": "options/test/ESRGAN/test_ESRGAN_x4_woGT.yml",
"chars": 814,
"preview": "name: ESRGAN_SRx4_DF2KOST_official\nmodel_type: ESRGANModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_see"
},
{
"path": "options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml",
"chars": 1332,
"preview": "name: ESRGAN_PSNR_SRx4_DF2K_official\nmodel_type: SRModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed:"
},
{
"path": "options/test/HiFaceGAN/test_hifacegan.yml",
"chars": 1952,
"preview": "name: HiFaceGAN_SR4x_test\nmodel_type: HiFaceGANModel\nscale: 1 # HiFaceGAN does not resize lq input\nnum_gpu: 1 # set "
},
{
"path": "options/test/HiFaceGAN/test_hifacegan_woGT.yml",
"chars": 1009,
"preview": "name: HiFaceGAN_generic_test\nmodel_type: HiFaceGANModel\nscale: 1 # HiFaceGAN does not resize lq input\nnum_gpu: 1 # s"
},
{
"path": "options/test/RCAN/test_RCAN.yml",
"chars": 1248,
"preview": "name: RCAN_BIX4-official\nsuffix: ~ # add suffix to saved images\nmodel_type: SRModel\nscale: 4\ncrop_border: ~ # crop bor"
},
{
"path": "options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml",
"chars": 1344,
"preview": "name: 004_MSRGAN_x4_f64b16_DIV2K_400k_B16G1_wandb\nmodel_type: SRGANModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu m"
},
{
"path": "options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml",
"chars": 1374,
"preview": "name: 002_MSRResNet_x2_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb\nmodel_type: SRModel\nscale: 2\nnum_gpu: 1 # set num_gpu"
},
{
"path": "options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml",
"chars": 1374,
"preview": "name: 003_MSRResNet_x3_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb\nmodel_type: SRModel\nscale: 3\nnum_gpu: 1 # set num_gpu"
},
{
"path": "options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml",
"chars": 1717,
"preview": "# ----------- Commands for running\n# ----------- Single GPU\n# PYTHONPATH=\"./:${PYTHONPATH}\" CUDA_VISIBLE_DEVICES=0 pyth"
},
{
"path": "options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml",
"chars": 832,
"preview": "name: 001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb\nmodel_type: SRModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu "
},
{
"path": "options/test/TOF/test_TOF_official.yml",
"chars": 909,
"preview": "name: TOF_official\nmodel_type: VideoBaseModel\nscale: 4\nnum_gpu: 1 # set num_gpu: 0 for cpu mode\nmanual_seed: 0\n\ndataset"
},
{
"path": "options/train/BasicVSR/train_BasicVSR_REDS.yml",
"chars": 2142,
"preview": "# general settings\nname: BasicVSR_REDS\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: auto # official: 8 GPUs\nmanual"
},
{
"path": "options/train/BasicVSR/train_BasicVSR_Vimeo90K_BDx4.yml",
"chars": 2082,
"preview": "# general settings\nname: BasicVSR_Vimeo90K_BDx4\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 2 # set num_gpu: 0 fo"
},
{
"path": "options/train/BasicVSR/train_BasicVSR_Vimeo90K_BIx4.yml",
"chars": 2082,
"preview": "# general settings\nname: BasicVSR_Vimeo90K_BIx4\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 2 # set num_gpu: 0 fo"
},
{
"path": "options/train/BasicVSR/train_IconVSR_REDS.yml",
"chars": 2263,
"preview": "# general settings\nname: IconVSR_REDS\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 2 # set num_gpu: 0 for cpu mode"
},
{
"path": "options/train/BasicVSR/train_IconVSR_Vimeo90K_BDx4.yml",
"chars": 2195,
"preview": "# general settings\nname: IconVSR_Vimeo90K_BDx4\nmodel_type: VideoRecurrentModel\nscale: 4\nnum_gpu: 2 # set num_gpu: 0 for"
}
]
// ... and 93 more files (download for full content)
About this extraction
This page contains the full source code of the XPixelGroup/BasicSR GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 293 files (3.9 MB), approximately 1.0M tokens, and a symbol index with 952 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.