Showing preview only (243K chars total). Download the full file or copy to clipboard to get everything.
Repository: TencentARC/AnimeSR
Branch: main
Commit: 80a24bf7a527
Files: 47
Total size: 228.9 KB
Directory structure:
gitextract_0qygtin4/
├── .github/
│ └── workflows/
│ └── pylint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── Training.md
├── VERSION
├── animesr/
│ ├── __init__.py
│ ├── archs/
│ │ ├── __init__.py
│ │ ├── discriminator_arch.py
│ │ ├── simple_degradation_arch.py
│ │ └── vsr_arch.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── data_utils.py
│ │ ├── ffmpeg_anime_dataset.py
│ │ ├── ffmpeg_anime_lbo_dataset.py
│ │ └── paired_image_dataset.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── degradation_gan_model.py
│ │ ├── degradation_model.py
│ │ ├── video_recurrent_gan_model.py
│ │ └── video_recurrent_model.py
│ ├── test.py
│ ├── train.py
│ └── utils/
│ ├── __init__.py
│ ├── inference_base.py
│ ├── shot_detector.py
│ └── video_util.py
├── cog.yaml
├── options/
│ ├── train_animesr_step1_gan_BasicOPonly.yml
│ ├── train_animesr_step1_net_BasicOPonly.yml
│ ├── train_animesr_step2_lbo_1_gan.yml
│ ├── train_animesr_step2_lbo_1_net.yml
│ └── train_animesr_step3_gan_3LBOs.yml
├── predict.py
├── requirements.txt
├── scripts/
│ ├── anime_videos_preprocessing.py
│ ├── inference_animesr_frames.py
│ ├── inference_animesr_video.py
│ └── metrics/
│ ├── MANIQA/
│ │ ├── inference_MANIQA.py
│ │ ├── models/
│ │ │ ├── model_attentionIQA2.py
│ │ │ └── swin.py
│ │ ├── pipal_data.py
│ │ └── utils.py
│ └── README.md
├── setup.cfg
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/pylint.yml
================================================
name: PyLint
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 yapf isort
- name: Lint
run: |
flake8 .
isort --check-only --diff animesr/ options/ scripts/ setup.py
yapf -r -d animesr/ options/ scripts/ setup.py
================================================
FILE: .gitignore
================================================
datasets/*
experiments/*
results/*
tb_logger/*
wandb/*
tmp/*
weights/*
inputs/*
*.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.idea/
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
# flake8
- repo: https://github.com/PyCQA/flake8
rev: 3.7.9
hooks:
- id: flake8
args: ["--config=setup.cfg", "--ignore=W504, W503"]
# modify known_third_party
- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config
args: ["--exclude=scripts/metrics/MANIQA"]
# isort
- repo: https://github.com/timothycrosley/isort
rev: 5.2.2
hooks:
- id: isort
# yapf
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.30.0
hooks:
- id: yapf
# pre-commit-hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace # Trim trailing whitespace
- id: check-yaml # Attempt to load all yaml files to verify syntax
- id: check-merge-conflict # Check for files that contain merge conflict strings
- id: end-of-file-fixer # Make sure files end in a newline and only a newline
- id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0
- id: mixed-line-ending # Replace or check mixed line ending
args: ["--fix=lf"]
================================================
FILE: LICENSE
================================================
Tencent is pleased to support the open source community by making AnimeSR available.
Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
AnimeSR is licensed under the Apache License Version 2.0 except for the third-party components listed below.
Terms of the Apache License Version 2.0:
---------------------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
You must give any other recipients of the Work or Derivative Works a copy of this License; and
You must cause any modified files to carry prominent notices stating that You changed the files; and
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Other dependencies and licenses:
Open Source Software Licensed under the Apache License Version 2.0:
--------------------------------------------------------------------
1. ffmpeg-python
Copyright 2017 Karl Kroening
2. basicsr
Copyright 2018-2022 BasicSR Authors
Terms of the Apache License Version 2.0:
--------------------------------------------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
You must give any other recipients of the Work or Derivative Works a copy of this License; and
You must cause any modified files to carry prominent notices stating that You changed the files; and
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Open Source Software Licensed under the BSD 3-Clause License:
--------------------------------------------------------------------
1. torch
From PyTorch:
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
From Caffe2:
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.
All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.
All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.
All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain
All contributions by Cruise LLC:
Copyright (c) 2022 Cruise LLC.
All rights reserved.
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.
All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.
Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
mark their specific copyright on a particular contribution, they should
indicate their copyright solely in the commit message of the change when it is
committed.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
2. torchvision
Copyright (c) Soumith Chintala 2016,
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3. numpy
Copyright (c) 2005-2022, NumPy Developers.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
* Neither the name of the NumPy Developers nor the names of any
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4. psutil
Copyright (c) 2009, Jay Loden, Dave Daeschler, Giampaolo Rodola'
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the psutil authors nor the names of its contributors
may be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Open Source Software Licensed under the HPND License:
--------------------------------------------------------------------
1. Pillow
The Python Imaging Library (PIL) is
Copyright © 1997-2011 by Secret Labs AB
Copyright © 1995-2011 by Fredrik Lundh
Pillow is the friendly PIL fork. It is
Copyright © 2010-2022 by Alex Clark and contributors
Like PIL, Pillow is licensed under the open source HPND License:
By obtaining, using, and/or copying this software and/or its associated
documentation, you agree that you have read, understood, and will comply
with the following terms and conditions:
Permission to use, copy, modify, and distribute this software and its
associated documentation for any purpose and without fee is hereby granted,
provided that the above copyright notice appears in all copies, and that
both that copyright notice and this permission notice appear in supporting
documentation, and that the name of Secret Labs AB or the author not be
used in advertising or publicity pertaining to distribution of the software
without specific, written prior permission.
SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS.
IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL,
INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.
Open Source Software Licensed under the MIT License:
--------------------------------------------------------------------
1. opencv-python
Copyright (c) Olli-Pekka Heinisuo
Terms of the MIT License:
--------------------------------------------------------------------
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Open Source Software Licensed under the MIT and MPL v2.0 Licenses:
--------------------------------------------------------------------
1. tqdm
`tqdm` is a product of collaborative work.
Unless otherwise stated, all authors (see commit logs) retain copyright
for their respective work, and release the work under the MIT licence
(text below).
Exceptions or notable authors are listed below
in reverse chronological order:
* files: *
MPLv2.0 2015-2021 (c) Casper da Costa-Luis
[casperdcl](https://github.com/casperdcl).
* files: tqdm/_tqdm.py
MIT 2016 (c) [PR #96] on behalf of Google Inc.
* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
MIT 2013 (c) Noam Yorav-Raphael, original author.
[PR #96]: https://github.com/tqdm/tqdm/pull/96
Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
-----------------------------------------------
This Source Code Form is subject to the terms of the
Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this project,
You can obtain one at https://mozilla.org/MPL/2.0/.
MIT License (MIT)
-----------------
Copyright (c) 2013 noamraph
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
================================================
FILE: README.md
================================================
# AnimeSR (NeurIPS 2022)
### :open_book: AnimeSR: Learning Real-World Super-Resolution Models for Animation Videos
> [](https://arxiv.org/abs/2206.07038)<br>
> [Yanze Wu](https://github.com/ToTheBeginning), [Xintao Wang](https://xinntao.github.io/), [Gen Li](https://scholar.google.com.hk/citations?user=jBxlX7oAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en) <br>
> [Tencent ARC Lab](https://arc.tencent.com/en/index); Platform Technologies, Tencent Online Video
### :triangular_flag_on_post: Updates
* **2022.11.28**: release codes&models.
* **2022.08.29**: release AVC-Train and AVC-Test.
## Web Demo and API
[](https://replicate.com/cjwbw/animesr)
## Video Demos
https://user-images.githubusercontent.com/11482921/204205018-d69e2e51-fbdc-4766-8293-a40ffce3ed25.mp4
https://user-images.githubusercontent.com/11482921/204205109-35866094-fa7f-413b-8b43-bb479b42dfb6.mp4
## :wrench: Dependencies and Installation
- Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
- [PyTorch >= 1.7](https://pytorch.org/)
- Other required packages in `requirements.txt`
### Installation
1. Clone repo
```bash
git clone https://github.com/TencentARC/AnimeSR.git
cd AnimeSR
```
2. Install
```bash
# Install dependent packages
pip install -r requirements.txt
# Install AnimeSR
python setup.py develop
```
## :zap: Quick Inference
Download the pre-trained AnimeSR models [[Google Drive](https://drive.google.com/drive/folders/1gwNTbKLUjt5FlgT6PQQnBz5wFzmNUX8g?usp=share_link)], and put them into the [weights](weights/) folder. Currently, the available pre-trained models are:
- `AnimeSR_v1-PaperModel.pth`: v1 model, also the paper model. You can use this model for paper results reproducing.
- `AnimeSR_v2.pth`: v2 model. Compare with v1, this version has better naturalness, fewer artifacts, and better texture/background restoration. If you want better results, use this model.
AnimeSR supports both frames and videos as input for inference. We provide several sample test cases in [google drive](https://drive.google.com/drive/folders/1K8JzGNqY_pHahYBv51iUUI7_hZXmamt2?usp=share_link), you can download it and put them to [inputs](inputs/) folder.
**Inference on Frames**
```bash
python scripts/inference_animesr_frames.py -i inputs/tom_and_jerry -n AnimeSR_v2 --expname animesr_v2 --save_video_too --fps 20
```
```console
Usage:
-i --input Input frames folder/root. Support first level dir (i.e., input/*.png) and second level dir (i.e., input/*/*.png)
-n --model_name AnimeSR model name. Default: AnimeSR_v2, can also be AnimeSR_v1-PaperModel
-s --outscale The netscale is x4, but you can achieve arbitrary output scale (e.g., x2 or x1) with the argument outscale.
The program will further perform cheap resize operation after the AnimeSR output. Default: 4
-o --output Output root. Default: results
-expname Identify the name of your current inference. The outputs will be saved in $output/$expname
-save_video_too Save the output frames to video. Default: off
-fps The fps of the (possible) saved videos. Default: 24
```
After run the above command, you will get the SR frames in `results/animesr_v2/frames` and the SR video in `results/animesr_v2/videos`.
**Inference on Video**
```bash
# single gpu and single process inference
CUDA_VISIBLE_DEVICES=0 python scripts/inference_animesr_video.py -i inputs/TheMonkeyKing1965.mp4 -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 1 --suffix 1gpu1process
# single gpu and multi process inference (you can use multi-processing to improve GPU utilization)
CUDA_VISIBLE_DEVICES=0 python scripts/inference_animesr_video.py -i inputs/TheMonkeyKing1965.mp4 -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 3 --suffix 1gpu3process
# multi gpu and multi process inference
CUDA_VISIBLE_DEVICES=0,1 python scripts/inference_animesr_video.py -i inputs/TheMonkeyKing1965.mp4 -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 3 --suffix 2gpu6process
```
```console
Usage:
-i --input Input video path or extracted frames folder
-n --model_name AnimeSR model name. Default: AnimeSR_v2, can also be AnimeSR_v1-PaperModel
-s --outscale The netscale is x4, but you can achieve arbitrary output scale (e.g., x2 or x1) with the argument outscale.
The program will further perform cheap resize operation after the AnimeSR output. Default: 4
-o -output Output root. Default: results
-expname Identify the name of your current inference. The outputs will be saved in $output/$expname
-fps The fps of the (possible) saved videos. Default: None
-extract_frame_first If input is a video, you can still extract the frames first, other wise AnimeSR will read from stream
-num_process_per_gpu Since the slow I/O speed will make GPU utilization not high enough, so as long as the
video memory is sufficient, we recommend placing multiple processes on one GPU to increase the utilization of each GPU.
The total process will be number_process_per_gpu * num_gpu
-suffix You can add a suffix string to the sr video name, for example, 1gpu3processx2 which means the SR video is generated with one GPU and three process and the outscale is x2
-half Use half precision for inference, it won't make big impact on the visual results
```
SR videos are saved in `results/animesr_v2/videos/$video_name` folder.
If you are looking for portable executable files, you can try our [realesr-animevideov3](https://github.com/xinntao/Real-ESRGAN/blob/master/docs/anime_video_model.md) model which shares the similar technology with AnimeSR.
## :computer: Training
See [Training.md](Training.md)
## Request for AVC-Dataset
1. Download and carefully read the [LICENSE AGREEMENT](assets/LICENSE%20AGREEMENT.pdf) PDF file.
2. If you understand, acknowledge, and agree to all the terms specified in the [LICENSE AGREEMENT](assets/LICENSE%20AGREEMENT.pdf). Please email `wuyanze123@gmail.com` with the **LICENSE AGREEMENT PDF** file, **your name**, and **institution**. We will keep the license and send the download link of AVC dataset to you.
## Acknowledgement
This project is build based on [BasicSR](https://github.com/XPixelGroup/BasicSR).
## Citation
If you find this project useful for your research, please consider citing our paper:
```bibtex
@InProceedings{wu2022animesr,
author={Wu, Yanze and Wang, Xintao and Li, Gen and Shan, Ying},
title={AnimeSR: Learning Real-World Super-Resolution Models for Animation Videos},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
```
## :e-mail: Contact
If you have any question, please email `wuyanze123@gmail.com`.
================================================
FILE: Training.md
================================================
# :computer: How to Train AnimeSR
- [Overview](#overview)
- [Dataset Preparation](#dataset-preparation)
- [Training](#training)
- [Training step 1](#training-step-1)
- [Training step 2](#training-step-2)
- [Training step 3](#training-step-3)
- [The Pre-Trained Checkpoints](#the-pre-trained-checkpoints)
- [Other Tips](#other-tips)
- [How to build your own (training) dataset ?](#how-to-build-your-own-training-dataset-)
## Overview
The training has been divided into three steps.
1. Training a Video Super-Resolution (VSR) model with a degradation model that only contains the classic basic operators (*i.e.*, blur, noise, downscale, compression).
2. Training several **L**earnable **B**asic **O**perators (**LBO**s). Using the VSR model from step 1 and the input-rescaling strategy to generate pseudo HR for real-world LR. The paired pseudo HR-LR data are used to train the LBO in a supervised manner.
3. Training the final VSR model with a degradation model containing both classic basic operators and learnable basic operators.
Specifically, the model training in each step consists of two stages. In the first stage, the model is trained with L1 loss from scratch. In the second stage, the model is fine-tuned with the combination of L1 loss, perceptual loss, and GAN loss.
## Dataset Preparation
We use AVC-Train dataset for our training. The AVC dataset is released under request, please refer to [Request for AVC-Dataset](README.md#request-for-avc-dataset).
After you download the AVC-Train dataset, put all the clips into one folder (dataset root). The dataset root should contain 553 folders(clips), each folder contains 100 frames, from `00000000.png` to `00000099.png`.
If you want to build your own (training) dataset or enlarge AVC-Train dataset, please refer to [How to build your own (training) dataset](#how-to-build-your-own-training-dataset).
## Training
As described in the paper, all the training is performed on four NVIDIA A100 GPUs in an internal cluster. You may need to adjust the batchsize according to the CUDA memory of your GPU card.
Before the training, you should modify the [option files](options/) accordingly. For example, you should modify the `dataroot_gt` to your own dataset root. We have comment all the lines you should modify with the `TO_MODIFY` tag.
### Training step 1
1. Train `Net` model
Before the training, you should modify the [yaml file](options/train_animesr_step1_net_BasicOPonly.yml) accordingly. For example, you should modify the `dataroot_gt` to your own dataset root.
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step1_net_BasicOPonly.yml --launcher pytorch --auto_resume
```
2. Train `GAN` model
The GAN model is fine-tuned from the `Net` model, as specified in `pretrain_network_g`
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step1_gan_BasicOPonly.yml --launcher pytorch --auto_resume
```
### Training step 2
The input frames for training LBO in the paper are included in the AVC dataset download link we sent you. These frames came from three real-world LR animation videos, and ~2,000 frames are selected from each video.
In order to obtain the paired data required for training LBO, you will need to use the VSR model obtained in step 1 and the input-rescaling strategy as described in the paper to process these input frames to obtain pseudo HR-LR paired data.
```bash
# make the soft link for the VSR model obtained in step 1
ln -s experiments/train_animesr_step1_net_BasicOPonly/models/net_g_300000.pth weights/step1_vsr_gan_model.pth
# using input-rescaling strategy to inference
python scripts/inference_animesr_frames.py -i datasets/lbo_training_data/real_world_video_to_train_lbo_1 -n step1_vsr_gan_model --input_rescaling_factor 0.5 --mod_scale 8 --expname input_rescaling_strategy_lbo_1
```
After the inference, you can train the LBO. Note that we only provide one [option file](options/train_animesr_step2_lbo_1_net.yml) for training `Net` model and one [option file](options/train_animesr_step2_lbo_1_gan.yml) for training `GAN` model. If you want to train multiple LBOs, just copy&paste those option files and modify the `name`, `dataroot_gt`, and `dataroot_lq`.
```bash
# train Net model
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step2_lbo_1_net.yml --launcher pytorch --auto_resume
# train GAN model
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step2_lbo_1_gan.yml --launcher pytorch --auto_resume
```
### Training step 3
Before the training, you will need to modify the `degradation_model_path` to the pre-trained LBO path.
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step3_gan_3LBOs.yml --launcher pytorch --auto_resume
```
## Evaluation
See [evaluation readme](scripts/metrics/README.md).
## The Pre-Trained Checkpoints
You can download the checkpoints of all steps in [google drive](https://drive.google.com/drive/folders/1hCXhKNZYBADXsS_weHO2z3HhNE-Eg_jw?usp=share_link).
## Other Tips
#### How to build your own (training) dataset ?
Suppose you have a batch of HQ (high resolution, high bitrate, high quality) animation video, we provide the [anime_videos_preprocessing.py](scripts/anime_videos_preprocessing.py) script to help you to prepare training clips from the raw videos.
The preprocessing consists of 6 steps:
1. use FFmpeg to extract frames. Note that this step will take up a lot of disk space.
2. shot detection using [PySceneDetect](https://github.com/Breakthrough/PySceneDetect/)
3. flow estimation using spynet
4. black frames detection
5. image quality assessment using hyperIQA
6. generate clips for each video
```console
Usage: python scripts/anime_videos_preprocessing.py --dataroot datasets/YOUR_OWN_ANIME --n_thread 4 --run 1
--dataroot dataset root, dataroot/raw_videos should contains your HQ videos to be processed
--n_thread number of workers to process in parallel
--run which step to run. Since each step may take a long time, we recommend performing it step by step.
And after each step, check whether the output files are as expected
--n_frames_per_clip number of frames per clip. Default 100. You can increase the number if you want more training data
--n_clips_per_video number of clips per video. Default 1. You can increase the number if you want more training data
```
After you finish all the steps, you will get the clips in `dataroot/select_clips`
================================================
FILE: VERSION
================================================
0.1.0
================================================
FILE: animesr/__init__.py
================================================
# flake8: noqa
from .archs import *
from .data import *
from .models import *
# from .version import __gitsha__, __version__
================================================
FILE: animesr/archs/__init__.py
================================================
import importlib
from os import path as osp
from basicsr.utils import scandir
# automatically scan and import arch modules for registry
# scan all the files that end with '_arch.py' under the archs folder
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'animesr.archs.{file_name}') for file_name in arch_filenames]
================================================
FILE: animesr/archs/discriminator_arch.py
================================================
import functools
from torch import nn as nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm
from basicsr.utils.registry import ARCH_REGISTRY
def get_conv_layer(input_nc, ndf, kernel_size, stride, padding, bias=True, use_sn=False):
if not use_sn:
return nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
return spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
@ARCH_REGISTRY.register()
class UNetDiscriminatorSN(nn.Module):
"""Defines a U-Net discriminator with spectral normalization (SN). copy from real-esrgan"""
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
super(UNetDiscriminatorSN, self).__init__()
self.skip_connection = skip_connection
norm = spectral_norm
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
# upsample
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
# extra
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
def forward(self, x):
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
# upsample
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
if self.skip_connection:
x4 = x4 + x2
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
if self.skip_connection:
x5 = x5 + x1
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
if self.skip_connection:
x6 = x6 + x0
# extra
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out)
return out
@ARCH_REGISTRY.register()
class PatchDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator, the receptive field of default config is 70x70.
Args:
use_sn (bool): Use spectra_norm or not, if use_sn is True, then norm_type should be none.
"""
def __init__(self,
num_in_ch,
num_feat=64,
num_layers=3,
max_nf_mult=8,
norm_type='batch',
use_sigmoid=False,
use_sn=False):
super(PatchDiscriminator, self).__init__()
norm_layer = self._get_norm_layer(norm_type)
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [
get_conv_layer(num_in_ch, num_feat, kernel_size=kw, stride=2, padding=padw, use_sn=use_sn),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, num_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, max_nf_mult)
sequence += [
get_conv_layer(
num_feat * nf_mult_prev,
num_feat * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
use_sn=use_sn),
norm_layer(num_feat * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**num_layers, max_nf_mult)
sequence += [
get_conv_layer(
num_feat * nf_mult_prev,
num_feat * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
use_sn=use_sn),
norm_layer(num_feat * nf_mult),
nn.LeakyReLU(0.2, True)
]
# output 1 channel prediction map
sequence += [get_conv_layer(num_feat * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, use_sn=use_sn)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def _get_norm_layer(self, norm_type='batch'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'batchnorm2d':
norm_layer = nn.BatchNorm2d
elif norm_type == 'none':
norm_layer = nn.Identity
else:
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
return norm_layer
def forward(self, x):
return self.model(x)
@ARCH_REGISTRY.register()
class MultiScaleDiscriminator(nn.Module):
"""Define a multi-scale discriminator, each discriminator is a instance of PatchDiscriminator.
Args:
num_layers (int or list): If the type of this variable is int, then degrade to PatchDiscriminator.
If the type of this variable is list, then the length of the list is
the number of discriminators.
use_downscale (bool): Progressive downscale the input to feed into different discriminators.
If set to True, then the discriminators are usually the same.
"""
def __init__(self,
num_in_ch,
num_feat=64,
num_layers=3,
max_nf_mult=8,
norm_type='batch',
use_sigmoid=False,
use_sn=False,
use_downscale=False):
super(MultiScaleDiscriminator, self).__init__()
if isinstance(num_layers, int):
num_layers = [num_layers]
# check whether the discriminators are the same
if use_downscale:
assert len(set(num_layers)) == 1
self.use_downscale = use_downscale
self.num_dis = len(num_layers)
self.dis_list = nn.ModuleList()
for nl in num_layers:
self.dis_list.append(
PatchDiscriminator(
num_in_ch,
num_feat=num_feat,
num_layers=nl,
max_nf_mult=max_nf_mult,
norm_type=norm_type,
use_sigmoid=use_sigmoid,
use_sn=use_sn,
))
def forward(self, x):
outs = []
h, w = x.size()[2:]
y = x
for i in range(self.num_dis):
if i != 0 and self.use_downscale:
y = F.interpolate(y, size=(h // 2, w // 2), mode='bilinear', align_corners=True)
h, w = y.size()[2:]
outs.append(self.dis_list[i](y))
return outs
================================================
FILE: animesr/archs/simple_degradation_arch.py
================================================
from torch import nn as nn
from basicsr.archs.arch_util import default_init_weights, pixel_unshuffle
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register()
class SimpleDegradationArch(nn.Module):
"""simple degradation architecture which consists several conv and non-linear layer
it learns the mapping from HR to LR
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, downscale=2):
"""
:param num_in_ch: input is a pseudo HR image, channel is 3
:param num_out_ch: output is an LR image, channel is also 3
:param num_feat: we use a small network, hidden dimension is 64
:param downscale: suppose (h, w) is the height&width of a real-world LR video.
Firstly, we select the best rescaling factor (usually around 0.5) for this LR video.
Secondly, we obtain the pseudo HR frames and resize them to (2h, 2w).
To learn the mapping from pseudo HR to LR, LBO contains a pixel-unshuffle layer with
a scale factor of 2 to perform the downsampling at the beginning.
"""
super(SimpleDegradationArch, self).__init__()
num_in_ch = num_in_ch * downscale * downscale
self.main = nn.Sequential(
nn.Conv2d(num_in_ch, num_feat, 3, 1, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_feat, num_feat, 3, 1, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_feat, num_out_ch, 3, 1, 1),
)
self.downscale = downscale
default_init_weights(self.main)
def forward(self, x):
x = pixel_unshuffle(x, self.downscale)
x = self.main(x)
return x
================================================
FILE: animesr/archs/vsr_arch.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.archs.arch_util import ResidualBlockNoBN, pixel_unshuffle
from basicsr.utils.registry import ARCH_REGISTRY
class RightAlignMSConvResidualBlocks(nn.Module):
"""right align multi-scale ConvResidualBlocks, currently only support 3 scales (1, 2, 4)"""
def __init__(self, num_in_ch=3, num_state_ch=64, num_out_ch=64, num_block=(5, 3, 2)):
super().__init__()
assert len(num_block) == 3
assert num_block[0] >= num_block[1] >= num_block[2]
self.num_block = num_block
self.conv_s1_first = nn.Sequential(
nn.Conv2d(num_in_ch, num_state_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True))
self.conv_s2_first = nn.Sequential(
nn.Conv2d(num_state_ch, num_state_ch, 3, 2, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True))
self.conv_s4_first = nn.Sequential(
nn.Conv2d(num_state_ch, num_state_ch, 3, 2, 1, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
)
self.body_s1_first = nn.ModuleList()
for _ in range(num_block[0]):
self.body_s1_first.append(ResidualBlockNoBN(num_feat=num_state_ch))
self.body_s2_first = nn.ModuleList()
for _ in range(num_block[1]):
self.body_s2_first.append(ResidualBlockNoBN(num_feat=num_state_ch))
self.body_s4_first = nn.ModuleList()
for _ in range(num_block[2]):
self.body_s4_first.append(ResidualBlockNoBN(num_feat=num_state_ch))
self.upsample_x2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.upsample_x4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
self.fusion = nn.Sequential(
nn.Conv2d(3 * num_state_ch, 2 * num_out_ch, 3, 1, 1, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(2 * num_out_ch, num_out_ch, 3, 1, 1, bias=True),
)
def up(self, x, scale=2):
if isinstance(x, int):
return x
elif scale == 2:
return self.upsample_x2(x)
else:
return self.upsample_x4(x)
def forward(self, x):
x_s1 = self.conv_s1_first(x)
x_s2 = self.conv_s2_first(x_s1)
x_s4 = self.conv_s4_first(x_s2)
flag_s2 = False
flag_s4 = False
for i in range(0, self.num_block[0]):
x_s1 = self.body_s1_first[i](
x_s1 + (self.up(x_s2, 2) if flag_s2 else 0) + (self.up(x_s4, 4) if flag_s4 else 0))
if i >= self.num_block[0] - self.num_block[1]:
x_s2 = self.body_s2_first[i - self.num_block[0] + self.num_block[1]](
x_s2 + (self.up(x_s4, 2) if flag_s4 else 0))
flag_s2 = True
if i >= self.num_block[0] - self.num_block[2]:
x_s4 = self.body_s4_first[i - self.num_block[0] + self.num_block[2]](x_s4)
flag_s4 = True
x_fusion = self.fusion(torch.cat((x_s1, self.upsample_x2(x_s2), self.upsample_x4(x_s4)), dim=1))
return x_fusion
@ARCH_REGISTRY.register()
class MSRSWVSR(nn.Module):
"""
Multi-Scale, unidirectional Recurrent, Sliding Window (MSRSW)
The implementation refers to paper: Efficient Video Super-Resolution through Recurrent Latent Space Propagation
"""
def __init__(self, num_feat=64, num_block=(5, 3, 2), netscale=4):
super(MSRSWVSR, self).__init__()
self.num_feat = num_feat
# 3(img channel) * 3(prev cur nxt 3 imgs) + 3(hr img channel) * netscale * netscale + num_feat
self.recurrent_cell = RightAlignMSConvResidualBlocks(3 * 3 + 3 * netscale * netscale + num_feat, num_feat,
num_feat + 3 * netscale * netscale, num_block)
self.lrelu = nn.LeakyReLU(negative_slope=0.1)
self.pixel_shuffle = nn.PixelShuffle(netscale)
self.netscale = netscale
def cell(self, x, fb, state):
res = x[:, 3:6]
# pre frame, cur frame, nxt frame, pre sr frame, pre hidden state
inp = torch.cat((x, pixel_unshuffle(fb, self.netscale), state), dim=1)
# the out contains both state and sr frame
out = self.recurrent_cell(inp)
out_img = self.pixel_shuffle(out[:, :3 * self.netscale * self.netscale]) + F.interpolate(
res, scale_factor=self.netscale, mode='bilinear', align_corners=False)
out_state = self.lrelu(out[:, 3 * self.netscale * self.netscale:])
return out_img, out_state
def forward(self, x):
b, n, c, h, w = x.size()
# initialize previous sr frame and previous hidden state as zero tensor
out = x.new_zeros(b, c, h * self.netscale, w * self.netscale)
state = x.new_zeros(b, self.num_feat, h, w)
out_l = []
for i in range(n):
if i == 0:
# there is no previous frame for the 1st frame, so reuse 1st frame as previous
out, state = self.cell(torch.cat((x[:, i], x[:, i], x[:, i + 1]), dim=1), out, state)
elif i == n - 1:
# there is no next frame for the last frame, so reuse last frame as next
out, state = self.cell(torch.cat((x[:, i - 1], x[:, i], x[:, i]), dim=1), out, state)
else:
out, state = self.cell(torch.cat((x[:, i - 1], x[:, i], x[:, i + 1]), dim=1), out, state)
out_l.append(out)
return torch.stack(out_l, dim=1)
================================================
FILE: animesr/data/__init__.py
================================================
import importlib
from os import path as osp
from basicsr.utils import scandir
# automatically scan and import dataset modules for registry
# scan all the files that end with '_dataset.py' under the data folder
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'animesr.data.{file_name}') for file_name in dataset_filenames]
================================================
FILE: animesr/data/data_utils.py
================================================
import random
import torch
def random_crop(imgs, patch_size, top=None, left=None):
"""
randomly crop patches from imgs
:param imgs: can be (list of) tensor, cv2 img
:param patch_size: patch size, usually 256
:param top: will sample if is None
:param left: will sample if is None
:return: cropped patches from input imgs
"""
if not isinstance(imgs, list):
imgs = [imgs]
# determine input type: Numpy array or Tensor
input_type = 'Tensor' if torch.is_tensor(imgs[0]) else 'Numpy'
if input_type == 'Tensor':
h, w = imgs[0].size()[-2:]
else:
h, w = imgs[0].shape[0:2]
# randomly choose top and left coordinates
if top is None:
top = random.randint(0, h - patch_size)
if left is None:
left = random.randint(0, w - patch_size)
if input_type == 'Tensor':
imgs = [v[:, :, top:top + patch_size, left:left + patch_size] for v in imgs]
else:
imgs = [v[top:top + patch_size, left:left + patch_size, ...] for v in imgs]
if len(imgs) == 1:
imgs = imgs[0]
return imgs
================================================
FILE: animesr/data/ffmpeg_anime_dataset.py
================================================
import cv2
import ffmpeg
import glob
import numpy as np
import os
import random
import torch
from os import path as osp
from torch.utils import data as data
from basicsr.data.degradations import random_add_gaussian_noise, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from .data_utils import random_crop
@DATASET_REGISTRY.register()
class FFMPEGAnimeDataset(data.Dataset):
"""Anime datasets with only classic basic operators"""
def __init__(self, opt):
super(FFMPEGAnimeDataset, self).__init__()
self.opt = opt
self.num_frame = opt['num_frame']
self.num_half_frames = opt['num_frame'] // 2
self.keys = []
self.clip_frames = {}
self.gt_root = opt['dataroot_gt']
logger = get_root_logger()
clip_names = os.listdir(self.gt_root)
for clip_name in clip_names:
num_frames = len(glob.glob(osp.join(self.gt_root, clip_name, '*.png')))
self.keys.extend([f'{clip_name}/{i:08d}' for i in range(num_frames)])
self.clip_frames[clip_name] = num_frames
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
self.iso_blur_range = opt.get('iso_blur_range', [0.2, 4])
self.aniso_blur_range = opt.get('aniso_blur_range', [0.8, 3])
self.noise_range = opt.get('noise_range', [0, 10])
self.crf_range = opt.get('crf_range', [18, 35])
self.ffmpeg_profile_names = opt.get('ffmpeg_profile_names', ['baseline', 'main', 'high'])
self.ffmpeg_profile_probs = opt.get('ffmpeg_profile_probs', [0.1, 0.2, 0.7])
self.scale = opt.get('scale', 4)
assert self.scale in (2, 4)
# temporal augmentation configs
self.interval_list = opt.get('interval_list', [1])
self.random_reverse = opt.get('random_reverse', False)
interval_str = ','.join(str(x) for x in self.interval_list)
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
f'random reverse is {self.random_reverse}.')
def get_gt_clip(self, index):
"""
get the GT(hr) clip with self.num_frame frames
:param index: the index from __getitem__
:return: a list of images, with numpy(cv2) format
"""
key = self.keys[index] # get clip from this key frame (if possible)
clip_name, frame_name = key.split('/') # key example: 000/00000000
# determine the "interval" of neighboring frames
interval = random.choice(self.interval_list)
# ensure not exceeding the borders
center_frame_idx = int(frame_name)
start_frame_idx = center_frame_idx - self.num_half_frames * interval
end_frame_idx = center_frame_idx + self.num_half_frames * interval
# if the index doesn't satisfy the requirement, resample it
if (start_frame_idx < 0) or (end_frame_idx >= self.clip_frames[clip_name]):
center_frame_idx = random.randint(self.num_half_frames * interval,
self.clip_frames[clip_name] - 1 - self.num_half_frames * interval)
start_frame_idx = center_frame_idx - self.num_half_frames * interval
end_frame_idx = center_frame_idx + self.num_half_frames * interval
# determine the neighbor frames
neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
# random reverse
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
# get the neighboring GT frames
img_gts = []
for neighbor in neighbor_list:
if self.is_lmdb:
img_gt_path = f'{clip_name}/{neighbor:08d}'
else:
img_gt_path = osp.join(self.gt_root, clip_name, f'{neighbor:08d}.png')
# get GT
img_bytes = self.file_client.get(img_gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
img_gts.append(img_gt)
# random crop
img_gts = random_crop(img_gts, self.opt['gt_size'])
# augmentation
img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot'])
return img_gts
def add_ffmpeg_compression(self, img_lqs, width, height):
# ffmpeg
loglevel = 'error'
format = 'h264'
fps = random.choices([24, 25, 30, 50, 60], [0.2, 0.2, 0.2, 0.2, 0.2])[0] # still have problems
fps = 25
crf = np.random.uniform(self.crf_range[0], self.crf_range[1])
try:
extra_args = dict()
if format == 'h264':
vcodec = 'libx264'
profile = random.choices(self.ffmpeg_profile_names, self.ffmpeg_profile_probs)[0]
extra_args['profile:v'] = profile
ffmpeg_img2video = (
ffmpeg.input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{width}x{height}',
r=fps).filter('fps', fps=fps, round='up').output(
'pipe:', format=format, pix_fmt='yuv420p', crf=crf, vcodec=vcodec,
**extra_args).global_args('-hide_banner').global_args('-loglevel', loglevel).run_async(
pipe_stdin=True, pipe_stdout=True))
ffmpeg_video2img = (
ffmpeg.input('pipe:', format=format).output('pipe:', format='rawvideo',
pix_fmt='rgb24').global_args('-hide_banner').global_args(
'-loglevel',
loglevel).run_async(pipe_stdin=True, pipe_stdout=True))
# read a sequence of images
for img_lq in img_lqs:
ffmpeg_img2video.stdin.write(img_lq.astype(np.uint8).tobytes())
ffmpeg_img2video.stdin.close()
video_bytes = ffmpeg_img2video.stdout.read()
ffmpeg_img2video.wait()
# ffmpeg: video to images
ffmpeg_video2img.stdin.write(video_bytes)
ffmpeg_video2img.stdin.close()
img_lqs_ffmpeg = []
while True:
in_bytes = ffmpeg_video2img.stdout.read(width * height * 3)
if not in_bytes:
break
in_frame = (np.frombuffer(in_bytes, np.uint8).reshape([height, width, 3]))
in_frame = in_frame.astype(np.float32) / 255.
img_lqs_ffmpeg.append(in_frame)
ffmpeg_video2img.wait()
assert len(img_lqs_ffmpeg) == self.num_frame, 'Wrong length'
except AssertionError as error:
logger = get_root_logger()
logger.warn(f'ffmpeg assertion error: {error}')
except Exception as error:
logger = get_root_logger()
logger.warn(f'ffmpeg exception error: {error}')
else:
img_lqs = img_lqs_ffmpeg
return img_lqs
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
img_gts = self.get_gt_clip(index)
# ------------- generate LQ frames --------------#
# add blur
kernel = random_mixed_kernels(['iso', 'aniso'], [0.7, 0.3], 21, self.iso_blur_range, self.aniso_blur_range)
img_lqs = [cv2.filter2D(v, -1, kernel) for v in img_gts]
# add noise
img_lqs = [
random_add_gaussian_noise(v, sigma_range=self.noise_range, gray_prob=0.5, clip=True, rounds=False)
for v in img_lqs
]
# downsample
h, w = img_gts[0].shape[0:2]
width = w // self.scale
height = h // self.scale
resize_type = random.choices([cv2.INTER_AREA, cv2.INTER_LINEAR, cv2.INTER_CUBIC], [0.3, 0.3, 0.4])[0]
img_lqs = [cv2.resize(v, (width, height), interpolation=resize_type) for v in img_lqs]
# ffmpeg
img_lqs = [np.clip(img_lq * 255.0, 0, 255) for img_lq in img_lqs]
img_lqs = self.add_ffmpeg_compression(img_lqs, width, height)
# ------------- end --------------#
img_gts = img2tensor(img_gts)
img_lqs = img2tensor(img_lqs)
img_gts = torch.stack(img_gts, dim=0)
img_lqs = torch.stack(img_lqs, dim=0)
# img_lqs: (t, c, h, w)
# img_gts: (t, c, h, w)
return {'lq': img_lqs, 'gt': img_gts}
def __len__(self):
return len(self.keys)
================================================
FILE: animesr/data/ffmpeg_anime_lbo_dataset.py
================================================
import numpy as np
import random
import torch
from torch.nn import functional as F
from animesr.archs.simple_degradation_arch import SimpleDegradationArch
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_mixed_kernels
from basicsr.utils import FileClient, get_root_logger, img2tensor
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.img_process_util import filter2D
from basicsr.utils.registry import DATASET_REGISTRY
from .ffmpeg_anime_dataset import FFMPEGAnimeDataset
@DATASET_REGISTRY.register()
class FFMPEGAnimeLBODataset(FFMPEGAnimeDataset):
"""Anime datasets with both classic basic operators and learnable basic operators (LBO)"""
def __init__(self, opt):
super(FFMPEGAnimeLBODataset, self).__init__(opt)
self.rank, self.world_size = get_dist_info()
self.lbo = SimpleDegradationArch(downscale=2)
lbo_list = opt['degradation_model_path']
if not isinstance(lbo_list, list):
lbo_list = [lbo_list]
self.lbo_list = lbo_list
# print(f'degradation model path for {self.rank} {self.world_size}: {degradation_model_path}\n')
# the real load is at reload_degradation_model function
self.lbo.load_state_dict(torch.load(self.lbo_list[0], map_location=lambda storage, loc: storage)['params'])
self.lbo = self.lbo.to(f'cuda:{self.rank}').eval()
self.lbo_prob = opt.get('lbo_prob', 0.5)
def reload_degradation_model(self):
"""
__init__ will be only invoked once for one gpu worker, so if we want to
have num_worker_dataset * num_gpu degradation model, we must call this func in __getitem__
ref: https://discuss.pytorch.org/t/what-happened-when-set-num-workers-0-in-dataloader/138515
"""
degradation_model_path = random.choice(self.lbo_list)
self.lbo.load_state_dict(
torch.load(degradation_model_path, map_location=lambda storage, loc: storage)['params'])
print(f'reload degradation model path for {self.rank} {self.world_size}: {degradation_model_path}\n')
logger = get_root_logger()
logger.info(f'reload degradation model path for {self.rank} {self.world_size}: {degradation_model_path}\n')
@torch.no_grad()
def custom_resize(self, x, scale=2):
if random.random() < self.lbo_prob: # learned degradation model from real-world
x = self.lbo(x)
else: # classic synthetic
h, w = x.shape[2:]
width = w // scale
height = h // scale
mode = random.choice(['area', 'bilinear', 'bicubic'])
if mode == 'area':
align_corners = None
else:
align_corners = False
x = F.interpolate(x, size=(height, width), mode=mode, align_corners=align_corners)
return x
@torch.no_grad()
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# called only once
self.reload_degradation_model()
img_gts = self.get_gt_clip(index)
# ------------- generate LQ frames --------------#
# change to CUDA implementation
img_gts = img2tensor(img_gts)
img_gts = torch.stack(img_gts, dim=0)
img_gts = img_gts.to(f'cuda:{self.rank}')
# add blur
kernel = random_mixed_kernels(['iso', 'aniso'], [0.7, 0.3], 21, self.iso_blur_range, self.aniso_blur_range)
with torch.no_grad():
kernel = torch.FloatTensor(kernel).unsqueeze(0).expand(self.num_frame, 21, 21).to(f'cuda:{self.rank}')
img_lqs = filter2D(img_gts, kernel)
# add noise
img_lqs = random_add_gaussian_noise_pt(
img_lqs, sigma_range=self.noise_range, clip=True, rounds=False, gray_prob=0.5)
# downsample
img_lqs = self.custom_resize(img_lqs)
if self.scale == 4:
img_lqs = self.custom_resize(img_lqs)
height, width = img_lqs.shape[2:]
# back to numpy since ffmpeg compression operate on cpu
img_lqs = img_lqs.detach().clamp_(0, 1).permute(0, 2, 3, 1) * 255 # B, H, W, C
img_lqs = img_lqs.type(torch.uint8).cpu().numpy()[:, :, :, ::-1]
img_lqs = np.split(img_lqs, self.num_frame, axis=0)
img_lqs = [img_lq[0] for img_lq in img_lqs]
# ffmpeg
img_lqs = self.add_ffmpeg_compression(img_lqs, width, height)
# ------------- end --------------#
img_lqs = img2tensor(img_lqs)
img_lqs = torch.stack(img_lqs, dim=0)
# img_lqs: (t, c, h, w)
# img_gts: (t, c, h, w) on gpu
return {'lq': img_lqs, 'gt': img_gts.cpu()}
def __len__(self):
return len(self.keys)
================================================
FILE: animesr/data/paired_image_dataset.py
================================================
import glob
import os
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.transforms import augment, mod_crop, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class CustomPairedImageDataset(data.Dataset):
"""Paired image dataset for training LBO.
Read real-world LQ and GT frames pairs.
The organization of these gt&lq folder is similar to AVC-Train,
except that each folder contains 200 clips, and each clip contains 11 frames.
We will ignore the first frame, so there are finally 2000 training pair data.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt, also the pseudo HR path.
dataroot_lq (str): Data root path for lq.
io_backend (dict): IO backend type and other kwarg.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, opt):
super(CustomPairedImageDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.mod_crop_scale = 8
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
omit_first_frame = opt.get('omit_first_frame', True)
start_idx = 1 if omit_first_frame else 0
self.paths = []
clip_list = os.listdir(self.lq_folder)
for clip_name in clip_list:
lq_frame_list = sorted(glob.glob(f'{self.lq_folder}/{clip_name}/*.png'))
gt_frame_list = sorted(glob.glob(f'{self.gt_folder}/{clip_name}/*.png'))
assert len(lq_frame_list) == len(gt_frame_list)
for i in range(start_idx, len(lq_frame_list)):
# omit the first frame
self.paths.append(dict([('lq_path', lq_frame_list[i]), ('gt_path', gt_frame_list[i])]))
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
lq_path = self.paths[index]['lq_path']
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
img_lq = mod_crop(img_lq, self.mod_crop_scale)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)
================================================
FILE: animesr/models/__init__.py
================================================
import importlib
from os import path as osp
from basicsr.utils import scandir
# automatically scan and import model modules for registry
# scan all the files that end with '_model.py' under the model folder
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'animesr.models.{file_name}') for file_name in model_filenames]
================================================
FILE: animesr/models/degradation_gan_model.py
================================================
from collections import OrderedDict
from basicsr.models.srgan_model import SRGANModel
from basicsr.utils.registry import MODEL_REGISTRY
@MODEL_REGISTRY.register()
class DegradationGANModel(SRGANModel):
"""Degradation model for real-world, hard-to-synthesis degradation."""
def feed_data(self, data):
# we reverse the order of lq and gt for convenient implementation
self.lq = data['gt'].to(self.device)
if 'lq' in data:
self.gt = data['lq'].to(self.device)
def optimize_parameters(self, current_iter):
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_g_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict['l_g_style'] = l_g_style
# gan loss
fake_g_pred = self.net_d(self.output)
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan'] = l_g_gan
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# real
real_d_pred = self.net_d(self.gt)
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
loss_dict['l_d_real'] = l_d_real
l_d_real.backward()
# fake
fake_d_pred = self.net_d(self.output.detach())
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d_fake'] = l_d_fake
l_d_fake.backward()
self.optimizer_d.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
================================================
FILE: animesr/models/degradation_model.py
================================================
from collections import OrderedDict
from basicsr.losses import build_loss
from basicsr.models.sr_model import SRModel
from basicsr.utils.registry import MODEL_REGISTRY
@MODEL_REGISTRY.register()
class DegradationModel(SRModel):
"""Degradation model for real-world, hard-to-synthesis degradation."""
def init_training_settings(self):
self.net_g.train()
train_opt = self.opt['train']
# define losses
self.l1_pix = build_loss(train_opt['l1_opt']).to(self.device)
self.l2_pix = build_loss(train_opt['l2_opt']).to(self.device)
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def feed_data(self, data):
# we reverse the order of lq and gt for convenient implementation
self.lq = data['gt'].to(self.device)
if 'lq' in data:
self.gt = data['lq'].to(self.device)
def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_total = 0
loss_dict = OrderedDict()
# l1 loss
l_l1 = self.l1_pix(self.output, self.gt)
l_total += l_l1
loss_dict['l_l1'] = l_l1
# l2 loss
l_l2 = self.l2_pix(self.output, self.gt)
l_total += l_l2
loss_dict['l_l2'] = l_l2
l_total.backward()
self.optimizer_g.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
================================================
FILE: animesr/models/video_recurrent_gan_model.py
================================================
from collections import OrderedDict
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY
from .video_recurrent_model import VideoRecurrentCustomModel
@MODEL_REGISTRY.register()
class VideoRecurrentGANCustomModel(VideoRecurrentCustomModel):
"""Currently, the VideoRecurrentGANModel and multi-scale discriminator are not compatible,
so we use a custom model.
"""
def init_training_settings(self):
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger = get_root_logger()
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
# build network net_g with Exponential Moving Average (EMA)
# net_g_ema only used for testing on one GPU and saving.
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
# define network net_d
self.net_d = build_network(self.opt['network_d'])
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_d', 'params')
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
self.net_g.train()
self.net_d.train()
# define losses
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
if train_opt.get('gan_opt'):
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
if train_opt['fix_flow']:
normal_params = []
flow_params = []
for name, param in self.net_g.named_parameters():
if 'spynet' in name: # The fix_flow now only works for spynet.
flow_params.append(param)
else:
normal_params.append(param)
optim_params = [
{ # add flow params first
'params': flow_params,
'lr': train_opt['lr_flow']
},
{
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
]
else:
optim_params = self.net_g.parameters()
# optimizer g
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
# optimizer d
optim_type = train_opt['optim_d'].pop('type')
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
self.optimizers.append(self.optimizer_d)
def optimize_parameters(self, current_iter):
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimize_parameters_base(current_iter)
_, _, c, h, w = self.output.size()
pix_gt = self.gt
percep_gt = self.gt
gan_gt = self.gt
if self.opt.get('l1_gt_usm', False):
pix_gt = self.gt_usm
if self.opt.get('percep_gt_usm', False):
percep_gt = self.gt_usm
if self.opt.get('gan_gt_usm', False):
gan_gt = self.gt_usm
l_g_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, pix_gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), percep_gt.view(-1, c, h, w))
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict['l_g_style'] = l_g_style
# gan loss
fake_g_pred = self.net_d(self.output.view(-1, c, h, w))
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan'] = l_g_gan
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# real
# reshape to (b*n, c, h, w)
real_d_pred = self.net_d(gan_gt.view(-1, c, h, w))
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
loss_dict['l_d_real'] = l_d_real
l_d_real.backward()
# fake
# reshape to (b*n, c, h, w)
fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach())
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d_fake'] = l_d_fake
l_d_fake.backward()
self.optimizer_d.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def save(self, epoch, current_iter):
if self.ema_decay > 0:
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
else:
self.save_network(self.net_g, 'net_g', current_iter)
self.save_network(self.net_d, 'net_d', current_iter)
self.save_training_state(epoch, current_iter)
================================================
FILE: animesr/models/video_recurrent_model.py
================================================
import cv2
import os
import torch
from collections import OrderedDict
from os import path as osp
from torch import distributed as dist
from tqdm import tqdm
from basicsr.models.video_base_model import VideoBaseModel
from basicsr.utils import USMSharp, get_root_logger, imwrite, tensor2img
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import MODEL_REGISTRY
@MODEL_REGISTRY.register()
class VideoRecurrentCustomModel(VideoBaseModel):
def __init__(self, opt):
super(VideoRecurrentCustomModel, self).__init__(opt)
if self.is_train:
self.fix_flow_iter = opt['train'].get('fix_flow')
self.idx = 1
lq_from_usm = opt['datasets']['train'].get('lq_from_usm', False)
assert lq_from_usm is False
self.usm_sharp_gt = opt['datasets']['train'].get('usm_sharp_gt', False)
if self.usm_sharp_gt:
usm_radius = opt['datasets']['train'].get('usm_radius', 50)
self.usm_sharpener = USMSharp(radius=usm_radius).cuda()
self.usm_weight = opt['datasets']['train'].get('usm_weight', 0.5)
self.usm_threshold = opt['datasets']['train'].get('usm_threshold', 10)
@torch.no_grad()
def feed_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
if 'gt_usm' in data:
self.gt_usm = data['gt_usm'].to(self.device)
logger = get_root_logger()
logger.warning(
'since lq is not from gt_usm, '
'we should put the usm_sharp operation outside the dataloader to speed up the traning time')
elif self.usm_sharp_gt:
b, n, c, h, w = self.gt.size()
self.gt_usm = self.usm_sharpener(
self.gt.view(b * n, c, h, w), weight=self.usm_weight,
threshold=self.usm_threshold).view(b, n, c, h, w)
# if self.opt['rank'] == 0 and 'debug' in self.opt['name']:
# import torchvision
# os.makedirs('tmp/gt', exist_ok=True)
# os.makedirs('tmp/gt_usm', exist_ok=True)
# os.makedirs('tmp/lq', exist_ok=True)
# print(self.idx)
# for i in range(15):
# torchvision.utils.save_image(
# self.lq[:, i, :, :, :],
# f'tmp/lq/lq{self.idx}_{i}.png',
# nrow=4,
# padding=2,
# normalize=True,
# range=(0, 1))
# torchvision.utils.save_image(
# self.gt[:, i, :, :, :],
# f'tmp/gt/gt{self.idx}_{i}.png',
# nrow=4,
# padding=2,
# normalize=True,
# range=(0, 1))
# torchvision.utils.save_image(
# self.gt_usm[:, i, :, :, :],
# f'tmp/gt_usm/gt_usm{self.idx}_{i}.png',
# nrow=4,
# padding=2,
# normalize=True,
# range=(0, 1))
# self.idx += 1
# if self.idx >= 20:
# exit()
def setup_optimizers(self):
train_opt = self.opt['train']
flow_lr_mul = train_opt.get('flow_lr_mul', 1)
logger = get_root_logger()
logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.')
if flow_lr_mul == 1:
optim_params = self.net_g.parameters()
else: # separate flow params and normal params for different lr
normal_params = []
flow_params = []
for name, param in self.net_g.named_parameters():
if 'spynet' in name:
flow_params.append(param)
else:
normal_params.append(param)
optim_params = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
{
'params': flow_params,
'lr': train_opt['optim_g']['lr'] * flow_lr_mul
},
]
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
def optimize_parameters_base(self, current_iter):
if self.fix_flow_iter:
logger = get_root_logger()
if current_iter == 1:
logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
for name, param in self.net_g.named_parameters():
if 'spynet' in name or 'edvr' in name:
param.requires_grad_(False)
elif current_iter == self.fix_flow_iter:
logger.warning('Train all the parameters.')
self.net_g.requires_grad_(True)
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
def optimize_parameters(self, current_iter):
self.optimize_parameters_base(current_iter)
_, _, c, h, w = self.output.size()
pix_gt = self.gt
percep_gt = self.gt
if self.opt.get('l1_gt_usm', False):
pix_gt = self.gt_usm
if self.opt.get('percep_gt_usm', False):
percep_gt = self.gt_usm
l_total = 0
loss_dict = OrderedDict()
# pixel loss
if self.cri_pix:
l_pix = self.cri_pix(self.output, pix_gt)
l_total += l_pix
loss_dict['l_pix'] = l_pix
# perceptual loss
if self.cri_perceptual:
l_percep, l_style = self.cri_perceptual(self.output.view(-1, c, h, w), percep_gt.view(-1, c, h, w))
if l_percep is not None:
l_total += l_percep
loss_dict['l_percep'] = l_percep
if l_style is not None:
l_total += l_style
loss_dict['l_style'] = l_style
l_total.backward()
self.optimizer_g.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
"""dist_test actually, no gt, no metrics"""
dataset = dataloader.dataset
dataset_name = dataset.opt['name']
assert dataset_name.endswith('CoreFrames')
rank, world_size = get_dist_info()
num_folders = len(dataset)
num_pad = (world_size - (num_folders % world_size)) % world_size
if rank == 0:
pbar = tqdm(total=len(dataset), unit='folder')
os.makedirs(osp.join(self.opt['path']['visualization'], dataset_name, str(current_iter)), exist_ok=True)
if self.opt['dist']:
dist.barrier()
# Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
# (To avoid wait-dead)
for i in range(rank, num_folders + num_pad, world_size):
idx = min(i, num_folders - 1)
val_data = dataset[idx]
folder = val_data['folder']
# compute outputs
val_data['lq'].unsqueeze_(0)
self.feed_data(val_data)
val_data['lq'].squeeze_(0)
self.test()
visuals = self.get_current_visuals()
# tentative for out of GPU memory
del self.lq
del self.output
if 'gt' in visuals:
del self.gt
torch.cuda.empty_cache()
# evaluate
if i < num_folders:
for idx in range(visuals['result'].size(1)):
result = visuals['result'][0, idx, :, :, :]
result_img = tensor2img([result]) # uint8, bgr
# since we keep all frames, scale of 4 is not very friendly to storage space
# so we use a default scale of 2 to save the frames
save_scale = self.opt.get('savescale', 2)
net_scale = self.opt.get('scale')
if save_scale != net_scale:
h, w = result_img.shape[0:2]
result_img = cv2.resize(
result_img, (w // net_scale * save_scale, h // net_scale * save_scale),
interpolation=cv2.INTER_LANCZOS4)
if save_img:
img_path = osp.join(self.opt['path']['visualization'], dataset_name, str(current_iter),
f"{folder}_{idx:08d}_{self.opt['name'][:5]}.png")
# image name only for REDS dataset
imwrite(result_img, img_path)
# progress bar
if rank == 0:
for _ in range(world_size):
pbar.update(1)
pbar.set_description(f'Folder: {folder}')
if rank == 0:
pbar.close()
def test(self):
n = self.lq.size(1)
self.net_g.eval()
flip_seq = self.opt['val'].get('flip_seq', False)
self.center_frame_only = self.opt['val'].get('center_frame_only', False)
if flip_seq:
self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1)
with torch.no_grad():
self.output = self.net_g(self.lq)
if flip_seq:
output_1 = self.output[:, :n, :, :, :]
output_2 = self.output[:, n:, :, :, :].flip(1)
self.output = 0.5 * (output_1 + output_2)
if self.center_frame_only:
self.output = self.output[:, n // 2, :, :, :]
self.net_g.train()
================================================
FILE: animesr/test.py
================================================
# flake8: noqa
import os.path as osp
import animesr.archs
import animesr.data
import animesr.models
from basicsr.test import test_pipeline
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
test_pipeline(root_path)
================================================
FILE: animesr/train.py
================================================
# flake8: noqa
import os.path as osp
import animesr.archs
import animesr.data
import animesr.models
from basicsr.train import train_pipeline
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)
================================================
FILE: animesr/utils/__init__.py
================================================
# -*- coding: utf-8 -*-
================================================
FILE: animesr/utils/inference_base.py
================================================
import argparse
import os.path
import torch
from animesr.archs.vsr_arch import MSRSWVSR
def get_base_argument_parser() -> argparse.ArgumentParser:
"""get the base argument parser for inference scripts"""
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='inputs', help='input test image folder or video path')
parser.add_argument('-o', '--output', type=str, default='results', help='save image/video path')
parser.add_argument(
'-n',
'--model_name',
type=str,
default='AnimeSR_v2',
help='Model names: AnimeSR_v2 | AnimeSR_v1-PaperModel. Default:AnimeSR_v2')
parser.add_argument(
'-s',
'--outscale',
type=int,
default=4,
help='The netscale is x4, but you can achieve arbitrary output scale (e.g., x2) with the argument outscale'
'The program will further perform cheap resize operation after the AnimeSR output. '
'This is useful when you want to save disk space or avoid too large-resolution output')
parser.add_argument(
'--expname', type=str, default='animesr', help='A unique name to identify your current inference')
parser.add_argument(
'--netscale',
type=int,
default=4,
help='the released models are all x4 models, only change this if you train a x2 or x1 model by yourself')
parser.add_argument(
'--mod_scale',
type=int,
default=4,
help='the scale used for mod crop, since AnimeSR use a multi-scale arch, so the edge should be divisible by 4')
parser.add_argument('--fps', type=int, default=None, help='fps of the sr videos')
parser.add_argument('--half', action='store_true', help='use half precision to inference')
return parser
def get_inference_model(args, device) -> MSRSWVSR:
"""return an on device model with eval mode"""
# set up model
model = MSRSWVSR(num_feat=64, num_block=[5, 3, 2], netscale=args.netscale)
model_path = f'weights/{args.model_name}.pth'
assert os.path.isfile(model_path), \
f'{model_path} does not exist, please make sure you successfully download the pretrained models ' \
f'and put them into the weights folder'
# load checkpoint
loadnet = torch.load(model_path)
model.load_state_dict(loadnet, strict=True)
model.eval()
model = model.to(device)
# num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
# print(num_parameters)
# exit(0)
return model.half() if args.half else model
================================================
FILE: animesr/utils/shot_detector.py
================================================
# The codes below partially refer to the PySceneDetect. According
# to its BSD 3-Clause License, we keep the following.
#
# PySceneDetect: Python-Based Video Scene Detector
# ---------------------------------------------------------------
# [ Site: http://www.bcastell.com/projects/PySceneDetect/ ]
# [ Github: https://github.com/Breakthrough/PySceneDetect/ ]
# [ Documentation: http://pyscenedetect.readthedocs.org/ ]
#
# Copyright (C) 2014-2020 Brandon Castellano <http://www.bcastell.com>.
#
# PySceneDetect is licensed under the BSD 3-Clause License; see the included
# LICENSE file, or visit one of the following pages for details:
# - https://github.com/Breakthrough/PySceneDetect/
# - http://www.bcastell.com/projects/PySceneDetect/
#
# This software uses Numpy, OpenCV, click, tqdm, simpletable, and pytest.
# See the included LICENSE files or one of the above URLs for more information.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import cv2
import glob
import numpy as np
import os
from tqdm import tqdm
DEFAULT_DOWNSCALE_FACTORS = {
3200: 12, # ~4k
2100: 8, # ~2k
1700: 6, # ~1080p
1200: 5,
900: 4, # ~720p
600: 3,
400: 1 # ~480p
}
def compute_downscale_factor(frame_width):
"""Compute Downscale Factor: Returns the optimal default downscale factor
based on a video's resolution (specifically, the width parameter).
Returns:
int: The defalt downscale factor to use with a video of frame_height x
frame_width.
"""
for width in sorted(DEFAULT_DOWNSCALE_FACTORS, reverse=True):
if frame_width >= width:
return DEFAULT_DOWNSCALE_FACTORS[width]
return 1
class ShotDetector(object):
"""Detects fast cuts using changes in colour and intensity between frames.
Detect shot boundary using HSV and LUV information.
"""
def __init__(self, threshold=30.0, min_shot_len=15):
super(ShotDetector, self).__init__()
self.hsv_threshold = threshold
self.delta_hsv_gap_threshold = 10
self.luv_threshold = 40
self.hsv_weight = 5
# minimum length (frames length) of any given shot
self.min_shot_len = min_shot_len
self.last_frame = None
self.last_shot_cut = None
self.last_hsv = None
self._metric_keys = [
'hsv_content_val', 'delta_hsv_hue', 'delta_hsv_sat', 'delta_hsv_lum', 'luv_content_val', 'delta_luv_hue',
'delta_luv_sat', 'delta_luv_lum'
]
self.cli_name = 'detect-content'
self.last_luv = None
self.cut_list = []
def add_cut(self, cut):
num_cuts = len(self.cut_list)
if num_cuts == 0:
self.cut_list.append([0, cut - 1])
else:
self.cut_list.append([self.cut_list[num_cuts - 1][1] + 1, cut - 1])
def process_frame(self, frame_num, frame_img):
"""Similar to ThresholdDetector, but using the HSV colour space
DIFFERENCE instead of single-frame RGB/grayscale intensity (thus cannot
detect slow fades with this method).
Args:
frame_num (int): Frame number of frame that is being passed.
frame_img (Optional[int]): Decoded frame image (np.ndarray) to
perform shot detection on. Can be None *only* if the
self.is_processing_required() method
(inhereted from the base shotDetector class) returns True.
Returns:
List[int]: List of frames where shot cuts have been detected.
There may be 0 or more frames in the list, and not necessarily
the same as frame_num.
"""
cut_list = []
# if self.last_frame is not None:
# # Change in average of HSV (hsv), (h)ue only,
# # (s)aturation only, (l)uminance only.
delta_hsv_avg, delta_hsv_h, delta_hsv_s, delta_hsv_v = 0.0, 0.0, 0.0, 0.0
delta_luv_avg, delta_luv_h, delta_luv_s, delta_luv_v = 0.0, 0.0, 0.0, 0.0
if frame_num == 0:
self.last_frame = frame_img.copy()
return cut_list
else:
num_pixels = frame_img.shape[0] * frame_img.shape[1]
curr_luv = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_BGR2Luv))
curr_hsv = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_BGR2HSV))
last_hsv = self.last_hsv
last_luv = self.last_luv
if not last_hsv:
last_hsv = cv2.split(cv2.cvtColor(self.last_frame, cv2.COLOR_BGR2HSV))
last_luv = cv2.split(cv2.cvtColor(self.last_frame, cv2.COLOR_BGR2Luv))
delta_hsv = [0, 0, 0, 0]
for i in range(3):
num_pixels = curr_hsv[i].shape[0] * curr_hsv[i].shape[1]
curr_hsv[i] = curr_hsv[i].astype(np.int32)
last_hsv[i] = last_hsv[i].astype(np.int32)
delta_hsv[i] = np.sum(np.abs(curr_hsv[i] - last_hsv[i])) / float(num_pixels)
delta_hsv[3] = sum(delta_hsv[0:3]) / 3.0
delta_hsv_h, delta_hsv_s, delta_hsv_v, delta_hsv_avg = \
delta_hsv
delta_luv = [0, 0, 0, 0]
for i in range(3):
num_pixels = curr_luv[i].shape[0] * curr_luv[i].shape[1]
curr_luv[i] = curr_luv[i].astype(np.int32)
last_luv[i] = last_luv[i].astype(np.int32)
delta_luv[i] = np.sum(np.abs(curr_luv[i] - last_luv[i])) / float(num_pixels)
delta_luv[3] = sum(delta_luv[0:3]) / 3.0
delta_luv_h, delta_luv_s, delta_luv_v, delta_luv_avg = \
delta_luv
self.last_hsv = curr_hsv
self.last_luv = curr_luv
if delta_hsv_avg >= self.hsv_threshold and delta_hsv_avg - self.hsv_threshold >= self.delta_hsv_gap_threshold:
if self.last_shot_cut is None or ((frame_num - self.last_shot_cut) >= self.min_shot_len):
cut_list.append(frame_num)
self.last_shot_cut = frame_num
elif delta_hsv_avg >= self.hsv_threshold and \
delta_hsv_avg - self.hsv_threshold < \
self.delta_hsv_gap_threshold and \
delta_luv_avg + self.hsv_weight * \
(delta_hsv_avg - self.hsv_threshold) > self.luv_threshold:
if self.last_shot_cut is None or ((frame_num - self.last_shot_cut) >= self.min_shot_len):
cut_list.append(frame_num)
self.last_shot_cut = frame_num
self.last_frame = frame_img.copy()
return cut_list
def detect_shots(self, frame_source, frame_skip=0, show_progress=True, keep_resolution=False):
"""Perform shot detection on the given frame_source using the added
shotDetectors.
Blocks until all frames in the frame_source have been processed.
Results can be obtained by calling either the get_shot_list()
or get_cut_list() methods.
Arguments:
frame_source (shotdetect.video_manager.VideoManager or
cv2.VideoCapture):
A source of frames to process (using frame_source.read() as in
VideoCapture).
VideoManager is preferred as it allows concatenation of
multiple videos as well as seeking, by defining start time
and end time/duration.
end_time (int or FrameTimecode): Maximum number of frames to detect
(set to None to detect all available frames). Only needed for
OpenCV
VideoCapture objects; for VideoManager objects, use
set_duration() instead.
frame_skip (int): Not recommended except for extremely high
framerate videos.
Number of frames to skip (i.e. process every 1 in N+1 frames,
where N is frame_skip, processing only 1/N+1 percent of the
video,
speeding up the detection time at the expense of accuracy).
`frame_skip` **must** be 0 (the default) when using a
StatsManager.
show_progress (bool): If True, and the ``tqdm`` module is
available, displays
a progress bar with the progress, framerate, and expected
time to
complete processing the video frame source.
Raises:
ValueError: `frame_skip` **must** be 0 (the default)
if the shotManager
was constructed with a StatsManager object.
"""
if frame_skip > 0 and self._stats_manager is not None:
raise ValueError('frame_skip must be 0 when using a StatsManager.')
curr_frame = 0
frame_paths = sorted(glob.glob(os.path.join(frame_source, '*')))
total_frames = len(frame_paths)
end_frame = total_frames
progress_bar = None
if tqdm and show_progress:
progress_bar = tqdm(total=total_frames, unit='frames')
try:
while True:
if end_frame is not None and curr_frame >= end_frame:
break
frame_im = cv2.imread(frame_paths[curr_frame])
if not keep_resolution:
if curr_frame == 0:
downscale_factor = compute_downscale_factor(frame_im.shape[1])
frame_im = frame_im[::downscale_factor, ::downscale_factor, :]
cut = self.process_frame(curr_frame, frame_im)
if len(cut) != 0:
self.add_cut(cut[0])
curr_frame += 1
if progress_bar:
progress_bar.update(1)
finally:
if progress_bar:
progress_bar.close()
return self.cut_list
================================================
FILE: animesr/utils/video_util.py
================================================
import glob
import os
import subprocess
default_ffmpeg_exe_path = 'ffmpeg'
default_ffprobe_exe_path = 'ffprobe'
default_ffmpeg_vcodec = 'h264'
default_ffmpeg_pix_fmt = 'yuv420p'
def get_video_fps(video_path, ret_type='float'):
"""Get the fps of the video.
Args:
video_path (str): the video path;
ret_type (str): the return type, it supports `str`, `float`, and `tuple` (numerator, denominator).
Returns:
--fps (str): if ret_type is `str`.
--fps (float): if ret_type is `float`.
--fps (tuple): if ret_type is tuple, (numerator, denominator).
"""
global default_ffprobe_exe_path
ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path)
cmd = [
ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-of', 'default=noprint_wrappers=1:nokey=1',
'-show_entries', 'stream=r_frame_rate', video_path
]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
fps = result.stdout.decode('utf-8').strip()
# e.g. 30/1
numerator, denominator = map(lambda x: int(x), fps.split('/'))
if ret_type == 'float':
return numerator / denominator
elif ret_type == 'str':
return str(numerator / denominator)
else:
return numerator, denominator
def get_video_num_frames(video_path):
"""Get the video's total number of frames."""
global default_ffprobe_exe_path
ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path)
cmd = [
ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-count_packets', '-of',
'default=noprint_wrappers=1:nokey=1', '-show_entries', 'stream=nb_read_packets', video_path
]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
nb_frames = result.stdout.decode('utf-8').strip()
return int(nb_frames)
def get_video_bitrate(video_path):
"""Get the bitrate of the video."""
global default_ffprobe_exe_path
ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path)
cmd = [
ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-of', 'default=noprint_wrappers=1:nokey=1',
'-show_entries', 'stream=bit_rate', video_path
]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
bitrate = result.stdout.decode('utf-8').strip()
if bitrate == 'N/A':
return bitrate
return int(bitrate) // 1000
def get_video_resolution(video_path):
"""Get the resolution (h and w) of the video.
Args:
video_path (str): the video path;
Returns:
h, w (int)
"""
global default_ffprobe_exe_path
ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path)
cmd = [
ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-of', 'csv=s=x:p=0', '-show_entries',
'stream=width,height', video_path
]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
resolution = result.stdout.decode('utf-8').strip()
# print(resolution)
w, h = resolution.split('x')
return int(h), int(w)
def video2frames(video_path, out_dir, force=False, high_quality=True, ss=None, to=None, vf=None):
"""Extract frames from the video
Args:
out_dir: where to save the frames
force: if out_dir is not empty, forceTrue will still extract frames
high_quality: whether to use the highest quality
ss: start time, format HH.MM.SS[.xxx], if None, extract full video
to: end time, format HH.MM.SS[.xxx], if None, extract full video
vf: video filter
"""
global default_ffmpeg_exe_path
ffmpeg_exc_path = os.environ.get('ffmpeg_exe_path', default_ffmpeg_exe_path)
imgs = glob.glob(os.path.join(out_dir, '*.png'))
length = len(imgs)
if length > 0:
print(f'{out_dir} already has frames!, force extract = {force}')
if not force:
return out_dir
print(f'extracting frames for {video_path}')
cmd = [
ffmpeg_exc_path,
f'-i {video_path}',
'-v error',
f'-ss {ss} -to {to}' if ss is not None and to is not None else '',
'-qscale:v 1 -qmin 1 -qmax 1 -vsync 0' if high_quality else '',
f'-vf {vf}' if vf is not None else '',
f'{out_dir}/%08d.png',
]
print(' '.join(cmd))
subprocess.call(' '.join(cmd), shell=True)
return out_dir
def frames2video(frames_dir, out_path, fps=25, filter='*', suffix=None):
"""Combine frames under a folder to video
Args:
frames_dir: input folder where frames locate
out_path: the output video path
fps: output video fps
suffix: the frame suffix, e.g., png jpg
"""
global default_ffmpeg_vcodec, default_ffmpeg_pix_fmt, default_ffmpeg_exe_path
ffmpeg_exc_path = os.environ.get('ffmpeg_exe_path', default_ffmpeg_exe_path)
vcodec = os.environ.get('ffmpeg_vcodec', default_ffmpeg_vcodec)
pix_fmt = os.environ.get('ffmpeg_pix_fmt', default_ffmpeg_pix_fmt)
if suffix is None:
images_names = os.listdir(frames_dir)
image_name = images_names[0]
suffix = image_name.split('.')[-1]
cmd = [
ffmpeg_exc_path,
'-y',
'-r', str(fps),
'-f', 'image2',
'-pattern_type', 'glob',
'-i', f'{frames_dir}/{filter}.{suffix}',
'-vcodec', vcodec,
'-pix_fmt', pix_fmt,
out_path
] # yapf: disable
print(' '.join(cmd))
subprocess.call(cmd)
return out_path
================================================
FILE: cog.yaml
================================================
build:
gpu: true
cuda: "11.6.2"
python_version: "3.10"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
- "ffmpeg"
python_packages:
- "ipython==8.4.0"
- "torch==1.13.1"
- "torchvision==0.14.1"
- "ffmpeg-python==0.2.0"
- "facexlib==0.2.5"
- "basicsr==1.4.2"
- "opencv-python==4.7.0.68"
- "Pillow==9.4.0"
- "psutil==5.9.4"
- "tqdm==4.64.1"
predict: "predict.py:Predictor"
================================================
FILE: options/train_animesr_step1_gan_BasicOPonly.yml
================================================
# general settings
name: train_animesr_step1_gan_BasicOPonly
model_type: VideoRecurrentGANCustomModel
scale: 4
num_gpu: auto # set num_gpu: 0 for cpu mode
manual_seed: 0
# USM the ground-truth
l1_gt_usm: False
percep_gt_usm: False
gan_gt_usm: False
# dataset and data loader settings
datasets:
train:
name: AVC-Train
type: FFMPEGAnimeDataset
dataroot_gt: datasets/AVC-Train # TO_MODIFY
test_mode: False
io_backend:
type: disk
num_frame: 15
gt_size: 256
interval_list: [1, 2, 3]
random_reverse: True
use_flip: true
use_rot: true
usm_sharp_gt: False
# data loader
use_shuffle: true
num_worker_per_gpu: 20
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
pin_memory: True
# network structures
network_g:
type: MSRSWVSR
num_feat: 64
num_block: [5, 3, 2]
network_d:
type: MultiScaleDiscriminator
num_in_ch: 3
num_feat: 64
num_layers: [3, 3, 3]
max_nf_mult: 8
norm_type: none
use_sigmoid: False
use_sn: True
use_downscale: True
# path
path:
pretrain_network_g: experiments/train_animesr_step1_net_BasicOPonly/models/net_g_300000.pth
param_key_g: params
strict_load_g: true
resume_state: ~
# training settings
train:
ema_decay: 0.999
optim_g:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
lr_flow: !!float 1e-5
optim_d:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [300000]
gamma: 0.5
total_iter: 300000
warmup_iter: -1 # no warm up
fix_flow: ~
# losses
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: PerceptualLoss
layer_weights:
# before relu
'conv1_2': 0.1
'conv2_2': 0.1
'conv3_4': 1
'conv4_4': 1
'conv5_4': 1
vgg_type: vgg19
use_input_norm: true
range_norm: false
perceptual_weight: 1.0
style_weight: 0
criterion: l1
gan_opt:
type: MultiScaleGANLoss
gan_type: lsgan
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1.0
net_d_iters: 1
net_d_init_iters: 0
# validation settings
#val:
# val_freq: !!float 1e4
# save_img: true
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500
find_unused_parameters: true
================================================
FILE: options/train_animesr_step1_net_BasicOPonly.yml
================================================
# general settings
name: train_animesr_step1_net_BasicOPonly
model_type: VideoRecurrentCustomModel
scale: 4
num_gpu: auto # set num_gpu: 0 for cpu mode
manual_seed: 0
# USM the ground-truth
l1_gt_usm: True
# dataset and data loader settings
datasets:
train:
name: AVC-Train
type: FFMPEGAnimeDataset
dataroot_gt: datasets/AVC-Train # TO_MODIFY
test_mode: False
io_backend:
type: disk
num_frame: 15
gt_size: 256
interval_list: [1, 2, 3]
random_reverse: True
use_flip: true
use_rot: true
usm_sharp_gt: True
usm_weight: 0.3
usm_radius: 50
# data loader
use_shuffle: true
num_worker_per_gpu: 20
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
pin_memory: True
# network structures
network_g:
type: MSRSWVSR
num_feat: 64
num_block: [5, 3, 2]
# path
path:
resume_state: ~
# training settings
train:
ema_decay: 0.999
optim_g:
type: Adam
lr: !!float 2e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [300000]
gamma: 0.5
total_iter: 300000
warmup_iter: -1 # no warm up
# losses
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
# validation settings
#val:
# val_freq: !!float 1e4
# save_img: true
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500
find_unused_parameters: true
================================================
FILE: options/train_animesr_step2_lbo_1_gan.yml
================================================
# general settings
name: train_animesr_step2_lbo_1_gan
model_type: DegradationGANModel
scale: 2
num_gpu: auto # set num_gpu: 0 for cpu mode
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: LBO_1
type: CustomPairedImageDataset
dataroot_gt: results/input_rescaling_strategy_lbo_1/frames # TO_MODIFY
dataroot_lq: datasets/lbo_training_data/real_world_video_to_train_lbo_1 # TO_MODIFY
io_backend:
type: disk
gt_size: 256
use_hflip: true
use_rot: true
# data loader
use_shuffle: true
num_worker_per_gpu: 12
batch_size_per_gpu: 16
dataset_enlarge_ratio: 200
prefetch_mode: ~
# network structures
network_g:
type: SimpleDegradationArch
num_in_ch: 3
num_out_ch: 3
num_feat: 64
downscale: 2
network_d:
type: MultiScaleDiscriminator
num_in_ch: 3
num_feat: 64
num_layers: [3]
max_nf_mult: 8
norm_type: none
use_sigmoid: False
use_sn: True
use_downscale: True
# path
path:
pretrain_network_g: experiments/train_animesr_step2_lbo_1_net/models/net_g_100000.pth
param_key_g: params
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [50000]
gamma: 0.5
total_iter: 100000
warmup_iter: -1 # no warm up
# losses
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: PerceptualLoss
layer_weights:
# before relu
'conv1_2': 0.1
'conv2_2': 0.1
'conv3_4': 1
'conv4_4': 1
'conv5_4': 1
vgg_type: vgg19
use_input_norm: true
range_norm: false
perceptual_weight: 1.0
style_weight: 0
criterion: l1
gan_opt:
type: MultiScaleGANLoss
gan_type: lsgan
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1.0
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500
================================================
FILE: options/train_animesr_step2_lbo_1_net.yml
================================================
# general settings
name: train_animesr_step2_lbo_1_net
model_type: DegradationModel
scale: 2
num_gpu: auto # set num_gpu: 0 for cpu mode
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: LBO_1
type: CustomPairedImageDataset
dataroot_gt: results/input_rescaling_strategy_lbo_1/frames # TO_MODIFY
dataroot_lq: datasets/lbo_training_data/real_world_video_to_train_lbo_1 # TO_MODIFY
io_backend:
type: disk
gt_size: 256
use_hflip: true
use_rot: true
# data loader
use_shuffle: true
num_worker_per_gpu: 12
batch_size_per_gpu: 16
dataset_enlarge_ratio: 200
prefetch_mode: ~
# network structures
network_g:
type: SimpleDegradationArch
num_in_ch: 3
num_out_ch: 3
num_feat: 64
downscale: 2
# path
path:
resume_state: ~
# training settings
train:
optim_g:
type: Adam
lr: !!float 2e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [50000]
gamma: 0.5
total_iter: 100000
warmup_iter: -1 # no warm up
# losses
l1_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
l2_opt:
type: MSELoss
loss_weight: 1.0
reduction: mean
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500
================================================
FILE: options/train_animesr_step3_gan_3LBOs.yml
================================================
# general settings
name: train_animesr_step3_gan_3LBOs
model_type: VideoRecurrentGANCustomModel
scale: 4
num_gpu: auto # set num_gpu: 0 for cpu mode
manual_seed: 0
# USM the ground-truth
l1_gt_usm: False
percep_gt_usm: False
gan_gt_usm: False
# dataset and data loader settings
datasets:
train:
name: AVC-Train
type: FFMPEGAnimeLBODataset
dataroot_gt: datasets/AVC-Train # TO_MODIFY
test_mode: False
io_backend:
type: disk
num_frame: 15
gt_size: 256
interval_list: [1, 2, 3]
random_reverse: True
use_flip: true
use_rot: true
usm_sharp_gt: False
degradation_model_path: [weights/pretrained_lbo_1.pth, weights/pretrained_lbo_2.pth, weights/pretrained_lbo_3.pth] # TO_MODIFY
# data loader
use_shuffle: true
num_worker_per_gpu: 5
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
pin_memory: True
# network structures
network_g:
type: MSRSWVSR
num_feat: 64
num_block: [5, 3, 2]
network_d:
type: MultiScaleDiscriminator
num_in_ch: 3
num_feat: 64
num_layers: [3, 3, 3]
max_nf_mult: 8
norm_type: none
use_sigmoid: False
use_sn: True
use_downscale: True
# path
path:
pretrain_network_g: experiments/train_animesr_step1_net_BasicOPonly/models/net_g_300000.pth
param_key_g: params
strict_load_g: true
resume_state: ~
# training settings
train:
ema_decay: 0.999
optim_g:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
lr_flow: !!float 1e-5
optim_d:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [300000]
gamma: 0.5
total_iter: 300000
warmup_iter: -1 # no warm up
fix_flow: ~
# losses
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: PerceptualLoss
layer_weights:
# before relu
'conv1_2': 0.1
'conv2_2': 0.1
'conv3_4': 1
'conv4_4': 1
'conv5_4': 1
vgg_type: vgg19
use_input_norm: true
range_norm: false
perceptual_weight: 1.0
style_weight: 0
criterion: l1
gan_opt:
type: MultiScaleGANLoss
gan_type: lsgan
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1.0
net_d_iters: 1
net_d_init_iters: 0
# validation settings
#val:
# val_freq: !!float 1e4
# save_img: true
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500
find_unused_parameters: true
================================================
FILE: predict.py
================================================
import os
import shutil
import tempfile
from subprocess import call
from zipfile import ZipFile
from typing import Optional
import mimetypes
import torch
from cog import BasePredictor, Input, Path, BaseModel
call("python setup.py develop", shell=True)
class ModelOutput(BaseModel):
video: Path
sr_frames: Optional[Path]
class Predictor(BasePredictor):
@torch.inference_mode()
def predict(
self,
video: Path = Input(
description="Input video file",
default=None,
),
frames: Path = Input(
description="Zip file of frames of a video. Ignored when video is provided.",
default=None,
),
) -> ModelOutput:
"""Run a single prediction on the model"""
assert frames or video, "Please provide frames of video input."
out_path = "cog_temp"
if os.path.exists(out_path):
shutil.rmtree(out_path)
os.makedirs(out_path, exist_ok=True)
if video:
print("processing video...")
cmd = (
"python scripts/inference_animesr_video.py -i "
+ str(video)
+ " -o "
+ out_path
+ " -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 1"
)
call(cmd, shell=True)
else:
print("processing frames...")
unzip_frames = "cog_frames_temp"
if os.path.exists(unzip_frames):
shutil.rmtree(unzip_frames)
os.makedirs(unzip_frames)
with ZipFile(str(frames), "r") as zip_ref:
for zip_info in zip_ref.infolist():
if zip_info.filename[-1] == "/" or zip_info.filename.startswith(
"__MACOSX"
):
continue
mt = mimetypes.guess_type(zip_info.filename)
if mt and mt[0] and mt[0].startswith("image/"):
zip_info.filename = os.path.basename(zip_info.filename)
zip_ref.extract(zip_info, unzip_frames)
cmd = (
"python scripts/inference_animesr_frames.py -i "
+ unzip_frames
+ " -o "
+ out_path
+ " -n AnimeSR_v2 --expname animesr_v2 --save_video_too --fps 20"
)
call(cmd, shell=True)
frames_output = Path(tempfile.mkdtemp()) / "out.zip"
frames_out_dir = os.listdir(f"{out_path}/animesr_v2/frames")
assert len(frames_out_dir) == 1
frames_path = os.path.join(
f"{out_path}/animesr_v2/frames", frames_out_dir[0]
)
# by defult, sr_frames will be saved in cog_temp/animesr_v2/frames
sr_frames_files = os.listdir(frames_path)
with ZipFile(str(frames_output), "w") as zip:
for img in sr_frames_files:
zip.write(os.path.join(frames_path, img))
# by defult, video will be saved in cog_temp/animesr_v2/videos
video_out_dir = os.listdir(f"{out_path}/animesr_v2/videos")
assert len(video_out_dir) == 1
if video_out_dir[0].endswith(".mp4"):
source = os.path.join(f"{out_path}/animesr_v2/videos", video_out_dir[0])
else:
video_output = os.listdir(
f"{out_path}/animesr_v2/videos/{video_out_dir[0]}"
)[0]
source = os.path.join(
f"{out_path}/animesr_v2/videos", video_out_dir[0], video_output
)
video_path = Path(tempfile.mkdtemp()) / "out.mp4"
shutil.copy(source, str(video_path))
if video:
return ModelOutput(video=video_path)
return ModelOutput(sr_frames=frames_output, video=video_path)
================================================
FILE: requirements.txt
================================================
basicsr
facexlib
ffmpeg-python
numpy
opencv-python
pillow
psutil
torch
torchvision
tqdm
================================================
FILE: scripts/anime_videos_preprocessing.py
================================================
import argparse
import cv2
import glob
import numpy as np
import os
import shutil
import torch
import torchvision
from multiprocessing import Pool
from os import path as osp
from PIL import Image
from tqdm import tqdm
from animesr.utils import video_util
from animesr.utils.shot_detector import ShotDetector
from basicsr.archs.spynet_arch import SpyNet
from basicsr.utils import img2tensor
from basicsr.utils.download_util import download_file_from_google_drive
from facexlib.assessment import init_assessment_model
def main(args):
"""A script to prepare anime videos.
The preparation can be divided into following steps:
1. use ffmpeg to extract frames
2. shot detection
3. estimate flow
4. detect black frames
5. use hyperIQA to evaluate the quality of frames
6. generate at most 5 clips per video
"""
opt = dict()
opt['debug'] = args.debug
opt['n_thread'] = args.n_thread
opt['ss_idx'] = args.ss_idx
opt['to_idx'] = args.to_idx
# params for step1: extract frames
opt['video_root'] = f'{args.dataroot}/raw_videos'
opt['save_frames_root'] = f'{args.dataroot}/frames'
opt['meta_files_root'] = f'{args.dataroot}/meta'
# params for step2: shot detection
opt['detect_shot_root'] = f'{args.dataroot}/detect_shot'
# params for step3: flow estimation
opt['estimate_flow_root'] = f'{args.dataroot}/estimate_flow'
opt['spy_pretrain_weight'] = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
opt['downscale_factor'] = 1
# params for step4: detect black frames
opt['black_flag_root'] = f'{args.dataroot}/black_flag'
opt['black_threshold'] = 0.98
# params for step5: image quality assessment
opt['num_patch_per_iqa'] = 5
opt['iqa_score_root'] = f'{args.dataroot}/iqa_score'
# params for step6: generate clips
opt['num_frames_per_clip'] = args.n_frames_per_clip
opt['num_clips_per_video'] = args.n_clips_per_video
opt['select_clips_root'] = f'{args.dataroot}/{args.select_clip_root}'
opt['select_clips_meta'] = osp.join(opt['select_clips_root'], 'meta_info')
opt['select_clips_frames'] = osp.join(opt['select_clips_root'], 'frames')
opt['select_done_flags'] = osp.join(opt['select_clips_root'], 'done_flags')
if '1' in args.run:
run_step1(opt)
if '2' in args.run:
run_step2(opt)
if '3' in args.run:
run_step3(opt)
if '4' in args.run:
run_step4(opt)
if '5' in args.run:
run_step5(opt)
if '6' in args.run:
run_step6(opt)
# -------------------------------------------------------------------- #
# --------------------------- step1 ---------------------------------- #
# -------------------------------------------------------------------- #
def run_step1(opt):
"""extract frames
1. read all video files under video_root folder
2. filter out the videos that already have been processed
3. use multi-process to extract the remaining videos
"""
video_root = opt['video_root']
frames_root = opt['save_frames_root']
meta_root = opt['meta_files_root']
os.makedirs(frames_root, exist_ok=True)
os.makedirs(meta_root, exist_ok=True)
if not osp.isdir(video_root):
print(f'path {video_root} is not a valid folder, exit.')
videos_path = sorted(glob.glob(osp.join(video_root, '*')))
if opt['debug']:
videos_path = videos_path[:3]
else:
videos_path = videos_path[opt['ss_idx']:opt['to_idx']]
pbar = tqdm(total=len(videos_path), unit='video', desc='step1')
pool = Pool(opt['n_thread'])
for video_path in videos_path:
video_name = osp.splitext(osp.basename(video_path))[0]
if video_name.startswith('.'):
print(f'skip {video_name}')
continue
frame_path = osp.join(frames_root, video_name)
meta_path = osp.join(meta_root, f'{video_name}.txt')
pool.apply_async(
worker1, args=(opt, video_name, video_path, frame_path, meta_path), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
def worker1(opt, video_name, video_path, frame_path, meta_path):
# get info of video
fps = video_util.get_video_fps(video_path)
h, w = video_util.get_video_resolution(video_path)
num_frames = video_util.get_video_num_frames(video_path)
bit_rate = video_util.get_video_bitrate(video_path)
# check whether this video has been processed
flag = True
num_extracted_frames = 0
if osp.exists(frame_path):
num_extracted_frames = len(glob.glob(osp.join(frame_path, '*.png')))
if num_extracted_frames == num_frames:
print(f'skip {video_path} since there are already {num_frames} frames have been extracted.')
flag = False
else:
print(f'{num_extracted_frames} of {num_frames} have been extracted for {video_path}, re-run.')
# extract frames
os.makedirs(frame_path, exist_ok=True)
video_util.video2frames(video_path, frame_path, force=flag, high_quality=True)
if flag:
num_extracted_frames = len(glob.glob(osp.join(frame_path, '*.png')))
# write some metadata to meta file
with open(meta_path, 'w') as f:
f.write(f'Video Name: {video_name}\n')
f.write(f'H: {h}\n')
f.write(f'W: {w}\n')
f.write(f'FPS: {fps}\n')
f.write(f'Bit Rate: {bit_rate}kbps\n')
f.write(f'{num_extracted_frames}/{num_frames} have been extracted\n')
# -------------------------------------------------------------------- #
# --------------------------- step2 ---------------------------------- #
# -------------------------------------------------------------------- #
def run_step2(opt):
"""shot detection. refer to lijian's pipeline"""
detect_shot_root = opt['detect_shot_root']
meta_root = opt['meta_files_root']
os.makedirs(detect_shot_root, exist_ok=True)
if not osp.exists(meta_root):
print('no videos has run step1, exit.')
return
# get the video which has been extracted frames
videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))
videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]
if opt['debug']:
videos_name = videos_name[:3]
else:
videos_name = videos_name[opt['ss_idx']:opt['to_idx']]
pbar = tqdm(total=len(videos_name), unit='video', desc='step2')
pool = Pool(opt['n_thread'])
for video_name in videos_name:
pool.apply_async(worker2, args=(opt, video_name), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
def worker2(opt, video_name):
video_frame_path = osp.join(opt['save_frames_root'], video_name)
detect_shot_file_path = osp.join(opt['detect_shot_root'], f'{video_name}.txt')
if osp.exists(detect_shot_file_path):
print(f'skip {video_name} since {detect_shot_file_path} already exist.')
return
detector = ShotDetector()
shot_list = detector.detect_shots(video_frame_path)
with open(detect_shot_file_path, 'w') as f:
for shot in shot_list:
f.write(f'{shot[0]} {shot[1]}\n')
# -------------------------------------------------------------------- #
# --------------------------- step3 ---------------------------------- #
# -------------------------------------------------------------------- #
def run_step3(opt):
estimate_flow_root = opt['estimate_flow_root']
meta_root = opt['meta_files_root']
os.makedirs(estimate_flow_root, exist_ok=True)
if not osp.exists(meta_root):
print('no videos has run step1, exit.')
return
# download the spynet checkpoint first
if not osp.exists(opt['spy_pretrain_weight']):
download_file_from_google_drive('1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF', opt['spy_pretrain_weight'])
# get the video which has been extracted frames
videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))
videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]
if opt['debug']:
videos_name = videos_name[:3]
else:
videos_name = videos_name[opt['ss_idx']:opt['to_idx']]
pbar = tqdm(total=len(videos_name), unit='video', desc='step3')
num_gpus = torch.cuda.device_count()
ctx = torch.multiprocessing.get_context('spawn')
pool = ctx.Pool(min(3 * num_gpus, opt['n_thread']))
for idx, video_name in enumerate(videos_name):
pool.apply_async(
worker3, args=(opt, video_name, torch.device(idx % num_gpus)), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
def read_img(img_path, device, downscale_factor=1):
img = cv2.imread(img_path)
h, w = img.shape[0:2]
if downscale_factor != 1:
img = cv2.resize(img, (w // downscale_factor, h // downscale_factor), interpolation=cv2.INTER_LANCZOS4)
img = img2tensor(img)
img = img.unsqueeze(0).to(device)
return img
@torch.no_grad()
def worker3(opt, video_name, device):
video_frame_path = osp.join(opt['save_frames_root'], video_name)
frames_path = sorted(glob.glob(osp.join(video_frame_path, '*.png')))
estimate_flow_file_path = osp.join(opt['estimate_flow_root'], f'{video_name}.txt')
if osp.exists(estimate_flow_file_path):
with open(estimate_flow_file_path, 'r') as f:
lines = f.readlines()
length = len(lines)
if length == len(frames_path):
print(f'skip {video_name} since {length}/{len(frames_path)} have done.')
return
else:
print(f're-run {video_name} since only {length}/{len(frames_path)} have done.')
spynet = SpyNet(load_path=opt['spy_pretrain_weight']).eval().to(device)
downscale_factor = opt['downscale_factor']
flow_out_list = []
pbar = tqdm(total=len(frames_path), unit='frame', desc='worker3')
pre_img = None
for idx, frame_path in enumerate(frames_path):
img_name = osp.basename(frame_path)
cur_img = read_img(frame_path, device, downscale_factor=downscale_factor)
if pre_img is not None:
flow = spynet(cur_img, pre_img)
flow = flow.abs()
flow_max = flow.max().item()
flow_avg = flow.mean().item() * 2.0 # according to lijian's hyper-parameter
elif idx == 0:
flow_max = 0.0
flow_avg = 0.0
else:
raise RuntimeError(f'pre_img is none at {idx}')
flow_out_list.append(f'{img_name} {flow_max:.6f} {flow_avg:.6f}\n')
pre_img = cur_img
pbar.update(1)
with open(estimate_flow_file_path, 'w') as f:
for line in flow_out_list:
f.write(line)
# -------------------------------------------------------------------- #
# --------------------------- step4 ---------------------------------- #
# -------------------------------------------------------------------- #
def run_step4(opt):
black_flag_root = opt['black_flag_root']
meta_root = opt['meta_files_root']
os.makedirs(black_flag_root, exist_ok=True)
if not osp.exists(meta_root):
print('no videos has run step1, exit.')
return
# get the video which has been extracted frames
videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))
videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]
if opt['debug']:
videos_name = videos_name[:3]
os.makedirs('tmp_black', exist_ok=True)
else:
videos_name = videos_name[opt['ss_idx']:opt['to_idx']]
pbar = tqdm(total=len(videos_name), unit='video', desc='step4')
pool = Pool(opt['n_thread'])
for idx, video_name in enumerate(videos_name):
pool.apply_async(worker4, args=(opt, video_name), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
def worker4(opt, video_name):
video_frame_path = osp.join(opt['save_frames_root'], video_name)
black_flag_path = osp.join(opt['black_flag_root'], f'{video_name}.txt')
if osp.exists(black_flag_path):
print(f'skip {video_name} since {black_flag_path} already exists.')
return
frames_path = sorted(glob.glob(osp.join(video_frame_path, '*.png')))
out_list = []
pbar = tqdm(total=len(frames_path), unit='frame', desc='worker4')
for frame_path in frames_path:
img = cv2.imread(frame_path)
img_name = osp.basename(frame_path)
h, w = img.shape[0:2]
total_pixels = h * w
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
hist = cv2.calcHist([img_gray], [0], None, [256], [0.0, 255.0])
max_pixel = max(hist)[0]
percentage = max_pixel / total_pixels
if percentage > opt['black_threshold']:
out_list.append(f'{img_name} {0} {percentage:.6f}\n')
if opt['debug']:
cv2.imwrite(osp.join('tmp_black', f'{video_name}_{img_name}'), img)
else:
out_list.append(f'{img_name} {1} {percentage:.6f}\n')
pbar.update(1)
with open(black_flag_path, 'w') as f:
for line in out_list:
f.write(line)
# -------------------------------------------------------------------- #
# --------------------------- step5 ---------------------------------- #
# -------------------------------------------------------------------- #
def run_step5(opt):
iqa_score_root = opt['iqa_score_root']
meta_root = opt['meta_files_root']
os.makedirs(iqa_score_root, exist_ok=True)
if not osp.exists(meta_root):
print('no videos has run step1, exit.')
return
# get the video which has been extracted frames
videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))
videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]
if opt['debug']:
videos_name = videos_name[:3]
os.makedirs('tmp_low_iqa', exist_ok=True)
else:
videos_name = videos_name[opt['ss_idx']:opt['to_idx']]
pbar = tqdm(total=len(videos_name), unit='video', desc='step5')
num_gpus = torch.cuda.device_count()
ctx = torch.multiprocessing.get_context('spawn')
pool = ctx.Pool(min(3 * num_gpus, opt['n_thread']))
for idx, video_name in enumerate(videos_name):
pool.apply_async(
worker5, args=(opt, video_name, torch.device(idx % num_gpus)), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
@torch.no_grad()
def worker5(opt, video_name, device):
video_frame_path = osp.join(opt['save_frames_root'], video_name)
frames_path = sorted(glob.glob(osp.join(video_frame_path, '*.png')))
iqa_score_path = osp.join(opt['iqa_score_root'], f'{video_name}.txt')
if osp.exists(iqa_score_path):
with open(iqa_score_path, 'r') as f:
lines = f.readlines()
length = len(lines)
if length == len(frames_path):
print(f'skip {video_name} since {length}/{len(frames_path)} have done.')
return
else:
print(f're-run {video_name} since only {length}/{len(frames_path)} have done.')
assess_net = init_assessment_model('hypernet', device=device)
assess_net = assess_net.half()
# specified transformation in original hyperIQA
transforms_resize = torchvision.transforms.Compose([
torchvision.transforms.Resize((512, 384)),
])
transforms_crop = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(size=224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
iqa_out_list = []
pbar = tqdm(total=len(frames_path), unit='frame', desc='worker3')
for idx, frame_path in enumerate(frames_path):
img_name = osp.basename(frame_path)
cv2_img = cv2.imread(frame_path)
# BRG -> RGB
img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
patchs = []
img_resize = transforms_resize(img)
for _ in range(opt['num_patch_per_iqa']):
patchs.append(transforms_crop(img_resize))
patch = torch.stack(patchs, dim=0).to(device)
pred = assess_net(patch.half())
score = pred.mean().item()
iqa_out_list.append(f'{img_name} {score:.6f}\n')
if opt['debug'] and score < 50.0:
cv2.imwrite(osp.join('tmp_low_iqa', f'{video_name}_{img_name}'), cv2_img)
pbar.update(1)
with open(iqa_score_path, 'w') as f:
for line in iqa_out_list:
f.write(line)
# -------------------------------------------------------------------- #
# --------------------------- step6 ---------------------------------- #
# -------------------------------------------------------------------- #
def filter_frozen_shots(shots, flows):
"""select clips from input video."""
flag_shot = np.ones(len(shots))
for idx, shot in enumerate(shots):
shot = shot.split(' ')
start = int(shot[0])
end = int(shot[1])
flow_in_shot = []
for i in range(start, end + 1, 1):
if i == 0:
continue
else:
flow_in_shot.append(float(flows[i].split(' ')[2]))
flow_in_shot = np.array(flow_in_shot)
flow_std = np.std(flow_in_shot)
if flow_std < 14.0:
flag_shot[idx] = 0
return flag_shot
def generate_clips(shots, flows, filter_frames, hyperiqa, max_length=500):
"""
hyperiqa [0, 100]
flows [0, 15000] (may be larger)
"""
clips = []
clip_scores = []
clip = []
shot_flow = 0
shot_hyperiqa = 0
for shot in shots:
shot = shot.split(' ')
start = int(shot[0])
end = int(shot[1])
pre_black = 0
for i in range(start, end + 1, 1):
if i == start:
stat = 0
pre_black = 1 # the first frame in shot do not need flow
else:
stat = 1
black_frame_thr = float(filter_frames[i].split(' ')[2])
# drop img when 90% of pixels are identical
if black_frame_thr < 0.90:
black_frame = 0
else:
black_frame = 1
# if current frame is a black frame, delete
if black_frame == 1:
pre_black = 1
elif pre_black == 0:
flow = float(flows[i].split(' ')[1])
shot_flow += flow
else:
pre_black = 0
flow = float(flows[i].split(' ')[1])
# calcu hyperiqa for non-black frames
if black_frame == 0:
curr_hyperiqa = float(hyperiqa[i].split(' ')[1])
shot_hyperiqa += curr_hyperiqa
clip.append(f'{i+1:08d} {stat} {flow} {curr_hyperiqa}')
if len(clip) == max_length:
clips.append(clip.copy())
clip_score = shot_flow / 150.0 + shot_hyperiqa
clip_score = clip_score / len(clip)
clip_scores.append(clip_score)
clip = []
shot_flow = 0
shot_hyperiqa = 0
# print(len(clip))
# if len(clip) > 0:
# clips.append(clip.copy())
# clip_score = shot_flow / 150.0 + shot_hyperiqa
# clip_score = clip_score / len(clip)
# clip_scores.append(clip_score)
sorted_shot = np.argsort(-np.array(clip_scores))
return [clips[i] for i in sorted_shot], [clip_scores[i] for i in sorted_shot]
def run_step6(opt):
meta_root = opt['meta_files_root']
if not osp.exists(meta_root):
print('no videos has run step1, exit.')
return
# get the video which has been extracted frames
videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))
videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]
if opt['debug']:
videos_name = videos_name[:3]
else:
videos_name = videos_name[opt['ss_idx']:opt['to_idx']]
pbar = tqdm(total=len(videos_name), unit='video', desc='step6')
os.makedirs(opt['select_clips_meta'], exist_ok=True)
os.makedirs(opt['select_clips_frames'], exist_ok=True)
os.makedirs(opt['select_done_flags'], exist_ok=True)
pool = Pool(opt['n_thread'])
for video_name in videos_name:
pool.apply_async(worker6, args=(opt, video_name), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
def worker6(opt, video_name):
select_clips_meta = opt['select_clips_meta']
select_clips_frames = opt['select_clips_frames']
select_done_flags = opt['select_done_flags']
if osp.exists(osp.join(select_done_flags, f'{video_name}.txt')):
print(f'skip {video_name}.')
return
with open(osp.join(opt['detect_shot_root'], f'{video_name}.txt'), 'r') as f:
shots = f.readlines()
shots = [shot.strip() for shot in shots]
with open(osp.join(opt['estimate_flow_root'], f'{video_name}.txt'), 'r') as f:
flows = f.readlines()
flows = [flow.strip() for flow in flows]
with open(osp.join(opt['black_flag_root'], f'{video_name}.txt'), 'r') as f:
black_flags = f.readlines()
black_flags = [black_flag.strip() for black_flag in black_flags]
with open(osp.join(opt['iqa_score_root'], f'{video_name}.txt'), 'r') as f:
iqa_scores = f.readlines()
iqa_scores = [iqa_score.strip() for iqa_score in iqa_scores]
flag_shot = filter_frozen_shots(shots, flows)
flag = np.where(flag_shot == 1)
flag = flag[0].tolist()
filtered_shots = [shots[i] for i in flag]
clips, scores = generate_clips(
filtered_shots, flows, black_flags, iqa_scores, max_length=opt['num_frames_per_clip'])
with open(osp.join(select_clips_meta, f'{video_name}.txt'), 'w') as f:
for i, clip in enumerate(clips):
os.makedirs(osp.join(select_clips_frames, f'{video_name}_{i}'), exist_ok=True)
for idx, info in enumerate(clip):
f.write(f'clip: {i:02d} {info} {scores[i]}\n')
img_name = info.split(' ')[0] + '.png'
shutil.copy(
osp.join(opt['save_frames_root'], video_name, img_name),
osp.join(select_clips_frames, f'{video_name}_{i}', f'{idx:08d}.png'))
if i >= opt['num_clips_per_video'] - 1:
break
with open(osp.join(select_done_flags, f'{video_name}.txt'), 'w') as f:
f.write(f'{i+1} clips are selected for {video_name}.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataroot',
type=str,
required=True,
help='dataset root, dataroot/raw_videos should contains your HQ videos to be processed.')
parser.add_argument('--n_thread', type=int, default=4, help='Thread number.')
parser.add_argument('--run', type=str, default='123456', help='run which steps')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--ss_idx', type=int, default=None, help='ss index')
parser.add_argument('--to_idx', type=int, default=None, help='to index')
parser.add_argument('--n_frames_per_clip', type=int, default=100)
parser.add_argument('--n_clips_per_video', type=int, default=1)
parser.add_argument('--select_clip_root', type=str, default='select_clips')
args = parser.parse_args()
main(args)
================================================
FILE: scripts/inference_animesr_frames.py
================================================
"""inference AnimeSR on frames"""
import argparse
import cv2
import glob
import numpy as np
import os
import psutil
import queue
import threading
import time
import torch
from os import path as osp
from tqdm import tqdm
from animesr.utils.inference_base import get_base_argument_parser, get_inference_model
from animesr.utils.video_util import frames2video
from basicsr.data.transforms import mod_crop
from basicsr.utils.img_util import img2tensor, tensor2img
def read_img(path, require_mod_crop=True, mod_scale=4, input_rescaling_factor=1.0):
""" read an image tensor from a given path
Args:
path: image path
require_mod_crop: mod crop or not. since the arch is multi-scale, so mod crop is needed by default
mod_scale: scale factor for mod_crop
Returns:
torch.Tensor: size(1, c, h, w)
"""
img = cv2.imread(path)
img = img.astype(np.float32) / 255.
if input_rescaling_factor != 1.0:
h, w = img.shape[:2]
img = cv2.resize(
img, (int(w * input_rescaling_factor), int(h * input_rescaling_factor)), interpolation=cv2.INTER_LANCZOS4)
if require_mod_crop:
img = mod_crop(img, mod_scale)
img = img2tensor(img, bgr2rgb=True, float32=True)
return img.unsqueeze(0)
class IOConsumer(threading.Thread):
"""Since IO time can take up a significant portion of the total inference time,
so we use multi thread to write frames individually.
"""
def __init__(self, args: argparse.Namespace, que, qid):
super().__init__()
self._queue = que
self.qid = qid
self.args = args
def run(self):
while True:
msg = self._queue.get()
if isinstance(msg, str) and msg == 'quit':
break
output = msg['output']
imgname = msg['imgname']
out_img = tensor2img(output.squeeze(0))
if self.args.outscale != self.args.netscale:
h, w = out_img.shape[:2]
out_img = cv2.resize(
out_img, (int(
w * self.args.outscale / self.args.netscale), int(h * self.args.outscale / self.args.netscale)),
interpolation=cv2.INTER_LANCZOS4)
cv2.imwrite(imgname, out_img)
print(f'IO for worker {self.qid} is done.')
@torch.no_grad()
def main():
"""Inference demo for AnimeSR.
It mainly for restoring anime frames.
"""
parser = get_base_argument_parser()
parser.add_argument('--input_rescaling_factor', type=float, default=1.0)
parser.add_argument('--num_io_consumer', type=int, default=3, help='number of IO consumer')
parser.add_argument(
'--sample_interval',
type=int,
default=1,
help='save 1 frame for every $sample_interval frames. this will be useful for calculating the metrics')
parser.add_argument('--save_video_too', action='store_true')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_inference_model(args, device)
# prepare output dir
frame_output = osp.join(args.output, args.expname, 'frames')
os.makedirs(frame_output, exist_ok=True)
# the input format can be:
# 1. clip folder which contains frames
# or 2. a folder which contains several clips
first_level_dir = len(glob.glob(osp.join(args.input, '*.png'))) > 0
if args.input.endswith('/'):
args.input = args.input[:-1]
if first_level_dir:
videos_name = [osp.basename(args.input)]
args.input = osp.dirname(args.input)
else:
videos_name = sorted(os.listdir(args.input))
pbar1 = tqdm(total=len(videos_name), unit='video', desc='inference')
que = queue.Queue()
consumers = [IOConsumer(args, que, f'IO_{i}') for i in range(args.num_io_consumer)]
for consumer in consumers:
consumer.start()
for video_name in videos_name:
video_folder_path = osp.join(args.input, video_name)
imgs_list = sorted(glob.glob(osp.join(video_folder_path, '*')))
num_imgs = len(imgs_list)
os.makedirs(osp.join(frame_output, video_name), exist_ok=True)
# prepare
prev = read_img(
imgs_list[0],
require_mod_crop=True,
mod_scale=args.mod_scale,
input_rescaling_factor=args.input_rescaling_factor).to(device)
cur = prev
nxt = read_img(
imgs_list[min(1, num_imgs - 1)],
require_mod_crop=True,
mod_scale=args.mod_scale,
input_rescaling_factor=args.input_rescaling_factor).to(device)
c, h, w = prev.size()[-3:]
state = prev.new_zeros(1, 64, h, w)
out = prev.new_zeros(1, c, h * args.netscale, w * args.netscale)
pbar2 = tqdm(total=num_imgs, unit='frame', desc='inference')
tot_model_time = 0
cnt_model_time = 0
for idx in range(num_imgs):
torch.cuda.synchronize()
start = time.time()
img_name = osp.splitext(osp.basename(imgs_list[idx]))[0]
out, state = model.cell(torch.cat((prev, cur, nxt), dim=1), out, state)
torch.cuda.synchronize()
model_time = time.time() - start
tot_model_time += model_time
cnt_model_time += 1
if (idx + 1) % args.sample_interval == 0:
# put the output frame to the queue to be consumed
que.put({'output': out.cpu().clone(), 'imgname': osp.join(frame_output, video_name, f'{img_name}.png')})
torch.cuda.synchronize()
start = time.time()
prev = cur
cur = nxt
nxt = read_img(
imgs_list[min(idx + 2, num_imgs - 1)],
require_mod_crop=True,
mod_scale=args.mod_scale,
input_rescaling_factor=args.input_rescaling_factor).to(device)
torch.cuda.synchronize()
read_time = time.time() - start
pbar2.update(1)
pbar2.set_description(f'read_time: {read_time}, model_time: {tot_model_time/cnt_model_time}')
mem = psutil.virtual_memory()
# since the speed of producer (model inference) is faster than the consumer (I/O)
# if there is a risk of OOM, just sleep to let the consumer work
if mem.percent > 80.0:
time.sleep(30)
pbar1.update(1)
for _ in range(args.num_io_consumer):
que.put('quit')
for consumer in consumers:
consumer.join()
if not args.save_video_too:
return
# convert the frames to videos
video_output = osp.join(args.output, args.expname, 'videos')
os.makedirs(video_output, exist_ok=True)
for video_name in videos_name:
out_path = osp.join(video_output, f'{video_name}.mp4')
frames2video(
osp.join(frame_output, video_name), out_path, fps=24 if args.fps is None else args.fps, suffix='png')
if __name__ == '__main__':
main()
================================================
FILE: scripts/inference_animesr_video.py
================================================
import cv2
import ffmpeg
import glob
import mimetypes
import numpy as np
import os
import shutil
import subprocess
import torch
from os import path as osp
from tqdm import tqdm
from animesr.utils import video_util
from animesr.utils.inference_base import get_base_argument_parser, get_inference_model
from basicsr.data.transforms import mod_crop
from basicsr.utils.img_util import img2tensor, tensor2img
from basicsr.utils.logger import AvgTimer
def get_video_meta_info(video_path):
"""get the meta info of the video by using ffprobe with python interface"""
ret = {}
probe = ffmpeg.probe(video_path)
video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
ret['width'] = video_streams[0]['width']
ret['height'] = video_streams[0]['height']
ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
try:
ret['nb_frames'] = int(video_streams[0]['nb_frames'])
except KeyError: # bilibili transcoder dont have nb_frames
ret['duration'] = float(probe['format']['duration'])
ret['nb_frames'] = int(ret['duration'] * ret['fps'])
print(ret['duration'], ret['nb_frames'])
return ret
def get_sub_video(args, num_process, process_idx):
"""Cut the whole video into num_process parts, return the process_idx-th part"""
if num_process == 1:
return args.input
meta = get_video_meta_info(args.input)
duration = int(meta['nb_frames'] / meta['fps'])
part_time = duration // num_process
print(f'duration: {duration}, part_time: {part_time}')
out_path = osp.join(args.output, 'inp_sub_videos', f'{process_idx:03d}.mp4')
cmd = [
args.ffmpeg_bin,
f'-i {args.input}',
f'-ss {part_time * process_idx}',
f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '',
'-async 1',
out_path,
'-y',
]
print(' '.join(cmd))
subprocess.call(' '.join(cmd), shell=True)
return out_path
class Reader:
"""read frames from a video stream or frames list"""
def __init__(self, args, total_workers=1, worker_idx=0, device=torch.device('cuda')):
self.args = args
input_type = mimetypes.guess_type(args.input)[0]
self.input_type = 'folder' if input_type is None else input_type
self.paths = [] # for image&folder type
self.audio = None
self.input_fps = None
if self.input_type.startswith('video'):
video_path = get_sub_video(args, total_workers, worker_idx)
# read bgr from stream, which is the same format as opencv
self.stream_reader = (
ffmpeg
.input(video_path)
.output('pipe:', format='rawvideo', pix_fmt='bgr24', loglevel='error')
.run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)
) # yapf: disable # noqa
meta = get_video_meta_info(video_path)
self.width = meta['width']
self.height = meta['height']
self.input_fps = meta['fps']
self.audio = meta['audio']
self.nb_frames = meta['nb_frames']
else:
if self.input_type.startswith('image'):
self.paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*')))
tot_frames = len(paths)
num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0)
self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)]
self.nb_frames = len(self.paths)
assert self.nb_frames > 0, 'empty folder'
from PIL import Image
tmp_img = Image.open(self.paths[0]) # lazy load
self.width, self.height = tmp_img.size
self.idx = 0
self.device = device
def get_resolution(self):
return self.height, self.width
def get_fps(self):
"""the fps of sr video is set to the user input fps first, followed by the input fps,
If the first two values are None, then the commonly used fps 24 is set"""
if self.args.fps is not None:
return self.args.fps
elif self.input_fps is not None:
return self.input_fps
return 24
def get_audio(self):
return self.audio
def __len__(self):
"""return the number of frames for this worker, however, this may be not accurate for video stream"""
return self.nb_frames
def get_frame_from_stream(self):
img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
if not img_bytes:
# end of stream
return None
img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
return img
def get_frame_from_list(self):
if self.idx >= self.nb_frames:
return None
img = cv2.imread(self.paths[self.idx])
self.idx += 1
return img
def get_frame(self):
if self.input_type.startswith('video'):
img = self.get_frame_from_stream()
else:
img = self.get_frame_from_list()
if img is None:
raise StopIteration
# bgr uint8 numpy -> rgb float32 [0, 1] tensor on device
img = img.astype(np.float32) / 255.
img = mod_crop(img, self.args.mod_scale)
img = img2tensor(img, bgr2rgb=True, float32=True).unsqueeze(0).to(self.device)
if self.args.half:
# half precision won't make a big impact on visuals
img = img.half()
return img
def close(self):
# close the video stream
if self.input_type.startswith('video'):
self.stream_reader.stdin.close()
self.stream_reader.wait()
class Writer:
"""write frames to a video stream"""
def __init__(self, args, audio, height, width, video_save_path, fps):
out_width, out_height = int(width * args.outscale), int(height * args.outscale)
if out_height > 2160:
print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
'We highly recommend to decrease the outscale(aka, -s).')
vsp = video_save_path
if audio is not None:
self.stream_writer = (
ffmpeg
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', framerate=fps)
.output(audio, vsp, pix_fmt='yuv420p', vcodec='libx264', loglevel='error', acodec='copy')
.overwrite_output()
.run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)
) # yapf: disable # noqa
else:
self.stream_writer = (
ffmpeg
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', framerate=fps)
.output(vsp, pix_fmt='yuv420p', vcodec='libx264', loglevel='error')
.overwrite_output()
.run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)
) # yapf: disable # noqa
self.out_width = out_width
self.out_height = out_height
self.args = args
def write_frame(self, frame):
if self.args.outscale != self.args.netscale:
frame = cv2.resize(frame, (self.out_width, self.out_height), interpolation=cv2.INTER_LANCZOS4)
self.stream_writer.stdin.write(frame.tobytes())
def close(self):
self.stream_writer.stdin.close()
self.stream_writer.wait()
@torch.no_grad()
def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0):
# prepare model
model = get_inference_model(args, device)
# prepare reader and writer
reader = Reader(args, total_workers, worker_idx, device=device)
audio = reader.get_audio()
height, width = reader.get_resolution()
height = height - height % args.mod_scale
width = width - width % args.mod_scale
fps = reader.get_fps()
writer = Writer(args, audio, height, width, video_save_path, fps)
# initialize pre/cur/nxt frames, pre sr frame, and pre hidden state for inference
end_flag = False
prev = reader.get_frame()
cur = prev
try:
nxt = reader.get_frame()
except StopIteration:
end_flag = True
nxt = cur
state = prev.new_zeros(1, 64, height, width)
out = prev.new_zeros(1, 3, height * args.netscale, width * args.netscale)
pbar = tqdm(total=len(reader), unit='frame', desc='inference')
model_timer = AvgTimer() # model inference time tracker
i_timer = AvgTimer() # I(input read) time tracker
o_timer = AvgTimer() # O(output write) time tracker
while True:
# inference at current step
torch.cuda.synchronize(device=device)
model_timer.start()
out, state = model.cell(torch.cat((prev, cur, nxt), dim=1), out, state)
torch.cuda.synchronize(device=device)
model_timer.record()
# write current sr frame to video stream
torch.cuda.synchronize(device=device)
o_timer.start()
output_frame = tensor2img(out, rgb2bgr=False)
writer.write_frame(output_frame)
torch.cuda.synchronize(device=device)
o_timer.record()
# if end of stream, break
if end_flag:
break
# move the sliding window
torch.cuda.synchronize(device=device)
i_timer.start()
prev = cur
cur = nxt
try:
nxt = reader.get_frame()
except StopIteration:
nxt = cur
end_flag = True
torch.cuda.synchronize(device=device)
i_timer.record()
# update&print infomation
pbar.update(1)
pbar.set_description(
f'I: {i_timer.get_avg_time():.4f} O: {o_timer.get_avg_time():.4f} Model: {model_timer.get_avg_time():.4f}')
reader.close()
writer.close()
def run(args):
if args.suffix is None:
args.suffix = ''
else:
args.suffix = f'_{args.suffix}'
video_save_path = osp.join(args.output, f'{args.video_name}{args.suffix}.mp4')
# set up multiprocessing
num_gpus = torch.cuda.device_count()
num_process = num_gpus * args.num_process_per_gpu
if num_process == 1:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inference_video(args, video_save_path, device=device)
return
ctx = torch.multiprocessing.get_context('spawn')
pool = ctx.Pool(num_process)
out_sub_videos_dir = osp.join(args.output, 'out_sub_videos')
os.makedirs(out_sub_videos_dir, exist_ok=True)
os.makedirs(osp.join(args.output, 'inp_sub_videos'), exist_ok=True)
pbar = tqdm(total=num_process, unit='sub_video', desc='inference')
for i in range(num_process):
sub_video_save_path = osp.join(out_sub_videos_dir, f'{i:03d}.mp4')
pool.apply_async(
inference_video,
args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i),
callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
# combine sub videos
# prepare vidlist.txt
with open(f'{args.output}/vidlist.txt', 'w') as f:
for i in range(num_process):
f.write(f'file \'out_sub_videos/{i:03d}.mp4\'\n')
# To avoid video&audio desync as mentioned in https://github.com/xinntao/Real-ESRGAN/issues/388
# we use the solution provided in https://stackoverflow.com/a/52156277 to solve this issue
cmd = [
args.ffmpeg_bin,
'-f', 'concat',
'-safe', '0',
'-i', f'{args.output}/vidlist.txt',
'-c:v', 'copy',
'-af', 'aresample=async=1000',
video_save_path,
'-y',
] # yapf: disable
print(' '.join(cmd))
subprocess.call(cmd)
shutil.rmtree(out_sub_videos_dir)
shutil.rmtree(osp.join(args.output, 'inp_sub_videos'))
os.remove(f'{args.output}/vidlist.txt')
def main():
"""Inference demo for AnimeSR.
It mainly for restoring anime videos.
"""
parser = get_base_argument_parser()
parser.add_argument(
'--extract_frame_first',
action='store_true',
help='if input is a video, you can still extract the frames first, other wise AnimeSR will read from stream')
parser.add_argument(
'--num_process_per_gpu', type=int, default=1, help='the total process is number_process_per_gpu * num_gpu')
parser.add_argument(
'--suffix', type=str, default=None, help='you can add a suffix string to the sr video name, for example, x2')
args = parser.parse_args()
args.ffmpeg_bin = os.environ.get('ffmpeg_exe_path', 'ffmpeg')
args.input = args.input.rstrip('/').rstrip('\\')
if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'):
is_video = True
else:
is_video = False
if args.extract_frame_first and not is_video:
args.extract_frame_first = False
# prepare input and output
args.video_name = osp.splitext(osp.basename(args.input))[0]
args.output = osp.join(args.output, args.expname, 'videos', args.video_name)
os.makedirs(args.output, exist_ok=True)
if args.extract_frame_first:
inp_extracted_frames = osp.join(args.output, 'inp_extracted_frames')
os.makedirs(inp_extracted_frames, exist_ok=True)
video_util.video2frames(args.input, inp_extracted_frames, force=True, high_quality=True)
video_meta = get_video_meta_info(args.input)
args.fps = video_meta['fps']
args.input = inp_extracted_frames
run(args)
if args.extract_frame_first:
shutil.rmtree(args.input)
if __name__ == '__main__':
main()
================================================
FILE: scripts/metrics/MANIQA/inference_MANIQA.py
================================================
import argparse
import os
import random
import torch
from pipal_data import NTIRE2022
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from utils import Normalize, ToTensor, crop_image
def parse_args():
parser = argparse.ArgumentParser(description='Inference script of RealBasicVSR')
parser.add_argument('--model_path', help='checkpoint file', required=True)
parser.add_argument('--input_dir', help='directory of the input video', required=True)
parser.add_argument(
'--output_dir',
help='directory of the output results',
default='output/ensemble_attentionIQA2_finetune_e2/AnimeSR')
args = parser.parse_args()
return args
def main():
args = parse_args()
# configuration
batch_size = 10
num_workers = 8
average_iters = 20
crop_size = 224
os.makedirs(args.output_dir, exist_ok=True)
model = torch.load(args.model_path)
# map to cuda, if available
cuda_flag = False
if torch.cuda.is_available():
model = model.cuda()
cuda_flag = True
model.eval()
total_avg_score = []
subfolder_namelist = []
for subfolder_name in sorted(os.listdir(args.input_dir)):
avg_score = 0.0
subfolder_root = os.path.join(args.input_dir, subfolder_name)
if os.path.isdir(subfolder_root) and subfolder_name != 'assemble-folder':
# data load
val_dataset = NTIRE2022(
ref_path=subfolder_root,
dis_path=subfolder_root,
transform=transforms.Compose([Normalize(0.5, 0.5), ToTensor()]),
)
val_loader = DataLoader(
dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=False)
name_list, pred_list = [], []
with open(os.path.join(args.output_dir, f'{subfolder_name}.txt'), 'w') as f:
for data in tqdm(val_loader):
pred = 0
for i in range(average_iters):
if cuda_flag:
x_d = data['d_img_org'].cuda()
b, c, h, w = x_d.shape
top = random.randint(0, h - crop_size)
left = random.randint(0, w - crop_size)
img = crop_image(top, left, crop_size, img=x_d)
with torch.no_grad():
pred += model(img)
pred /= average_iters
d_name = data['d_name']
pred = pred.cpu().numpy()
name_list.extend(d_name)
pred_list.extend(pred)
for i in range(len(name_list)):
f.write(f'{name_list[i]}, {float(pred_list[i][0]): .6f}\n')
avg_score += float(pred_list[i][0])
avg_score /= len(name_list)
f.write(f'The average score of {subfolder_name} is {avg_score:.6f}')
f.close()
subfolder_namelist.append(subfolder_name)
total_avg_score.append(avg_score)
with open(os.path.join(args.output_dir, 'average.txt'), 'w') as f:
for idx, averge_score in enumerate(total_avg_score):
string = f'Folder {subfolder_namelist[idx]}, Average Score: {averge_score:.6f}\n'
f.write(string)
print(f'Folder {subfolder_namelist[idx]}, Average Score: {averge_score:.6f}')
print(f'Average Score of {len(subfolder_namelist)} Folders: {sum(total_avg_score) / len(total_avg_score):.6f}')
string = f'Average Score of {len(subfolder_namelist)} Folders: {sum(total_avg_score) / len(total_avg_score):.6f}' # noqa E501
f.write(string)
f.close()
if __name__ == '__main__':
main()
================================================
FILE: scripts/metrics/MANIQA/models/model_attentionIQA2.py
================================================
# flake8: noqa
import timm
import torch
from einops import rearrange
from models.swin import SwinTransformer
from timm.models.vision_transformer import Block
from torch import nn
class ChannelAttn(nn.Module):
def __init__(self, dim, drop=0.1):
super().__init__()
self.c_q = nn.Linear(dim, dim)
self.c_k = nn.Linear(dim, dim)
self.c_v = nn.Linear(dim, dim)
self.norm_fact = dim**-0.5
self.softmax = nn.Softmax(dim=-1)
self.attn_drop = nn.Dropout(drop)
self.proj_drop = nn.Dropout(drop)
def forward(self, x):
_x = x
B, C, N = x.shape
q = self.c_q(x)
k = self.c_k(x)
v = self.c_v(x)
attn = q @ k.transpose(-2, -1) * self.norm_fact
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
x = self.proj_drop(x)
x = x + _x
return x
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
class AttentionIQA(nn.Module):
def __init__(self,
embed_dim=72,
num_outputs=1,
patch_size=8,
drop=0.1,
depths=[2, 2],
window_size=4,
dim_mlp=768,
num_heads=[4, 4],
img_size=224,
num_channel_attn=2,
**kwargs):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.input_size = img_size // patch_size
self.patches_resolution = (img_size // patch_size, img_size // patch_size)
self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)
self.save_output = SaveOutput()
hook_handles = []
for layer in self.vit.modules():
if isinstance(layer, Block):
handle = layer.register_forward_hook(self.save_output)
hook_handles.append(handle)
self.channel_attn1 = nn.Sequential(*[ChannelAttn(self.input_size**2) for i in range(num_channel_attn)])
self.channel_attn2 = nn.Sequential(*[ChannelAttn(self.input_size**2) for i in range(num_channel_attn)])
self.conv1 = nn.Conv2d(embed_dim * 4, embed_dim, 1, 1, 0)
self.swintransformer1 = SwinTransformer(
patches_resolution=self.patches_resolution,
depths=depths,
num_heads=num_heads,
embed_dim=embed_dim,
window_size=window_size,
dim_mlp=dim_mlp)
self.swintransformer2 = SwinTransformer(
patches_resolution=self.patches_resolution,
depths=depths,
num_heads=num_heads,
embed_dim=embed_dim // 2,
window_size=window_size,
dim_mlp=dim_mlp)
self.conv2 = nn.Conv2d(embed_dim, embed_dim // 2, 1, 1, 0)
self.fc_score = nn.Sequential(
nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(), nn.Dropout(drop),
nn.Linear(embed_dim // 2, num_outputs), nn.ReLU())
self.fc_weight = nn.Sequential(
nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(), nn.Dropout(drop),
nn.Linear(embed_dim // 2, num_outputs), nn.Sigmoid())
def extract_feature(self, save_output):
x6 = save_output.outputs[6][:, 1:]
x7 = save_output.outputs[7][:, 1:]
x8 = save_output.outputs[8][:, 1:]
x9 = save_output.outputs[9][:, 1:]
x = torch.cat((x6, x7, x8, x9), dim=2)
return x
def forward(self, x):
_x = self.vit(x)
x = self.extract_feature(self.save_output)
self.save_output.outputs.clear()
# stage 1
x = rearrange(x, 'b (h w) c -> b c (h w)', h=self.input_size, w=self.input_size)
x = self.channel_attn1(x)
x = rearrange(x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size)
x = self.conv1(x)
x = self.swintransformer1(x)
# stage2
x = rearrange(x, 'b c h w -> b c (h w)', h=self.input_size, w=self.input_size)
x = self.channel_attn2(x)
x = rearrange(x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size)
x = self.conv2(x)
x = self.swintransformer2(x)
x = rearrange(x, 'b c h w -> b (h w) c', h=self.input_size, w=self.input_size)
f = self.fc_score(x)
w = self.fc_weight(x)
s = torch.sum(f * w, dim=1) / torch.sum(w, dim=1)
return s
================================================
FILE: scripts/metrics/MANIQA/models/swin.py
================================================
"""
isort:skip_file
"""
# flake8: noqa
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from torch import nn
""" attention decoder mask """
def get_attn_decoder_mask(seq):
subsequent_mask = torch.ones_like(seq).unsqueeze(-1).expand(seq.size(0), seq.size(1), seq.size(1))
subsequent_mask = subsequent_mask.triu(diagonal=1) # upper triangular part of a matrix(2-D)
return subsequent_mask
""" attention pad mask """
def get_attn_pad_mask(seq_q, seq_k, i_pad):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
pad_attn_mask = seq_k.data.eq(i_pad)
pad_attn_mask = pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
return pad_attn_mask
class DecoderWindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifte
gitextract_0qygtin4/ ├── .github/ │ └── workflows/ │ └── pylint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── Training.md ├── VERSION ├── animesr/ │ ├── __init__.py │ ├── archs/ │ │ ├── __init__.py │ │ ├── discriminator_arch.py │ │ ├── simple_degradation_arch.py │ │ └── vsr_arch.py │ ├── data/ │ │ ├── __init__.py │ │ ├── data_utils.py │ │ ├── ffmpeg_anime_dataset.py │ │ ├── ffmpeg_anime_lbo_dataset.py │ │ └── paired_image_dataset.py │ ├── models/ │ │ ├── __init__.py │ │ ├── degradation_gan_model.py │ │ ├── degradation_model.py │ │ ├── video_recurrent_gan_model.py │ │ └── video_recurrent_model.py │ ├── test.py │ ├── train.py │ └── utils/ │ ├── __init__.py │ ├── inference_base.py │ ├── shot_detector.py │ └── video_util.py ├── cog.yaml ├── options/ │ ├── train_animesr_step1_gan_BasicOPonly.yml │ ├── train_animesr_step1_net_BasicOPonly.yml │ ├── train_animesr_step2_lbo_1_gan.yml │ ├── train_animesr_step2_lbo_1_net.yml │ └── train_animesr_step3_gan_3LBOs.yml ├── predict.py ├── requirements.txt ├── scripts/ │ ├── anime_videos_preprocessing.py │ ├── inference_animesr_frames.py │ ├── inference_animesr_video.py │ └── metrics/ │ ├── MANIQA/ │ │ ├── inference_MANIQA.py │ │ ├── models/ │ │ │ ├── model_attentionIQA2.py │ │ │ └── swin.py │ │ ├── pipal_data.py │ │ └── utils.py │ └── README.md ├── setup.cfg └── setup.py
SYMBOL INDEX (190 symbols across 24 files)
FILE: animesr/archs/discriminator_arch.py
function get_conv_layer (line 9) | def get_conv_layer(input_nc, ndf, kernel_size, stride, padding, bias=Tru...
class UNetDiscriminatorSN (line 16) | class UNetDiscriminatorSN(nn.Module):
method __init__ (line 19) | def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
method forward (line 40) | def forward(self, x):
class PatchDiscriminator (line 72) | class PatchDiscriminator(nn.Module):
method __init__ (line 79) | def __init__(self,
method _get_norm_layer (line 141) | def _get_norm_layer(self, norm_type='batch'):
method forward (line 155) | def forward(self, x):
class MultiScaleDiscriminator (line 160) | class MultiScaleDiscriminator(nn.Module):
method __init__ (line 171) | def __init__(self,
method forward (line 204) | def forward(self, x):
FILE: animesr/archs/simple_degradation_arch.py
class SimpleDegradationArch (line 8) | class SimpleDegradationArch(nn.Module):
method __init__ (line 13) | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, downscale=2):
method forward (line 37) | def forward(self, x):
FILE: animesr/archs/vsr_arch.py
class RightAlignMSConvResidualBlocks (line 9) | class RightAlignMSConvResidualBlocks(nn.Module):
method __init__ (line 12) | def __init__(self, num_in_ch=3, num_state_ch=64, num_out_ch=64, num_bl...
method up (line 47) | def up(self, x, scale=2):
method forward (line 55) | def forward(self, x):
class MSRSWVSR (line 79) | class MSRSWVSR(nn.Module):
method __init__ (line 85) | def __init__(self, num_feat=64, num_block=(5, 3, 2), netscale=4):
method cell (line 96) | def cell(self, x, fb, state):
method forward (line 108) | def forward(self, x):
FILE: animesr/data/data_utils.py
function random_crop (line 5) | def random_crop(imgs, patch_size, top=None, left=None):
FILE: animesr/data/ffmpeg_anime_dataset.py
class FFMPEGAnimeDataset (line 19) | class FFMPEGAnimeDataset(data.Dataset):
method __init__ (line 22) | def __init__(self, opt):
method get_gt_clip (line 63) | def get_gt_clip(self, index):
method add_ffmpeg_compression (line 114) | def add_ffmpeg_compression(self, img_lqs, width, height):
method __getitem__ (line 175) | def __getitem__(self, index):
method __len__ (line 209) | def __len__(self):
FILE: animesr/data/ffmpeg_anime_lbo_dataset.py
class FFMPEGAnimeLBODataset (line 16) | class FFMPEGAnimeLBODataset(FFMPEGAnimeDataset):
method __init__ (line 19) | def __init__(self, opt):
method reload_degradation_model (line 35) | def reload_degradation_model(self):
method custom_resize (line 49) | def custom_resize(self, x, scale=2):
method __getitem__ (line 66) | def __getitem__(self, index):
method __len__ (line 108) | def __len__(self):
FILE: animesr/data/paired_image_dataset.py
class CustomPairedImageDataset (line 12) | class CustomPairedImageDataset(data.Dataset):
method __init__ (line 33) | def __init__(self, opt):
method __getitem__ (line 58) | def __getitem__(self, index):
method __len__ (line 91) | def __len__(self):
FILE: animesr/models/degradation_gan_model.py
class DegradationGANModel (line 8) | class DegradationGANModel(SRGANModel):
method feed_data (line 11) | def feed_data(self, data):
method optimize_parameters (line 17) | def optimize_parameters(self, current_iter):
FILE: animesr/models/degradation_model.py
class DegradationModel (line 9) | class DegradationModel(SRModel):
method init_training_settings (line 12) | def init_training_settings(self):
method feed_data (line 24) | def feed_data(self, data):
method optimize_parameters (line 30) | def optimize_parameters(self, current_iter):
FILE: animesr/models/video_recurrent_gan_model.py
class VideoRecurrentGANCustomModel (line 11) | class VideoRecurrentGANCustomModel(VideoRecurrentCustomModel):
method init_training_settings (line 16) | def init_training_settings(self):
method setup_optimizers (line 70) | def setup_optimizers(self):
method optimize_parameters (line 103) | def optimize_parameters(self, current_iter):
method save (line 172) | def save(self, epoch, current_iter):
FILE: animesr/models/video_recurrent_model.py
class VideoRecurrentCustomModel (line 16) | class VideoRecurrentCustomModel(VideoBaseModel):
method __init__ (line 18) | def __init__(self, opt):
method feed_data (line 34) | def feed_data(self, data):
method setup_optimizers (line 82) | def setup_optimizers(self):
method optimize_parameters_base (line 112) | def optimize_parameters_base(self, current_iter):
method optimize_parameters (line 127) | def optimize_parameters(self, current_iter):
method dist_validation (line 163) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
method test (line 231) | def test(self):
FILE: animesr/utils/inference_base.py
function get_base_argument_parser (line 8) | def get_base_argument_parser() -> argparse.ArgumentParser:
function get_inference_model (line 45) | def get_inference_model(args, device) -> MSRSWVSR:
FILE: animesr/utils/shot_detector.py
function compute_downscale_factor (line 44) | def compute_downscale_factor(frame_width):
class ShotDetector (line 58) | class ShotDetector(object):
method __init__ (line 64) | def __init__(self, threshold=30.0, min_shot_len=15):
method add_cut (line 83) | def add_cut(self, cut):
method process_frame (line 90) | def process_frame(self, frame_num, frame_img):
method detect_shots (line 167) | def detect_shots(self, frame_source, frame_skip=0, show_progress=True,...
FILE: animesr/utils/video_util.py
function get_video_fps (line 11) | def get_video_fps(video_path, ret_type='float'):
function get_video_num_frames (line 46) | def get_video_num_frames(video_path):
function get_video_bitrate (line 64) | def get_video_bitrate(video_path):
function get_video_resolution (line 85) | def get_video_resolution(video_path):
function video2frames (line 113) | def video2frames(video_path, out_dir, force=False, high_quality=True, ss...
function frames2video (line 152) | def frames2video(frames_dir, out_path, fps=25, filter='*', suffix=None):
FILE: predict.py
class ModelOutput (line 16) | class ModelOutput(BaseModel):
class Predictor (line 21) | class Predictor(BasePredictor):
method predict (line 23) | def predict(
FILE: scripts/anime_videos_preprocessing.py
function main (line 22) | def main(args):
function run_step1 (line 89) | def run_step1(opt):
function worker1 (line 127) | def worker1(opt, video_name, video_path, frame_path, meta_path):
function run_step2 (line 166) | def run_step2(opt):
function worker2 (line 192) | def worker2(opt, video_name):
function run_step3 (line 211) | def run_step3(opt):
function read_img (line 244) | def read_img(img_path, device, downscale_factor=1):
function worker3 (line 255) | def worker3(opt, video_name, device):
function run_step4 (line 306) | def run_step4(opt):
function worker4 (line 333) | def worker4(opt, video_name):
function run_step5 (line 373) | def run_step5(opt):
function worker5 (line 404) | def worker5(opt, video_name, device):
function filter_frozen_shots (line 466) | def filter_frozen_shots(shots, flows):
function generate_clips (line 492) | def generate_clips(shots, flows, filter_frames, hyperiqa, max_length=500):
function run_step6 (line 561) | def run_step6(opt):
function worker6 (line 589) | def worker6(opt, video_name):
FILE: scripts/inference_animesr_frames.py
function read_img (line 21) | def read_img(path, require_mod_crop=True, mod_scale=4, input_rescaling_f...
class IOConsumer (line 47) | class IOConsumer(threading.Thread):
method __init__ (line 52) | def __init__(self, args: argparse.Namespace, que, qid):
method run (line 58) | def run(self):
function main (line 79) | def main():
FILE: scripts/inference_animesr_video.py
function get_video_meta_info (line 20) | def get_video_meta_info(video_path):
function get_sub_video (line 39) | def get_sub_video(args, num_process, process_idx):
class Reader (line 62) | class Reader:
method __init__ (line 65) | def __init__(self, args, total_workers=1, worker_idx=0, device=torch.d...
method get_resolution (line 105) | def get_resolution(self):
method get_fps (line 108) | def get_fps(self):
method get_audio (line 117) | def get_audio(self):
method __len__ (line 120) | def __len__(self):
method get_frame_from_stream (line 124) | def get_frame_from_stream(self):
method get_frame_from_list (line 132) | def get_frame_from_list(self):
method get_frame (line 139) | def get_frame(self):
method close (line 157) | def close(self):
class Writer (line 164) | class Writer:
method __init__ (line 167) | def __init__(self, args, audio, height, width, video_save_path, fps):
method write_frame (line 195) | def write_frame(self, frame):
method close (line 200) | def close(self):
function inference_video (line 206) | def inference_video(args, video_save_path, device=None, total_workers=1,...
function run (line 277) | def run(args):
function main (line 332) | def main():
FILE: scripts/metrics/MANIQA/inference_MANIQA.py
function parse_args (line 12) | def parse_args():
function main (line 25) | def main():
FILE: scripts/metrics/MANIQA/models/model_attentionIQA2.py
class ChannelAttn (line 10) | class ChannelAttn(nn.Module):
method __init__ (line 12) | def __init__(self, dim, drop=0.1):
method forward (line 22) | def forward(self, x):
class SaveOutput (line 38) | class SaveOutput:
method __init__ (line 40) | def __init__(self):
method __call__ (line 43) | def __call__(self, module, module_in, module_out):
method clear (line 46) | def clear(self):
class AttentionIQA (line 50) | class AttentionIQA(nn.Module):
method __init__ (line 52) | def __init__(self,
method extract_feature (line 105) | def extract_feature(self, save_output):
method forward (line 113) | def forward(self, x):
FILE: scripts/metrics/MANIQA/models/swin.py
function get_attn_decoder_mask (line 14) | def get_attn_decoder_mask(seq):
function get_attn_pad_mask (line 23) | def get_attn_pad_mask(seq_q, seq_k, i_pad):
class DecoderWindowAttention (line 31) | class DecoderWindowAttention(nn.Module):
method __init__ (line 45) | def __init__(self, dim, window_size, num_heads, qk_scale=None, attn_dr...
method forward (line 83) | def forward(self, q, k, v, mask=None, attn_mask=None):
method extra_repr (line 119) | def extra_repr(self) -> str:
method flops (line 122) | def flops(self, N):
class DecoderLayer (line 139) | class DecoderLayer(nn.Module):
method __init__ (line 141) | def __init__(self,
method partition (line 196) | def partition(self, inputs, B, H, W, C, shift_size=0):
method reverse (line 208) | def reverse(self, inputs, B, H, W, C, shift_size=0):
method forward (line 220) | def forward(self, mask_dec_inputs, enc_outputs, self_attn_mask, dec_en...
class SwinDecoder (line 262) | class SwinDecoder(nn.Module):
method __init__ (line 264) | def __init__(self,
method forward (line 294) | def forward(self, y_embed, enc_outputs):
class Mlp (line 328) | class Mlp(nn.Module):
method __init__ (line 330) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 339) | def forward(self, x):
function window_partition (line 348) | def window_partition(x, window_size):
function window_reverse (line 363) | def window_reverse(windows, window_size, H, W):
class WindowAttention (line 380) | class WindowAttention(nn.Module):
method __init__ (line 394) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
method forward (line 428) | def forward(self, x, mask=None):
method extra_repr (line 461) | def extra_repr(self) -> str:
method flops (line 464) | def flops(self, N):
class SwinBlock (line 478) | class SwinBlock(nn.Module):
method __init__ (line 497) | def __init__(self,
method forward (line 562) | def forward(self, x):
method extra_repr (line 601) | def extra_repr(self) -> str:
method flops (line 605) | def flops(self):
class BasicLayer (line 620) | class BasicLayer(nn.Module):
method __init__ (line 640) | def __init__(self,
method forward (line 686) | def forward(self, x):
method extra_repr (line 697) | def extra_repr(self) -> str:
method flops (line 700) | def flops(self):
class SwinTransformer (line 709) | class SwinTransformer(nn.Module):
method __init__ (line 711) | def __init__(self,
method forward (line 762) | def forward(self, x):
FILE: scripts/metrics/MANIQA/pipal_data.py
class NTIRE2022 (line 7) | class NTIRE2022(torch.utils.data.Dataset):
method __init__ (line 9) | def __init__(self, ref_path, dis_path, transform):
method __len__ (line 22) | def __len__(self):
method __getitem__ (line 25) | def __getitem__(self, idx):
FILE: scripts/metrics/MANIQA/utils.py
function crop_image (line 5) | def crop_image(top, left, patch_size, img=None):
class RandCrop (line 10) | class RandCrop(object):
method __init__ (line 12) | def __init__(self, patch_size, num_crop):
method __call__ (line 16) | def __call__(self, sample):
class Normalize (line 41) | class Normalize(object):
method __init__ (line 43) | def __init__(self, mean, var):
method __call__ (line 47) | def __call__(self, sample):
class RandHorizontalFlip (line 59) | class RandHorizontalFlip(object):
method __init__ (line 61) | def __init__(self):
method __call__ (line 64) | def __call__(self, sample):
class ToTensor (line 77) | class ToTensor(object):
method __init__ (line 79) | def __init__(self):
method __call__ (line 82) | def __call__(self, sample):
FILE: setup.py
function readme (line 12) | def readme():
function get_git_hash (line 18) | def get_git_hash():
function get_hash (line 43) | def get_hash():
function write_version_py (line 58) | def write_version_py():
function get_version (line 75) | def get_version():
function get_requirements (line 81) | def get_requirements(filename='requirements.txt'):
Condensed preview — 47 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (244K chars).
[
{
"path": ".github/workflows/pylint.yml",
"chars": 643,
"preview": "name: PyLint\n\non: [push, pull_request]\n\njobs:\n build:\n\n runs-on: ubuntu-latest\n strategy:\n matrix:\n p"
},
{
"path": ".gitignore",
"chars": 1900,
"preview": "datasets/*\nexperiments/*\nresults/*\ntb_logger/*\nwandb/*\ntmp/*\nweights/*\ninputs/*\n\n*.DS_Store\n\n# Byte-compiled / optimized"
},
{
"path": ".pre-commit-config.yaml",
"chars": 1203,
"preview": "repos:\n # flake8\n - repo: https://github.com/PyCQA/flake8\n rev: 3.7.9\n hooks:\n - id: flake8\n args: ["
},
{
"path": "LICENSE",
"chars": 32042,
"preview": "Tencent is pleased to support the open source community by making AnimeSR available.\n\nCopyright (C) 2022 THL A29 Limited"
},
{
"path": "README.md",
"chars": 7118,
"preview": "# AnimeSR (NeurIPS 2022)\n\n### :open_book: AnimeSR: Learning Real-World Super-Resolution Models for Animation Videos\n> [!"
},
{
"path": "Training.md",
"chars": 6957,
"preview": "# :computer: How to Train AnimeSR\n\n- [Overview](#overview)\n- [Dataset Preparation](#dataset-preparation)\n- [Training](#t"
},
{
"path": "VERSION",
"chars": 6,
"preview": "0.1.0\n"
},
{
"path": "animesr/__init__.py",
"chars": 126,
"preview": "# flake8: noqa\nfrom .archs import *\nfrom .data import *\nfrom .models import *\n\n# from .version import __gitsha__, __vers"
},
{
"path": "animesr/archs/__init__.py",
"chars": 498,
"preview": "import importlib\nfrom os import path as osp\n\nfrom basicsr.utils import scandir\n\n# automatically scan and import arch mod"
},
{
"path": "animesr/archs/discriminator_arch.py",
"chars": 8048,
"preview": "import functools\nfrom torch import nn as nn\nfrom torch.nn import functional as F\nfrom torch.nn.utils import spectral_nor"
},
{
"path": "animesr/archs/simple_degradation_arch.py",
"chars": 1755,
"preview": "from torch import nn as nn\n\nfrom basicsr.archs.arch_util import default_init_weights, pixel_unshuffle\nfrom basicsr.utils"
},
{
"path": "animesr/archs/vsr_arch.py",
"chars": 5574,
"preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.archs.arch_util import Residu"
},
{
"path": "animesr/data/__init__.py",
"chars": 517,
"preview": "import importlib\nfrom os import path as osp\n\nfrom basicsr.utils import scandir\n\n# automatically scan and import dataset "
},
{
"path": "animesr/data/data_utils.py",
"chars": 1103,
"preview": "import random\nimport torch\n\n\ndef random_crop(imgs, patch_size, top=None, left=None):\n \"\"\"\n randomly crop patches f"
},
{
"path": "animesr/data/ffmpeg_anime_dataset.py",
"chars": 8740,
"preview": "import cv2\nimport ffmpeg\nimport glob\nimport numpy as np\nimport os\nimport random\nimport torch\nfrom os import path as osp\n"
},
{
"path": "animesr/data/ffmpeg_anime_lbo_dataset.py",
"chars": 4858,
"preview": "import numpy as np\nimport random\nimport torch\nfrom torch.nn import functional as F\n\nfrom animesr.archs.simple_degradatio"
},
{
"path": "animesr/data/paired_image_dataset.py",
"chars": 3970,
"preview": "import glob\nimport os\nfrom torch.utils import data as data\nfrom torchvision.transforms.functional import normalize\n\nfrom"
},
{
"path": "animesr/models/__init__.py",
"chars": 508,
"preview": "import importlib\nfrom os import path as osp\n\nfrom basicsr.utils import scandir\n\n# automatically scan and import model mo"
},
{
"path": "animesr/models/degradation_gan_model.py",
"chars": 2514,
"preview": "from collections import OrderedDict\n\nfrom basicsr.models.srgan_model import SRGANModel\nfrom basicsr.utils.registry impor"
},
{
"path": "animesr/models/degradation_model.py",
"chars": 1458,
"preview": "from collections import OrderedDict\n\nfrom basicsr.losses import build_loss\nfrom basicsr.models.sr_model import SRModel\nf"
},
{
"path": "animesr/models/video_recurrent_gan_model.py",
"chars": 6898,
"preview": "from collections import OrderedDict\n\nfrom basicsr.archs import build_network\nfrom basicsr.losses import build_loss\nfrom "
},
{
"path": "animesr/models/video_recurrent_model.py",
"chars": 9915,
"preview": "import cv2\nimport os\nimport torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom torch import distr"
},
{
"path": "animesr/test.py",
"chars": 269,
"preview": "# flake8: noqa\nimport os.path as osp\n\nimport animesr.archs\nimport animesr.data\nimport animesr.models\nfrom basicsr.test i"
},
{
"path": "animesr/train.py",
"chars": 272,
"preview": "# flake8: noqa\nimport os.path as osp\n\nimport animesr.archs\nimport animesr.data\nimport animesr.models\nfrom basicsr.train "
},
{
"path": "animesr/utils/__init__.py",
"chars": 24,
"preview": "# -*- coding: utf-8 -*-\n"
},
{
"path": "animesr/utils/inference_base.py",
"chars": 2568,
"preview": "import argparse\nimport os.path\nimport torch\n\nfrom animesr.archs.vsr_arch import MSRSWVSR\n\n\ndef get_base_argument_parser("
},
{
"path": "animesr/utils/shot_detector.py",
"chars": 10360,
"preview": "# The codes below partially refer to the PySceneDetect. According\n# to its BSD 3-Clause License, we keep the following.\n"
},
{
"path": "animesr/utils/video_util.py",
"chars": 5611,
"preview": "import glob\nimport os\nimport subprocess\n\ndefault_ffmpeg_exe_path = 'ffmpeg'\ndefault_ffprobe_exe_path = 'ffprobe'\ndefault"
},
{
"path": "cog.yaml",
"chars": 438,
"preview": "build:\n gpu: true\n cuda: \"11.6.2\"\n python_version: \"3.10\"\n system_packages:\n - \"libgl1-mesa-glx\"\n - \"libglib2."
},
{
"path": "options/train_animesr_step1_gan_BasicOPonly.yml",
"chars": 2506,
"preview": "# general settings\nname: train_animesr_step1_gan_BasicOPonly\nmodel_type: VideoRecurrentGANCustomModel\nscale: 4\nnum_gpu: "
},
{
"path": "options/train_animesr_step1_net_BasicOPonly.yml",
"chars": 1555,
"preview": "# general settings\nname: train_animesr_step1_net_BasicOPonly\nmodel_type: VideoRecurrentCustomModel\nscale: 4\nnum_gpu: aut"
},
{
"path": "options/train_animesr_step2_lbo_1_gan.yml",
"chars": 2220,
"preview": "# general settings\nname: train_animesr_step2_lbo_1_gan\nmodel_type: DegradationGANModel\nscale: 2\nnum_gpu: auto # set num"
},
{
"path": "options/train_animesr_step2_lbo_1_net.yml",
"chars": 1428,
"preview": "# general settings\nname: train_animesr_step2_lbo_1_net\nmodel_type: DegradationModel\nscale: 2\nnum_gpu: auto # set num_gp"
},
{
"path": "options/train_animesr_step3_gan_3LBOs.yml",
"chars": 2634,
"preview": "# general settings\nname: train_animesr_step3_gan_3LBOs\nmodel_type: VideoRecurrentGANCustomModel\nscale: 4\nnum_gpu: auto "
},
{
"path": "predict.py",
"chars": 3853,
"preview": "import os\nimport shutil\nimport tempfile\nfrom subprocess import call\nfrom zipfile import ZipFile\nfrom typing import Optio"
},
{
"path": "requirements.txt",
"chars": 88,
"preview": "basicsr\nfacexlib\nffmpeg-python\nnumpy\nopencv-python\npillow\npsutil\ntorch\ntorchvision\ntqdm\n"
},
{
"path": "scripts/anime_videos_preprocessing.py",
"chars": 23504,
"preview": "import argparse\nimport cv2\nimport glob\nimport numpy as np\nimport os\nimport shutil\nimport torch\nimport torchvision\nfrom m"
},
{
"path": "scripts/inference_animesr_frames.py",
"chars": 7046,
"preview": "\"\"\"inference AnimeSR on frames\"\"\"\nimport argparse\nimport cv2\nimport glob\nimport numpy as np\nimport os\nimport psutil\nimpo"
},
{
"path": "scripts/inference_animesr_video.py",
"chars": 14080,
"preview": "import cv2\nimport ffmpeg\nimport glob\nimport mimetypes\nimport numpy as np\nimport os\nimport shutil\nimport subprocess\nimpor"
},
{
"path": "scripts/metrics/MANIQA/inference_MANIQA.py",
"chars": 3844,
"preview": "import argparse\nimport os\nimport random\nimport torch\nfrom pipal_data import NTIRE2022\nfrom torch.utils.data import DataL"
},
{
"path": "scripts/metrics/MANIQA/models/model_attentionIQA2.py",
"chars": 4630,
"preview": "# flake8: noqa\nimport timm\nimport torch\nfrom einops import rearrange\nfrom models.swin import SwinTransformer\nfrom timm.m"
},
{
"path": "scripts/metrics/MANIQA/models/swin.py",
"chars": 31957,
"preview": "\"\"\"\nisort:skip_file\n\"\"\"\n# flake8: noqa\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as che"
},
{
"path": "scripts/metrics/MANIQA/pipal_data.py",
"chars": 1491,
"preview": "import cv2\nimport numpy as np\nimport os\nimport torch\n\n\nclass NTIRE2022(torch.utils.data.Dataset):\n\n def __init__(self"
},
{
"path": "scripts/metrics/MANIQA/utils.py",
"chars": 2619,
"preview": "import numpy as np\nimport torch\n\n\ndef crop_image(top, left, patch_size, img=None):\n tmp_img = img[:, :, top:top + pat"
},
{
"path": "scripts/metrics/README.md",
"chars": 1162,
"preview": "# Instruction for calculating metrics\n\n## Prepare the frames\nFor fast evaluation, we measure 1 frames every 10 frames, t"
},
{
"path": "setup.cfg",
"chars": 577,
"preview": "[flake8]\nignore =\n # line break before binary operator (W503)\n W503,\n # line break after binary operator (W504)"
},
{
"path": "setup.py",
"chars": 3354,
"preview": "#!/usr/bin/env python\n\nfrom setuptools import find_packages, setup\n\nimport os\nimport subprocess\nimport time\n\nversion_fil"
}
]
About this extraction
This page contains the full source code of the TencentARC/AnimeSR GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 47 files (228.9 KB), approximately 59.8k tokens, and a symbol index with 190 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.