Repository: TencentARC/AnimeSR Branch: main Commit: 80a24bf7a527 Files: 47 Total size: 228.9 KB Directory structure: gitextract_0qygtin4/ ├── .github/ │ └── workflows/ │ └── pylint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── Training.md ├── VERSION ├── animesr/ │ ├── __init__.py │ ├── archs/ │ │ ├── __init__.py │ │ ├── discriminator_arch.py │ │ ├── simple_degradation_arch.py │ │ └── vsr_arch.py │ ├── data/ │ │ ├── __init__.py │ │ ├── data_utils.py │ │ ├── ffmpeg_anime_dataset.py │ │ ├── ffmpeg_anime_lbo_dataset.py │ │ └── paired_image_dataset.py │ ├── models/ │ │ ├── __init__.py │ │ ├── degradation_gan_model.py │ │ ├── degradation_model.py │ │ ├── video_recurrent_gan_model.py │ │ └── video_recurrent_model.py │ ├── test.py │ ├── train.py │ └── utils/ │ ├── __init__.py │ ├── inference_base.py │ ├── shot_detector.py │ └── video_util.py ├── cog.yaml ├── options/ │ ├── train_animesr_step1_gan_BasicOPonly.yml │ ├── train_animesr_step1_net_BasicOPonly.yml │ ├── train_animesr_step2_lbo_1_gan.yml │ ├── train_animesr_step2_lbo_1_net.yml │ └── train_animesr_step3_gan_3LBOs.yml ├── predict.py ├── requirements.txt ├── scripts/ │ ├── anime_videos_preprocessing.py │ ├── inference_animesr_frames.py │ ├── inference_animesr_video.py │ └── metrics/ │ ├── MANIQA/ │ │ ├── inference_MANIQA.py │ │ ├── models/ │ │ │ ├── model_attentionIQA2.py │ │ │ └── swin.py │ │ ├── pipal_data.py │ │ └── utils.py │ └── README.md ├── setup.cfg └── setup.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/pylint.yml ================================================ name: PyLint on: [push, pull_request] jobs: build: runs-on: ubuntu-latest strategy: matrix: python-version: [3.8] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install flake8 yapf isort - name: Lint run: | flake8 . isort --check-only --diff animesr/ options/ scripts/ setup.py yapf -r -d animesr/ options/ scripts/ setup.py ================================================ FILE: .gitignore ================================================ datasets/* experiments/* results/* tb_logger/* wandb/* tmp/* weights/* inputs/* *.DS_Store # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ .idea/ ================================================ FILE: .pre-commit-config.yaml ================================================ repos: # flake8 - repo: https://github.com/PyCQA/flake8 rev: 3.7.9 hooks: - id: flake8 args: ["--config=setup.cfg", "--ignore=W504, W503"] # modify known_third_party - repo: https://github.com/asottile/seed-isort-config rev: v2.2.0 hooks: - id: seed-isort-config args: ["--exclude=scripts/metrics/MANIQA"] # isort - repo: https://github.com/timothycrosley/isort rev: 5.2.2 hooks: - id: isort # yapf - repo: https://github.com/pre-commit/mirrors-yapf rev: v0.30.0 hooks: - id: yapf # pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: - id: trailing-whitespace # Trim trailing whitespace - id: check-yaml # Attempt to load all yaml files to verify syntax - id: check-merge-conflict # Check for files that contain merge conflict strings - id: end-of-file-fixer # Make sure files end in a newline and only a newline - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 - id: mixed-line-ending # Replace or check mixed line ending args: ["--fix=lf"] ================================================ FILE: LICENSE ================================================ Tencent is pleased to support the open source community by making AnimeSR available. Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. AnimeSR is licensed under the Apache License Version 2.0 except for the third-party components listed below. Terms of the Apache License Version 2.0: --------------------------------------------- Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. “License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. “Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. “You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License. “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. “Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. “Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.” “Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS Other dependencies and licenses: Open Source Software Licensed under the Apache License Version 2.0: -------------------------------------------------------------------- 1. ffmpeg-python Copyright 2017 Karl Kroening 2. basicsr Copyright 2018-2022 BasicSR Authors Terms of the Apache License Version 2.0: -------------------------------------------------------------------- Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS Open Source Software Licensed under the BSD 3-Clause License: -------------------------------------------------------------------- 1. torch From PyTorch: Copyright (c) 2016- Facebook, Inc (Adam Paszke) Copyright (c) 2014- Facebook, Inc (Soumith Chintala) Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) Copyright (c) 2011-2013 NYU (Clement Farabet) Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) Copyright (c) 2006 Idiap Research Institute (Samy Bengio) Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) From Caffe2: Copyright (c) 2016-present, Facebook Inc. All rights reserved. All contributions by Facebook: Copyright (c) 2016 Facebook Inc. All contributions by Google: Copyright (c) 2015 Google Inc. All rights reserved. All contributions by Yangqing Jia: Copyright (c) 2015 Yangqing Jia All rights reserved. All contributions by Kakao Brain: Copyright 2019-2020 Kakao Brain All contributions by Cruise LLC: Copyright (c) 2022 Cruise LLC. All rights reserved. All contributions from Caffe: Copyright(c) 2013, 2014, 2015, the respective contributors All rights reserved. All other contributions: Copyright(c) 2015, 2016 the respective contributors All rights reserved. Caffe2 uses a copyright model similar to Caffe: each contributor holds copyright over their contributions to Caffe2. The project versioning records all such contribution and copyright details. If a contributor wants to further mark their specific copyright on a particular contribution, they should indicate their copyright solely in the commit message of the change when it is committed. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America and IDIAP Research Institute nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 2. torchvision Copyright (c) Soumith Chintala 2016, All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 3. numpy Copyright (c) 2005-2022, NumPy Developers. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the NumPy Developers nor the names of any contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 4. psutil Copyright (c) 2009, Jay Loden, Dave Daeschler, Giampaolo Rodola' Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the psutil authors nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Open Source Software Licensed under the HPND License: -------------------------------------------------------------------- 1. Pillow The Python Imaging Library (PIL) is Copyright © 1997-2011 by Secret Labs AB Copyright © 1995-2011 by Fredrik Lundh Pillow is the friendly PIL fork. It is Copyright © 2010-2022 by Alex Clark and contributors Like PIL, Pillow is licensed under the open source HPND License: By obtaining, using, and/or copying this software and/or its associated documentation, you agree that you have read, understood, and will comply with the following terms and conditions: Permission to use, copy, modify, and distribute this software and its associated documentation for any purpose and without fee is hereby granted, provided that the above copyright notice appears in all copies, and that both that copyright notice and this permission notice appear in supporting documentation, and that the name of Secret Labs AB or the author not be used in advertising or publicity pertaining to distribution of the software without specific, written prior permission. SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. Open Source Software Licensed under the MIT License: -------------------------------------------------------------------- 1. opencv-python Copyright (c) Olli-Pekka Heinisuo Terms of the MIT License: -------------------------------------------------------------------- Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Open Source Software Licensed under the MIT and MPL v2.0 Licenses: -------------------------------------------------------------------- 1. tqdm `tqdm` is a product of collaborative work. Unless otherwise stated, all authors (see commit logs) retain copyright for their respective work, and release the work under the MIT licence (text below). Exceptions or notable authors are listed below in reverse chronological order: * files: * MPLv2.0 2015-2021 (c) Casper da Costa-Luis [casperdcl](https://github.com/casperdcl). * files: tqdm/_tqdm.py MIT 2016 (c) [PR #96] on behalf of Google Inc. * files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore MIT 2013 (c) Noam Yorav-Raphael, original author. [PR #96]: https://github.com/tqdm/tqdm/pull/96 Mozilla Public Licence (MPL) v. 2.0 - Exhibit A ----------------------------------------------- This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this project, You can obtain one at https://mozilla.org/MPL/2.0/. MIT License (MIT) ----------------- Copyright (c) 2013 noamraph Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # AnimeSR (NeurIPS 2022) ### :open_book: AnimeSR: Learning Real-World Super-Resolution Models for Animation Videos > [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2206.07038)
> [Yanze Wu](https://github.com/ToTheBeginning), [Xintao Wang](https://xinntao.github.io/), [Gen Li](https://scholar.google.com.hk/citations?user=jBxlX7oAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en)
> [Tencent ARC Lab](https://arc.tencent.com/en/index); Platform Technologies, Tencent Online Video ### :triangular_flag_on_post: Updates * **2022.11.28**: release codes&models. * **2022.08.29**: release AVC-Train and AVC-Test. ## Web Demo and API [![Replicate](https://replicate.com/cjwbw/animesr/badge)](https://replicate.com/cjwbw/animesr) ## Video Demos https://user-images.githubusercontent.com/11482921/204205018-d69e2e51-fbdc-4766-8293-a40ffce3ed25.mp4 https://user-images.githubusercontent.com/11482921/204205109-35866094-fa7f-413b-8b43-bb479b42dfb6.mp4 ## :wrench: Dependencies and Installation - Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) - [PyTorch >= 1.7](https://pytorch.org/) - Other required packages in `requirements.txt` ### Installation 1. Clone repo ```bash git clone https://github.com/TencentARC/AnimeSR.git cd AnimeSR ``` 2. Install ```bash # Install dependent packages pip install -r requirements.txt # Install AnimeSR python setup.py develop ``` ## :zap: Quick Inference Download the pre-trained AnimeSR models [[Google Drive](https://drive.google.com/drive/folders/1gwNTbKLUjt5FlgT6PQQnBz5wFzmNUX8g?usp=share_link)], and put them into the [weights](weights/) folder. Currently, the available pre-trained models are: - `AnimeSR_v1-PaperModel.pth`: v1 model, also the paper model. You can use this model for paper results reproducing. - `AnimeSR_v2.pth`: v2 model. Compare with v1, this version has better naturalness, fewer artifacts, and better texture/background restoration. If you want better results, use this model. AnimeSR supports both frames and videos as input for inference. We provide several sample test cases in [google drive](https://drive.google.com/drive/folders/1K8JzGNqY_pHahYBv51iUUI7_hZXmamt2?usp=share_link), you can download it and put them to [inputs](inputs/) folder. **Inference on Frames** ```bash python scripts/inference_animesr_frames.py -i inputs/tom_and_jerry -n AnimeSR_v2 --expname animesr_v2 --save_video_too --fps 20 ``` ```console Usage: -i --input Input frames folder/root. Support first level dir (i.e., input/*.png) and second level dir (i.e., input/*/*.png) -n --model_name AnimeSR model name. Default: AnimeSR_v2, can also be AnimeSR_v1-PaperModel -s --outscale The netscale is x4, but you can achieve arbitrary output scale (e.g., x2 or x1) with the argument outscale. The program will further perform cheap resize operation after the AnimeSR output. Default: 4 -o --output Output root. Default: results -expname Identify the name of your current inference. The outputs will be saved in $output/$expname -save_video_too Save the output frames to video. Default: off -fps The fps of the (possible) saved videos. Default: 24 ``` After run the above command, you will get the SR frames in `results/animesr_v2/frames` and the SR video in `results/animesr_v2/videos`. **Inference on Video** ```bash # single gpu and single process inference CUDA_VISIBLE_DEVICES=0 python scripts/inference_animesr_video.py -i inputs/TheMonkeyKing1965.mp4 -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 1 --suffix 1gpu1process # single gpu and multi process inference (you can use multi-processing to improve GPU utilization) CUDA_VISIBLE_DEVICES=0 python scripts/inference_animesr_video.py -i inputs/TheMonkeyKing1965.mp4 -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 3 --suffix 1gpu3process # multi gpu and multi process inference CUDA_VISIBLE_DEVICES=0,1 python scripts/inference_animesr_video.py -i inputs/TheMonkeyKing1965.mp4 -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 3 --suffix 2gpu6process ``` ```console Usage: -i --input Input video path or extracted frames folder -n --model_name AnimeSR model name. Default: AnimeSR_v2, can also be AnimeSR_v1-PaperModel -s --outscale The netscale is x4, but you can achieve arbitrary output scale (e.g., x2 or x1) with the argument outscale. The program will further perform cheap resize operation after the AnimeSR output. Default: 4 -o -output Output root. Default: results -expname Identify the name of your current inference. The outputs will be saved in $output/$expname -fps The fps of the (possible) saved videos. Default: None -extract_frame_first If input is a video, you can still extract the frames first, other wise AnimeSR will read from stream -num_process_per_gpu Since the slow I/O speed will make GPU utilization not high enough, so as long as the video memory is sufficient, we recommend placing multiple processes on one GPU to increase the utilization of each GPU. The total process will be number_process_per_gpu * num_gpu -suffix You can add a suffix string to the sr video name, for example, 1gpu3processx2 which means the SR video is generated with one GPU and three process and the outscale is x2 -half Use half precision for inference, it won't make big impact on the visual results ``` SR videos are saved in `results/animesr_v2/videos/$video_name` folder. If you are looking for portable executable files, you can try our [realesr-animevideov3](https://github.com/xinntao/Real-ESRGAN/blob/master/docs/anime_video_model.md) model which shares the similar technology with AnimeSR. ## :computer: Training See [Training.md](Training.md) ## Request for AVC-Dataset 1. Download and carefully read the [LICENSE AGREEMENT](assets/LICENSE%20AGREEMENT.pdf) PDF file. 2. If you understand, acknowledge, and agree to all the terms specified in the [LICENSE AGREEMENT](assets/LICENSE%20AGREEMENT.pdf). Please email `wuyanze123@gmail.com` with the **LICENSE AGREEMENT PDF** file, **your name**, and **institution**. We will keep the license and send the download link of AVC dataset to you. ## Acknowledgement This project is build based on [BasicSR](https://github.com/XPixelGroup/BasicSR). ## Citation If you find this project useful for your research, please consider citing our paper: ```bibtex @InProceedings{wu2022animesr, author={Wu, Yanze and Wang, Xintao and Li, Gen and Shan, Ying}, title={AnimeSR: Learning Real-World Super-Resolution Models for Animation Videos}, booktitle={Advances in Neural Information Processing Systems}, year={2022} } ``` ## :e-mail: Contact If you have any question, please email `wuyanze123@gmail.com`. ================================================ FILE: Training.md ================================================ # :computer: How to Train AnimeSR - [Overview](#overview) - [Dataset Preparation](#dataset-preparation) - [Training](#training) - [Training step 1](#training-step-1) - [Training step 2](#training-step-2) - [Training step 3](#training-step-3) - [The Pre-Trained Checkpoints](#the-pre-trained-checkpoints) - [Other Tips](#other-tips) - [How to build your own (training) dataset ?](#how-to-build-your-own-training-dataset-) ## Overview The training has been divided into three steps. 1. Training a Video Super-Resolution (VSR) model with a degradation model that only contains the classic basic operators (*i.e.*, blur, noise, downscale, compression). 2. Training several **L**earnable **B**asic **O**perators (**LBO**s). Using the VSR model from step 1 and the input-rescaling strategy to generate pseudo HR for real-world LR. The paired pseudo HR-LR data are used to train the LBO in a supervised manner. 3. Training the final VSR model with a degradation model containing both classic basic operators and learnable basic operators. Specifically, the model training in each step consists of two stages. In the first stage, the model is trained with L1 loss from scratch. In the second stage, the model is fine-tuned with the combination of L1 loss, perceptual loss, and GAN loss. ## Dataset Preparation We use AVC-Train dataset for our training. The AVC dataset is released under request, please refer to [Request for AVC-Dataset](README.md#request-for-avc-dataset). After you download the AVC-Train dataset, put all the clips into one folder (dataset root). The dataset root should contain 553 folders(clips), each folder contains 100 frames, from `00000000.png` to `00000099.png`. If you want to build your own (training) dataset or enlarge AVC-Train dataset, please refer to [How to build your own (training) dataset](#how-to-build-your-own-training-dataset). ## Training As described in the paper, all the training is performed on four NVIDIA A100 GPUs in an internal cluster. You may need to adjust the batchsize according to the CUDA memory of your GPU card. Before the training, you should modify the [option files](options/) accordingly. For example, you should modify the `dataroot_gt` to your own dataset root. We have comment all the lines you should modify with the `TO_MODIFY` tag. ### Training step 1 1. Train `Net` model Before the training, you should modify the [yaml file](options/train_animesr_step1_net_BasicOPonly.yml) accordingly. For example, you should modify the `dataroot_gt` to your own dataset root. ```bash CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step1_net_BasicOPonly.yml --launcher pytorch --auto_resume ``` 2. Train `GAN` model The GAN model is fine-tuned from the `Net` model, as specified in `pretrain_network_g` ```bash CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step1_gan_BasicOPonly.yml --launcher pytorch --auto_resume ``` ### Training step 2 The input frames for training LBO in the paper are included in the AVC dataset download link we sent you. These frames came from three real-world LR animation videos, and ~2,000 frames are selected from each video. In order to obtain the paired data required for training LBO, you will need to use the VSR model obtained in step 1 and the input-rescaling strategy as described in the paper to process these input frames to obtain pseudo HR-LR paired data. ```bash # make the soft link for the VSR model obtained in step 1 ln -s experiments/train_animesr_step1_net_BasicOPonly/models/net_g_300000.pth weights/step1_vsr_gan_model.pth # using input-rescaling strategy to inference python scripts/inference_animesr_frames.py -i datasets/lbo_training_data/real_world_video_to_train_lbo_1 -n step1_vsr_gan_model --input_rescaling_factor 0.5 --mod_scale 8 --expname input_rescaling_strategy_lbo_1 ``` After the inference, you can train the LBO. Note that we only provide one [option file](options/train_animesr_step2_lbo_1_net.yml) for training `Net` model and one [option file](options/train_animesr_step2_lbo_1_gan.yml) for training `GAN` model. If you want to train multiple LBOs, just copy&paste those option files and modify the `name`, `dataroot_gt`, and `dataroot_lq`. ```bash # train Net model CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step2_lbo_1_net.yml --launcher pytorch --auto_resume # train GAN model CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step2_lbo_1_gan.yml --launcher pytorch --auto_resume ``` ### Training step 3 Before the training, you will need to modify the `degradation_model_path` to the pre-trained LBO path. ```bash CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step3_gan_3LBOs.yml --launcher pytorch --auto_resume ``` ## Evaluation See [evaluation readme](scripts/metrics/README.md). ## The Pre-Trained Checkpoints You can download the checkpoints of all steps in [google drive](https://drive.google.com/drive/folders/1hCXhKNZYBADXsS_weHO2z3HhNE-Eg_jw?usp=share_link). ## Other Tips #### How to build your own (training) dataset ? Suppose you have a batch of HQ (high resolution, high bitrate, high quality) animation video, we provide the [anime_videos_preprocessing.py](scripts/anime_videos_preprocessing.py) script to help you to prepare training clips from the raw videos. The preprocessing consists of 6 steps: 1. use FFmpeg to extract frames. Note that this step will take up a lot of disk space. 2. shot detection using [PySceneDetect](https://github.com/Breakthrough/PySceneDetect/) 3. flow estimation using spynet 4. black frames detection 5. image quality assessment using hyperIQA 6. generate clips for each video ```console Usage: python scripts/anime_videos_preprocessing.py --dataroot datasets/YOUR_OWN_ANIME --n_thread 4 --run 1 --dataroot dataset root, dataroot/raw_videos should contains your HQ videos to be processed --n_thread number of workers to process in parallel --run which step to run. Since each step may take a long time, we recommend performing it step by step. And after each step, check whether the output files are as expected --n_frames_per_clip number of frames per clip. Default 100. You can increase the number if you want more training data --n_clips_per_video number of clips per video. Default 1. You can increase the number if you want more training data ``` After you finish all the steps, you will get the clips in `dataroot/select_clips` ================================================ FILE: VERSION ================================================ 0.1.0 ================================================ FILE: animesr/__init__.py ================================================ # flake8: noqa from .archs import * from .data import * from .models import * # from .version import __gitsha__, __version__ ================================================ FILE: animesr/archs/__init__.py ================================================ import importlib from os import path as osp from basicsr.utils import scandir # automatically scan and import arch modules for registry # scan all the files that end with '_arch.py' under the archs folder arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] # import all the arch modules _arch_modules = [importlib.import_module(f'animesr.archs.{file_name}') for file_name in arch_filenames] ================================================ FILE: animesr/archs/discriminator_arch.py ================================================ import functools from torch import nn as nn from torch.nn import functional as F from torch.nn.utils import spectral_norm from basicsr.utils.registry import ARCH_REGISTRY def get_conv_layer(input_nc, ndf, kernel_size, stride, padding, bias=True, use_sn=False): if not use_sn: return nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) return spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) @ARCH_REGISTRY.register() class UNetDiscriminatorSN(nn.Module): """Defines a U-Net discriminator with spectral normalization (SN). copy from real-esrgan""" def __init__(self, num_in_ch, num_feat=64, skip_connection=True): super(UNetDiscriminatorSN, self).__init__() self.skip_connection = skip_connection norm = spectral_norm self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) # upsample self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) # extra self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) def forward(self, x): x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) # upsample x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) if self.skip_connection: x4 = x4 + x2 x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) if self.skip_connection: x5 = x5 + x1 x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) if self.skip_connection: x6 = x6 + x0 # extra out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) out = self.conv9(out) return out @ARCH_REGISTRY.register() class PatchDiscriminator(nn.Module): """Defines a PatchGAN discriminator, the receptive field of default config is 70x70. Args: use_sn (bool): Use spectra_norm or not, if use_sn is True, then norm_type should be none. """ def __init__(self, num_in_ch, num_feat=64, num_layers=3, max_nf_mult=8, norm_type='batch', use_sigmoid=False, use_sn=False): super(PatchDiscriminator, self).__init__() norm_layer = self._get_norm_layer(norm_type) if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d kw = 4 padw = 1 sequence = [ get_conv_layer(num_in_ch, num_feat, kernel_size=kw, stride=2, padding=padw, use_sn=use_sn), nn.LeakyReLU(0.2, True) ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, num_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2**n, max_nf_mult) sequence += [ get_conv_layer( num_feat * nf_mult_prev, num_feat * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias, use_sn=use_sn), norm_layer(num_feat * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2**num_layers, max_nf_mult) sequence += [ get_conv_layer( num_feat * nf_mult_prev, num_feat * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias, use_sn=use_sn), norm_layer(num_feat * nf_mult), nn.LeakyReLU(0.2, True) ] # output 1 channel prediction map sequence += [get_conv_layer(num_feat * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, use_sn=use_sn)] if use_sigmoid: sequence += [nn.Sigmoid()] self.model = nn.Sequential(*sequence) def _get_norm_layer(self, norm_type='batch'): if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) elif norm_type == 'batchnorm2d': norm_layer = nn.BatchNorm2d elif norm_type == 'none': norm_layer = nn.Identity else: raise NotImplementedError(f'normalization layer [{norm_type}] is not found') return norm_layer def forward(self, x): return self.model(x) @ARCH_REGISTRY.register() class MultiScaleDiscriminator(nn.Module): """Define a multi-scale discriminator, each discriminator is a instance of PatchDiscriminator. Args: num_layers (int or list): If the type of this variable is int, then degrade to PatchDiscriminator. If the type of this variable is list, then the length of the list is the number of discriminators. use_downscale (bool): Progressive downscale the input to feed into different discriminators. If set to True, then the discriminators are usually the same. """ def __init__(self, num_in_ch, num_feat=64, num_layers=3, max_nf_mult=8, norm_type='batch', use_sigmoid=False, use_sn=False, use_downscale=False): super(MultiScaleDiscriminator, self).__init__() if isinstance(num_layers, int): num_layers = [num_layers] # check whether the discriminators are the same if use_downscale: assert len(set(num_layers)) == 1 self.use_downscale = use_downscale self.num_dis = len(num_layers) self.dis_list = nn.ModuleList() for nl in num_layers: self.dis_list.append( PatchDiscriminator( num_in_ch, num_feat=num_feat, num_layers=nl, max_nf_mult=max_nf_mult, norm_type=norm_type, use_sigmoid=use_sigmoid, use_sn=use_sn, )) def forward(self, x): outs = [] h, w = x.size()[2:] y = x for i in range(self.num_dis): if i != 0 and self.use_downscale: y = F.interpolate(y, size=(h // 2, w // 2), mode='bilinear', align_corners=True) h, w = y.size()[2:] outs.append(self.dis_list[i](y)) return outs ================================================ FILE: animesr/archs/simple_degradation_arch.py ================================================ from torch import nn as nn from basicsr.archs.arch_util import default_init_weights, pixel_unshuffle from basicsr.utils.registry import ARCH_REGISTRY @ARCH_REGISTRY.register() class SimpleDegradationArch(nn.Module): """simple degradation architecture which consists several conv and non-linear layer it learns the mapping from HR to LR """ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, downscale=2): """ :param num_in_ch: input is a pseudo HR image, channel is 3 :param num_out_ch: output is an LR image, channel is also 3 :param num_feat: we use a small network, hidden dimension is 64 :param downscale: suppose (h, w) is the height&width of a real-world LR video. Firstly, we select the best rescaling factor (usually around 0.5) for this LR video. Secondly, we obtain the pseudo HR frames and resize them to (2h, 2w). To learn the mapping from pseudo HR to LR, LBO contains a pixel-unshuffle layer with a scale factor of 2 to perform the downsampling at the beginning. """ super(SimpleDegradationArch, self).__init__() num_in_ch = num_in_ch * downscale * downscale self.main = nn.Sequential( nn.Conv2d(num_in_ch, num_feat, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(num_feat, num_out_ch, 3, 1, 1), ) self.downscale = downscale default_init_weights(self.main) def forward(self, x): x = pixel_unshuffle(x, self.downscale) x = self.main(x) return x ================================================ FILE: animesr/archs/vsr_arch.py ================================================ import torch from torch import nn as nn from torch.nn import functional as F from basicsr.archs.arch_util import ResidualBlockNoBN, pixel_unshuffle from basicsr.utils.registry import ARCH_REGISTRY class RightAlignMSConvResidualBlocks(nn.Module): """right align multi-scale ConvResidualBlocks, currently only support 3 scales (1, 2, 4)""" def __init__(self, num_in_ch=3, num_state_ch=64, num_out_ch=64, num_block=(5, 3, 2)): super().__init__() assert len(num_block) == 3 assert num_block[0] >= num_block[1] >= num_block[2] self.num_block = num_block self.conv_s1_first = nn.Sequential( nn.Conv2d(num_in_ch, num_state_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True)) self.conv_s2_first = nn.Sequential( nn.Conv2d(num_state_ch, num_state_ch, 3, 2, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True)) self.conv_s4_first = nn.Sequential( nn.Conv2d(num_state_ch, num_state_ch, 3, 2, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True), ) self.body_s1_first = nn.ModuleList() for _ in range(num_block[0]): self.body_s1_first.append(ResidualBlockNoBN(num_feat=num_state_ch)) self.body_s2_first = nn.ModuleList() for _ in range(num_block[1]): self.body_s2_first.append(ResidualBlockNoBN(num_feat=num_state_ch)) self.body_s4_first = nn.ModuleList() for _ in range(num_block[2]): self.body_s4_first.append(ResidualBlockNoBN(num_feat=num_state_ch)) self.upsample_x2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.upsample_x4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) self.fusion = nn.Sequential( nn.Conv2d(3 * num_state_ch, 2 * num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True), nn.Conv2d(2 * num_out_ch, num_out_ch, 3, 1, 1, bias=True), ) def up(self, x, scale=2): if isinstance(x, int): return x elif scale == 2: return self.upsample_x2(x) else: return self.upsample_x4(x) def forward(self, x): x_s1 = self.conv_s1_first(x) x_s2 = self.conv_s2_first(x_s1) x_s4 = self.conv_s4_first(x_s2) flag_s2 = False flag_s4 = False for i in range(0, self.num_block[0]): x_s1 = self.body_s1_first[i]( x_s1 + (self.up(x_s2, 2) if flag_s2 else 0) + (self.up(x_s4, 4) if flag_s4 else 0)) if i >= self.num_block[0] - self.num_block[1]: x_s2 = self.body_s2_first[i - self.num_block[0] + self.num_block[1]]( x_s2 + (self.up(x_s4, 2) if flag_s4 else 0)) flag_s2 = True if i >= self.num_block[0] - self.num_block[2]: x_s4 = self.body_s4_first[i - self.num_block[0] + self.num_block[2]](x_s4) flag_s4 = True x_fusion = self.fusion(torch.cat((x_s1, self.upsample_x2(x_s2), self.upsample_x4(x_s4)), dim=1)) return x_fusion @ARCH_REGISTRY.register() class MSRSWVSR(nn.Module): """ Multi-Scale, unidirectional Recurrent, Sliding Window (MSRSW) The implementation refers to paper: Efficient Video Super-Resolution through Recurrent Latent Space Propagation """ def __init__(self, num_feat=64, num_block=(5, 3, 2), netscale=4): super(MSRSWVSR, self).__init__() self.num_feat = num_feat # 3(img channel) * 3(prev cur nxt 3 imgs) + 3(hr img channel) * netscale * netscale + num_feat self.recurrent_cell = RightAlignMSConvResidualBlocks(3 * 3 + 3 * netscale * netscale + num_feat, num_feat, num_feat + 3 * netscale * netscale, num_block) self.lrelu = nn.LeakyReLU(negative_slope=0.1) self.pixel_shuffle = nn.PixelShuffle(netscale) self.netscale = netscale def cell(self, x, fb, state): res = x[:, 3:6] # pre frame, cur frame, nxt frame, pre sr frame, pre hidden state inp = torch.cat((x, pixel_unshuffle(fb, self.netscale), state), dim=1) # the out contains both state and sr frame out = self.recurrent_cell(inp) out_img = self.pixel_shuffle(out[:, :3 * self.netscale * self.netscale]) + F.interpolate( res, scale_factor=self.netscale, mode='bilinear', align_corners=False) out_state = self.lrelu(out[:, 3 * self.netscale * self.netscale:]) return out_img, out_state def forward(self, x): b, n, c, h, w = x.size() # initialize previous sr frame and previous hidden state as zero tensor out = x.new_zeros(b, c, h * self.netscale, w * self.netscale) state = x.new_zeros(b, self.num_feat, h, w) out_l = [] for i in range(n): if i == 0: # there is no previous frame for the 1st frame, so reuse 1st frame as previous out, state = self.cell(torch.cat((x[:, i], x[:, i], x[:, i + 1]), dim=1), out, state) elif i == n - 1: # there is no next frame for the last frame, so reuse last frame as next out, state = self.cell(torch.cat((x[:, i - 1], x[:, i], x[:, i]), dim=1), out, state) else: out, state = self.cell(torch.cat((x[:, i - 1], x[:, i], x[:, i + 1]), dim=1), out, state) out_l.append(out) return torch.stack(out_l, dim=1) ================================================ FILE: animesr/data/__init__.py ================================================ import importlib from os import path as osp from basicsr.utils import scandir # automatically scan and import dataset modules for registry # scan all the files that end with '_dataset.py' under the data folder data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] # import all the dataset modules _dataset_modules = [importlib.import_module(f'animesr.data.{file_name}') for file_name in dataset_filenames] ================================================ FILE: animesr/data/data_utils.py ================================================ import random import torch def random_crop(imgs, patch_size, top=None, left=None): """ randomly crop patches from imgs :param imgs: can be (list of) tensor, cv2 img :param patch_size: patch size, usually 256 :param top: will sample if is None :param left: will sample if is None :return: cropped patches from input imgs """ if not isinstance(imgs, list): imgs = [imgs] # determine input type: Numpy array or Tensor input_type = 'Tensor' if torch.is_tensor(imgs[0]) else 'Numpy' if input_type == 'Tensor': h, w = imgs[0].size()[-2:] else: h, w = imgs[0].shape[0:2] # randomly choose top and left coordinates if top is None: top = random.randint(0, h - patch_size) if left is None: left = random.randint(0, w - patch_size) if input_type == 'Tensor': imgs = [v[:, :, top:top + patch_size, left:left + patch_size] for v in imgs] else: imgs = [v[top:top + patch_size, left:left + patch_size, ...] for v in imgs] if len(imgs) == 1: imgs = imgs[0] return imgs ================================================ FILE: animesr/data/ffmpeg_anime_dataset.py ================================================ import cv2 import ffmpeg import glob import numpy as np import os import random import torch from os import path as osp from torch.utils import data as data from basicsr.data.degradations import random_add_gaussian_noise, random_mixed_kernels from basicsr.data.transforms import augment from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor from basicsr.utils.registry import DATASET_REGISTRY from .data_utils import random_crop @DATASET_REGISTRY.register() class FFMPEGAnimeDataset(data.Dataset): """Anime datasets with only classic basic operators""" def __init__(self, opt): super(FFMPEGAnimeDataset, self).__init__() self.opt = opt self.num_frame = opt['num_frame'] self.num_half_frames = opt['num_frame'] // 2 self.keys = [] self.clip_frames = {} self.gt_root = opt['dataroot_gt'] logger = get_root_logger() clip_names = os.listdir(self.gt_root) for clip_name in clip_names: num_frames = len(glob.glob(osp.join(self.gt_root, clip_name, '*.png'))) self.keys.extend([f'{clip_name}/{i:08d}' for i in range(num_frames)]) self.clip_frames[clip_name] = num_frames # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.is_lmdb = False self.iso_blur_range = opt.get('iso_blur_range', [0.2, 4]) self.aniso_blur_range = opt.get('aniso_blur_range', [0.8, 3]) self.noise_range = opt.get('noise_range', [0, 10]) self.crf_range = opt.get('crf_range', [18, 35]) self.ffmpeg_profile_names = opt.get('ffmpeg_profile_names', ['baseline', 'main', 'high']) self.ffmpeg_profile_probs = opt.get('ffmpeg_profile_probs', [0.1, 0.2, 0.7]) self.scale = opt.get('scale', 4) assert self.scale in (2, 4) # temporal augmentation configs self.interval_list = opt.get('interval_list', [1]) self.random_reverse = opt.get('random_reverse', False) interval_str = ','.join(str(x) for x in self.interval_list) logger.info(f'Temporal augmentation interval list: [{interval_str}]; ' f'random reverse is {self.random_reverse}.') def get_gt_clip(self, index): """ get the GT(hr) clip with self.num_frame frames :param index: the index from __getitem__ :return: a list of images, with numpy(cv2) format """ key = self.keys[index] # get clip from this key frame (if possible) clip_name, frame_name = key.split('/') # key example: 000/00000000 # determine the "interval" of neighboring frames interval = random.choice(self.interval_list) # ensure not exceeding the borders center_frame_idx = int(frame_name) start_frame_idx = center_frame_idx - self.num_half_frames * interval end_frame_idx = center_frame_idx + self.num_half_frames * interval # if the index doesn't satisfy the requirement, resample it if (start_frame_idx < 0) or (end_frame_idx >= self.clip_frames[clip_name]): center_frame_idx = random.randint(self.num_half_frames * interval, self.clip_frames[clip_name] - 1 - self.num_half_frames * interval) start_frame_idx = center_frame_idx - self.num_half_frames * interval end_frame_idx = center_frame_idx + self.num_half_frames * interval # determine the neighbor frames neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval)) # random reverse if self.random_reverse and random.random() < 0.5: neighbor_list.reverse() # get the neighboring GT frames img_gts = [] for neighbor in neighbor_list: if self.is_lmdb: img_gt_path = f'{clip_name}/{neighbor:08d}' else: img_gt_path = osp.join(self.gt_root, clip_name, f'{neighbor:08d}.png') # get GT img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) img_gts.append(img_gt) # random crop img_gts = random_crop(img_gts, self.opt['gt_size']) # augmentation img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot']) return img_gts def add_ffmpeg_compression(self, img_lqs, width, height): # ffmpeg loglevel = 'error' format = 'h264' fps = random.choices([24, 25, 30, 50, 60], [0.2, 0.2, 0.2, 0.2, 0.2])[0] # still have problems fps = 25 crf = np.random.uniform(self.crf_range[0], self.crf_range[1]) try: extra_args = dict() if format == 'h264': vcodec = 'libx264' profile = random.choices(self.ffmpeg_profile_names, self.ffmpeg_profile_probs)[0] extra_args['profile:v'] = profile ffmpeg_img2video = ( ffmpeg.input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{width}x{height}', r=fps).filter('fps', fps=fps, round='up').output( 'pipe:', format=format, pix_fmt='yuv420p', crf=crf, vcodec=vcodec, **extra_args).global_args('-hide_banner').global_args('-loglevel', loglevel).run_async( pipe_stdin=True, pipe_stdout=True)) ffmpeg_video2img = ( ffmpeg.input('pipe:', format=format).output('pipe:', format='rawvideo', pix_fmt='rgb24').global_args('-hide_banner').global_args( '-loglevel', loglevel).run_async(pipe_stdin=True, pipe_stdout=True)) # read a sequence of images for img_lq in img_lqs: ffmpeg_img2video.stdin.write(img_lq.astype(np.uint8).tobytes()) ffmpeg_img2video.stdin.close() video_bytes = ffmpeg_img2video.stdout.read() ffmpeg_img2video.wait() # ffmpeg: video to images ffmpeg_video2img.stdin.write(video_bytes) ffmpeg_video2img.stdin.close() img_lqs_ffmpeg = [] while True: in_bytes = ffmpeg_video2img.stdout.read(width * height * 3) if not in_bytes: break in_frame = (np.frombuffer(in_bytes, np.uint8).reshape([height, width, 3])) in_frame = in_frame.astype(np.float32) / 255. img_lqs_ffmpeg.append(in_frame) ffmpeg_video2img.wait() assert len(img_lqs_ffmpeg) == self.num_frame, 'Wrong length' except AssertionError as error: logger = get_root_logger() logger.warn(f'ffmpeg assertion error: {error}') except Exception as error: logger = get_root_logger() logger.warn(f'ffmpeg exception error: {error}') else: img_lqs = img_lqs_ffmpeg return img_lqs def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) img_gts = self.get_gt_clip(index) # ------------- generate LQ frames --------------# # add blur kernel = random_mixed_kernels(['iso', 'aniso'], [0.7, 0.3], 21, self.iso_blur_range, self.aniso_blur_range) img_lqs = [cv2.filter2D(v, -1, kernel) for v in img_gts] # add noise img_lqs = [ random_add_gaussian_noise(v, sigma_range=self.noise_range, gray_prob=0.5, clip=True, rounds=False) for v in img_lqs ] # downsample h, w = img_gts[0].shape[0:2] width = w // self.scale height = h // self.scale resize_type = random.choices([cv2.INTER_AREA, cv2.INTER_LINEAR, cv2.INTER_CUBIC], [0.3, 0.3, 0.4])[0] img_lqs = [cv2.resize(v, (width, height), interpolation=resize_type) for v in img_lqs] # ffmpeg img_lqs = [np.clip(img_lq * 255.0, 0, 255) for img_lq in img_lqs] img_lqs = self.add_ffmpeg_compression(img_lqs, width, height) # ------------- end --------------# img_gts = img2tensor(img_gts) img_lqs = img2tensor(img_lqs) img_gts = torch.stack(img_gts, dim=0) img_lqs = torch.stack(img_lqs, dim=0) # img_lqs: (t, c, h, w) # img_gts: (t, c, h, w) return {'lq': img_lqs, 'gt': img_gts} def __len__(self): return len(self.keys) ================================================ FILE: animesr/data/ffmpeg_anime_lbo_dataset.py ================================================ import numpy as np import random import torch from torch.nn import functional as F from animesr.archs.simple_degradation_arch import SimpleDegradationArch from basicsr.data.degradations import random_add_gaussian_noise_pt, random_mixed_kernels from basicsr.utils import FileClient, get_root_logger, img2tensor from basicsr.utils.dist_util import get_dist_info from basicsr.utils.img_process_util import filter2D from basicsr.utils.registry import DATASET_REGISTRY from .ffmpeg_anime_dataset import FFMPEGAnimeDataset @DATASET_REGISTRY.register() class FFMPEGAnimeLBODataset(FFMPEGAnimeDataset): """Anime datasets with both classic basic operators and learnable basic operators (LBO)""" def __init__(self, opt): super(FFMPEGAnimeLBODataset, self).__init__(opt) self.rank, self.world_size = get_dist_info() self.lbo = SimpleDegradationArch(downscale=2) lbo_list = opt['degradation_model_path'] if not isinstance(lbo_list, list): lbo_list = [lbo_list] self.lbo_list = lbo_list # print(f'degradation model path for {self.rank} {self.world_size}: {degradation_model_path}\n') # the real load is at reload_degradation_model function self.lbo.load_state_dict(torch.load(self.lbo_list[0], map_location=lambda storage, loc: storage)['params']) self.lbo = self.lbo.to(f'cuda:{self.rank}').eval() self.lbo_prob = opt.get('lbo_prob', 0.5) def reload_degradation_model(self): """ __init__ will be only invoked once for one gpu worker, so if we want to have num_worker_dataset * num_gpu degradation model, we must call this func in __getitem__ ref: https://discuss.pytorch.org/t/what-happened-when-set-num-workers-0-in-dataloader/138515 """ degradation_model_path = random.choice(self.lbo_list) self.lbo.load_state_dict( torch.load(degradation_model_path, map_location=lambda storage, loc: storage)['params']) print(f'reload degradation model path for {self.rank} {self.world_size}: {degradation_model_path}\n') logger = get_root_logger() logger.info(f'reload degradation model path for {self.rank} {self.world_size}: {degradation_model_path}\n') @torch.no_grad() def custom_resize(self, x, scale=2): if random.random() < self.lbo_prob: # learned degradation model from real-world x = self.lbo(x) else: # classic synthetic h, w = x.shape[2:] width = w // scale height = h // scale mode = random.choice(['area', 'bilinear', 'bicubic']) if mode == 'area': align_corners = None else: align_corners = False x = F.interpolate(x, size=(height, width), mode=mode, align_corners=align_corners) return x @torch.no_grad() def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # called only once self.reload_degradation_model() img_gts = self.get_gt_clip(index) # ------------- generate LQ frames --------------# # change to CUDA implementation img_gts = img2tensor(img_gts) img_gts = torch.stack(img_gts, dim=0) img_gts = img_gts.to(f'cuda:{self.rank}') # add blur kernel = random_mixed_kernels(['iso', 'aniso'], [0.7, 0.3], 21, self.iso_blur_range, self.aniso_blur_range) with torch.no_grad(): kernel = torch.FloatTensor(kernel).unsqueeze(0).expand(self.num_frame, 21, 21).to(f'cuda:{self.rank}') img_lqs = filter2D(img_gts, kernel) # add noise img_lqs = random_add_gaussian_noise_pt( img_lqs, sigma_range=self.noise_range, clip=True, rounds=False, gray_prob=0.5) # downsample img_lqs = self.custom_resize(img_lqs) if self.scale == 4: img_lqs = self.custom_resize(img_lqs) height, width = img_lqs.shape[2:] # back to numpy since ffmpeg compression operate on cpu img_lqs = img_lqs.detach().clamp_(0, 1).permute(0, 2, 3, 1) * 255 # B, H, W, C img_lqs = img_lqs.type(torch.uint8).cpu().numpy()[:, :, :, ::-1] img_lqs = np.split(img_lqs, self.num_frame, axis=0) img_lqs = [img_lq[0] for img_lq in img_lqs] # ffmpeg img_lqs = self.add_ffmpeg_compression(img_lqs, width, height) # ------------- end --------------# img_lqs = img2tensor(img_lqs) img_lqs = torch.stack(img_lqs, dim=0) # img_lqs: (t, c, h, w) # img_gts: (t, c, h, w) on gpu return {'lq': img_lqs, 'gt': img_gts.cpu()} def __len__(self): return len(self.keys) ================================================ FILE: animesr/data/paired_image_dataset.py ================================================ import glob import os from torch.utils import data as data from torchvision.transforms.functional import normalize from basicsr.data.transforms import augment, mod_crop, paired_random_crop from basicsr.utils import FileClient, imfrombytes, img2tensor from basicsr.utils.registry import DATASET_REGISTRY @DATASET_REGISTRY.register() class CustomPairedImageDataset(data.Dataset): """Paired image dataset for training LBO. Read real-world LQ and GT frames pairs. The organization of these gt&lq folder is similar to AVC-Train, except that each folder contains 200 clips, and each clip contains 11 frames. We will ignore the first frame, so there are finally 2000 training pair data. Args: opt (dict): Config for train datasets. It contains the following keys: dataroot_gt (str): Data root path for gt, also the pseudo HR path. dataroot_lq (str): Data root path for lq. io_backend (dict): IO backend type and other kwarg. gt_size (int): Cropped patched size for gt patches. use_hflip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). scale (bool): Scale, which will be added automatically. phase (str): 'train' or 'val'. """ def __init__(self, opt): super(CustomPairedImageDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.mean = opt['mean'] if 'mean' in opt else None self.std = opt['std'] if 'std' in opt else None self.mod_crop_scale = 8 self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] omit_first_frame = opt.get('omit_first_frame', True) start_idx = 1 if omit_first_frame else 0 self.paths = [] clip_list = os.listdir(self.lq_folder) for clip_name in clip_list: lq_frame_list = sorted(glob.glob(f'{self.lq_folder}/{clip_name}/*.png')) gt_frame_list = sorted(glob.glob(f'{self.gt_folder}/{clip_name}/*.png')) assert len(lq_frame_list) == len(gt_frame_list) for i in range(start_idx, len(lq_frame_list)): # omit the first frame self.paths.append(dict([('lq_path', lq_frame_list[i]), ('gt_path', gt_frame_list[i])])) def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) scale = self.opt['scale'] # Load gt and lq images. Dimension order: HWC; channel order: BGR; # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') img_lq = imfrombytes(img_bytes, float32=True) img_lq = mod_crop(img_lq, self.mod_crop_scale) # augmentation for training if self.opt['phase'] == 'train': gt_size = self.opt['gt_size'] # random crop img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) normalize(img_gt, self.mean, self.std, inplace=True) return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} def __len__(self): return len(self.paths) ================================================ FILE: animesr/models/__init__.py ================================================ import importlib from os import path as osp from basicsr.utils import scandir # automatically scan and import model modules for registry # scan all the files that end with '_model.py' under the model folder model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] # import all the model modules _model_modules = [importlib.import_module(f'animesr.models.{file_name}') for file_name in model_filenames] ================================================ FILE: animesr/models/degradation_gan_model.py ================================================ from collections import OrderedDict from basicsr.models.srgan_model import SRGANModel from basicsr.utils.registry import MODEL_REGISTRY @MODEL_REGISTRY.register() class DegradationGANModel(SRGANModel): """Degradation model for real-world, hard-to-synthesis degradation.""" def feed_data(self, data): # we reverse the order of lq and gt for convenient implementation self.lq = data['gt'].to(self.device) if 'lq' in data: self.gt = data['lq'].to(self.device) def optimize_parameters(self, current_iter): # optimize net_g for p in self.net_d.parameters(): p.requires_grad = False self.optimizer_g.zero_grad() self.output = self.net_g(self.lq) l_g_total = 0 loss_dict = OrderedDict() if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): # pixel loss if self.cri_pix: l_g_pix = self.cri_pix(self.output, self.gt) l_g_total += l_g_pix loss_dict['l_g_pix'] = l_g_pix # perceptual loss if self.cri_perceptual: l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) if l_g_percep is not None: l_g_total += l_g_percep loss_dict['l_g_percep'] = l_g_percep if l_g_style is not None: l_g_total += l_g_style loss_dict['l_g_style'] = l_g_style # gan loss fake_g_pred = self.net_d(self.output) l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) l_g_total += l_g_gan loss_dict['l_g_gan'] = l_g_gan l_g_total.backward() self.optimizer_g.step() # optimize net_d for p in self.net_d.parameters(): p.requires_grad = True self.optimizer_d.zero_grad() # real real_d_pred = self.net_d(self.gt) l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) loss_dict['l_d_real'] = l_d_real l_d_real.backward() # fake fake_d_pred = self.net_d(self.output.detach()) l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) loss_dict['l_d_fake'] = l_d_fake l_d_fake.backward() self.optimizer_d.step() self.log_dict = self.reduce_loss_dict(loss_dict) if self.ema_decay > 0: self.model_ema(decay=self.ema_decay) ================================================ FILE: animesr/models/degradation_model.py ================================================ from collections import OrderedDict from basicsr.losses import build_loss from basicsr.models.sr_model import SRModel from basicsr.utils.registry import MODEL_REGISTRY @MODEL_REGISTRY.register() class DegradationModel(SRModel): """Degradation model for real-world, hard-to-synthesis degradation.""" def init_training_settings(self): self.net_g.train() train_opt = self.opt['train'] # define losses self.l1_pix = build_loss(train_opt['l1_opt']).to(self.device) self.l2_pix = build_loss(train_opt['l2_opt']).to(self.device) # set up optimizers and schedulers self.setup_optimizers() self.setup_schedulers() def feed_data(self, data): # we reverse the order of lq and gt for convenient implementation self.lq = data['gt'].to(self.device) if 'lq' in data: self.gt = data['lq'].to(self.device) def optimize_parameters(self, current_iter): self.optimizer_g.zero_grad() self.output = self.net_g(self.lq) l_total = 0 loss_dict = OrderedDict() # l1 loss l_l1 = self.l1_pix(self.output, self.gt) l_total += l_l1 loss_dict['l_l1'] = l_l1 # l2 loss l_l2 = self.l2_pix(self.output, self.gt) l_total += l_l2 loss_dict['l_l2'] = l_l2 l_total.backward() self.optimizer_g.step() self.log_dict = self.reduce_loss_dict(loss_dict) ================================================ FILE: animesr/models/video_recurrent_gan_model.py ================================================ from collections import OrderedDict from basicsr.archs import build_network from basicsr.losses import build_loss from basicsr.utils import get_root_logger from basicsr.utils.registry import MODEL_REGISTRY from .video_recurrent_model import VideoRecurrentCustomModel @MODEL_REGISTRY.register() class VideoRecurrentGANCustomModel(VideoRecurrentCustomModel): """Currently, the VideoRecurrentGANModel and multi-scale discriminator are not compatible, so we use a custom model. """ def init_training_settings(self): train_opt = self.opt['train'] self.ema_decay = train_opt.get('ema_decay', 0) if self.ema_decay > 0: logger = get_root_logger() logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') # build network net_g with Exponential Moving Average (EMA) # net_g_ema only used for testing on one GPU and saving. # There is no need to wrap with DistributedDataParallel self.net_g_ema = build_network(self.opt['network_g']).to(self.device) # load pretrained model load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') else: self.model_ema(0) # copy net_g weight self.net_g_ema.eval() # define network net_d self.net_d = build_network(self.opt['network_d']) self.net_d = self.model_to_device(self.net_d) self.print_network(self.net_d) # load pretrained models load_path = self.opt['path'].get('pretrain_network_d', None) if load_path is not None: param_key = self.opt['path'].get('param_key_d', 'params') self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) self.net_g.train() self.net_d.train() # define losses if train_opt.get('pixel_opt'): self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) else: self.cri_pix = None if train_opt.get('perceptual_opt'): self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) else: self.cri_perceptual = None if train_opt.get('gan_opt'): self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) self.net_d_iters = train_opt.get('net_d_iters', 1) self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) # set up optimizers and schedulers self.setup_optimizers() self.setup_schedulers() def setup_optimizers(self): train_opt = self.opt['train'] if train_opt['fix_flow']: normal_params = [] flow_params = [] for name, param in self.net_g.named_parameters(): if 'spynet' in name: # The fix_flow now only works for spynet. flow_params.append(param) else: normal_params.append(param) optim_params = [ { # add flow params first 'params': flow_params, 'lr': train_opt['lr_flow'] }, { 'params': normal_params, 'lr': train_opt['optim_g']['lr'] }, ] else: optim_params = self.net_g.parameters() # optimizer g optim_type = train_opt['optim_g'].pop('type') self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) self.optimizers.append(self.optimizer_g) # optimizer d optim_type = train_opt['optim_d'].pop('type') self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) self.optimizers.append(self.optimizer_d) def optimize_parameters(self, current_iter): # optimize net_g for p in self.net_d.parameters(): p.requires_grad = False self.optimize_parameters_base(current_iter) _, _, c, h, w = self.output.size() pix_gt = self.gt percep_gt = self.gt gan_gt = self.gt if self.opt.get('l1_gt_usm', False): pix_gt = self.gt_usm if self.opt.get('percep_gt_usm', False): percep_gt = self.gt_usm if self.opt.get('gan_gt_usm', False): gan_gt = self.gt_usm l_g_total = 0 loss_dict = OrderedDict() if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): # pixel loss if self.cri_pix: l_g_pix = self.cri_pix(self.output, pix_gt) l_g_total += l_g_pix loss_dict['l_g_pix'] = l_g_pix # perceptual loss if self.cri_perceptual: l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), percep_gt.view(-1, c, h, w)) if l_g_percep is not None: l_g_total += l_g_percep loss_dict['l_g_percep'] = l_g_percep if l_g_style is not None: l_g_total += l_g_style loss_dict['l_g_style'] = l_g_style # gan loss fake_g_pred = self.net_d(self.output.view(-1, c, h, w)) l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) l_g_total += l_g_gan loss_dict['l_g_gan'] = l_g_gan l_g_total.backward() self.optimizer_g.step() # optimize net_d for p in self.net_d.parameters(): p.requires_grad = True self.optimizer_d.zero_grad() # real # reshape to (b*n, c, h, w) real_d_pred = self.net_d(gan_gt.view(-1, c, h, w)) l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) loss_dict['l_d_real'] = l_d_real l_d_real.backward() # fake # reshape to (b*n, c, h, w) fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach()) l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) loss_dict['l_d_fake'] = l_d_fake l_d_fake.backward() self.optimizer_d.step() self.log_dict = self.reduce_loss_dict(loss_dict) if self.ema_decay > 0: self.model_ema(decay=self.ema_decay) def save(self, epoch, current_iter): if self.ema_decay > 0: self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) else: self.save_network(self.net_g, 'net_g', current_iter) self.save_network(self.net_d, 'net_d', current_iter) self.save_training_state(epoch, current_iter) ================================================ FILE: animesr/models/video_recurrent_model.py ================================================ import cv2 import os import torch from collections import OrderedDict from os import path as osp from torch import distributed as dist from tqdm import tqdm from basicsr.models.video_base_model import VideoBaseModel from basicsr.utils import USMSharp, get_root_logger, imwrite, tensor2img from basicsr.utils.dist_util import get_dist_info from basicsr.utils.registry import MODEL_REGISTRY @MODEL_REGISTRY.register() class VideoRecurrentCustomModel(VideoBaseModel): def __init__(self, opt): super(VideoRecurrentCustomModel, self).__init__(opt) if self.is_train: self.fix_flow_iter = opt['train'].get('fix_flow') self.idx = 1 lq_from_usm = opt['datasets']['train'].get('lq_from_usm', False) assert lq_from_usm is False self.usm_sharp_gt = opt['datasets']['train'].get('usm_sharp_gt', False) if self.usm_sharp_gt: usm_radius = opt['datasets']['train'].get('usm_radius', 50) self.usm_sharpener = USMSharp(radius=usm_radius).cuda() self.usm_weight = opt['datasets']['train'].get('usm_weight', 0.5) self.usm_threshold = opt['datasets']['train'].get('usm_threshold', 10) @torch.no_grad() def feed_data(self, data): self.lq = data['lq'].to(self.device) if 'gt' in data: self.gt = data['gt'].to(self.device) if 'gt_usm' in data: self.gt_usm = data['gt_usm'].to(self.device) logger = get_root_logger() logger.warning( 'since lq is not from gt_usm, ' 'we should put the usm_sharp operation outside the dataloader to speed up the traning time') elif self.usm_sharp_gt: b, n, c, h, w = self.gt.size() self.gt_usm = self.usm_sharpener( self.gt.view(b * n, c, h, w), weight=self.usm_weight, threshold=self.usm_threshold).view(b, n, c, h, w) # if self.opt['rank'] == 0 and 'debug' in self.opt['name']: # import torchvision # os.makedirs('tmp/gt', exist_ok=True) # os.makedirs('tmp/gt_usm', exist_ok=True) # os.makedirs('tmp/lq', exist_ok=True) # print(self.idx) # for i in range(15): # torchvision.utils.save_image( # self.lq[:, i, :, :, :], # f'tmp/lq/lq{self.idx}_{i}.png', # nrow=4, # padding=2, # normalize=True, # range=(0, 1)) # torchvision.utils.save_image( # self.gt[:, i, :, :, :], # f'tmp/gt/gt{self.idx}_{i}.png', # nrow=4, # padding=2, # normalize=True, # range=(0, 1)) # torchvision.utils.save_image( # self.gt_usm[:, i, :, :, :], # f'tmp/gt_usm/gt_usm{self.idx}_{i}.png', # nrow=4, # padding=2, # normalize=True, # range=(0, 1)) # self.idx += 1 # if self.idx >= 20: # exit() def setup_optimizers(self): train_opt = self.opt['train'] flow_lr_mul = train_opt.get('flow_lr_mul', 1) logger = get_root_logger() logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.') if flow_lr_mul == 1: optim_params = self.net_g.parameters() else: # separate flow params and normal params for different lr normal_params = [] flow_params = [] for name, param in self.net_g.named_parameters(): if 'spynet' in name: flow_params.append(param) else: normal_params.append(param) optim_params = [ { # add normal params first 'params': normal_params, 'lr': train_opt['optim_g']['lr'] }, { 'params': flow_params, 'lr': train_opt['optim_g']['lr'] * flow_lr_mul }, ] optim_type = train_opt['optim_g'].pop('type') self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) self.optimizers.append(self.optimizer_g) def optimize_parameters_base(self, current_iter): if self.fix_flow_iter: logger = get_root_logger() if current_iter == 1: logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.') for name, param in self.net_g.named_parameters(): if 'spynet' in name or 'edvr' in name: param.requires_grad_(False) elif current_iter == self.fix_flow_iter: logger.warning('Train all the parameters.') self.net_g.requires_grad_(True) self.optimizer_g.zero_grad() self.output = self.net_g(self.lq) def optimize_parameters(self, current_iter): self.optimize_parameters_base(current_iter) _, _, c, h, w = self.output.size() pix_gt = self.gt percep_gt = self.gt if self.opt.get('l1_gt_usm', False): pix_gt = self.gt_usm if self.opt.get('percep_gt_usm', False): percep_gt = self.gt_usm l_total = 0 loss_dict = OrderedDict() # pixel loss if self.cri_pix: l_pix = self.cri_pix(self.output, pix_gt) l_total += l_pix loss_dict['l_pix'] = l_pix # perceptual loss if self.cri_perceptual: l_percep, l_style = self.cri_perceptual(self.output.view(-1, c, h, w), percep_gt.view(-1, c, h, w)) if l_percep is not None: l_total += l_percep loss_dict['l_percep'] = l_percep if l_style is not None: l_total += l_style loss_dict['l_style'] = l_style l_total.backward() self.optimizer_g.step() self.log_dict = self.reduce_loss_dict(loss_dict) if self.ema_decay > 0: self.model_ema(decay=self.ema_decay) def dist_validation(self, dataloader, current_iter, tb_logger, save_img): """dist_test actually, no gt, no metrics""" dataset = dataloader.dataset dataset_name = dataset.opt['name'] assert dataset_name.endswith('CoreFrames') rank, world_size = get_dist_info() num_folders = len(dataset) num_pad = (world_size - (num_folders % world_size)) % world_size if rank == 0: pbar = tqdm(total=len(dataset), unit='folder') os.makedirs(osp.join(self.opt['path']['visualization'], dataset_name, str(current_iter)), exist_ok=True) if self.opt['dist']: dist.barrier() # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded. # (To avoid wait-dead) for i in range(rank, num_folders + num_pad, world_size): idx = min(i, num_folders - 1) val_data = dataset[idx] folder = val_data['folder'] # compute outputs val_data['lq'].unsqueeze_(0) self.feed_data(val_data) val_data['lq'].squeeze_(0) self.test() visuals = self.get_current_visuals() # tentative for out of GPU memory del self.lq del self.output if 'gt' in visuals: del self.gt torch.cuda.empty_cache() # evaluate if i < num_folders: for idx in range(visuals['result'].size(1)): result = visuals['result'][0, idx, :, :, :] result_img = tensor2img([result]) # uint8, bgr # since we keep all frames, scale of 4 is not very friendly to storage space # so we use a default scale of 2 to save the frames save_scale = self.opt.get('savescale', 2) net_scale = self.opt.get('scale') if save_scale != net_scale: h, w = result_img.shape[0:2] result_img = cv2.resize( result_img, (w // net_scale * save_scale, h // net_scale * save_scale), interpolation=cv2.INTER_LANCZOS4) if save_img: img_path = osp.join(self.opt['path']['visualization'], dataset_name, str(current_iter), f"{folder}_{idx:08d}_{self.opt['name'][:5]}.png") # image name only for REDS dataset imwrite(result_img, img_path) # progress bar if rank == 0: for _ in range(world_size): pbar.update(1) pbar.set_description(f'Folder: {folder}') if rank == 0: pbar.close() def test(self): n = self.lq.size(1) self.net_g.eval() flip_seq = self.opt['val'].get('flip_seq', False) self.center_frame_only = self.opt['val'].get('center_frame_only', False) if flip_seq: self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1) with torch.no_grad(): self.output = self.net_g(self.lq) if flip_seq: output_1 = self.output[:, :n, :, :, :] output_2 = self.output[:, n:, :, :, :].flip(1) self.output = 0.5 * (output_1 + output_2) if self.center_frame_only: self.output = self.output[:, n // 2, :, :, :] self.net_g.train() ================================================ FILE: animesr/test.py ================================================ # flake8: noqa import os.path as osp import animesr.archs import animesr.data import animesr.models from basicsr.test import test_pipeline if __name__ == '__main__': root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) test_pipeline(root_path) ================================================ FILE: animesr/train.py ================================================ # flake8: noqa import os.path as osp import animesr.archs import animesr.data import animesr.models from basicsr.train import train_pipeline if __name__ == '__main__': root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) train_pipeline(root_path) ================================================ FILE: animesr/utils/__init__.py ================================================ # -*- coding: utf-8 -*- ================================================ FILE: animesr/utils/inference_base.py ================================================ import argparse import os.path import torch from animesr.archs.vsr_arch import MSRSWVSR def get_base_argument_parser() -> argparse.ArgumentParser: """get the base argument parser for inference scripts""" parser = argparse.ArgumentParser() parser.add_argument('-i', '--input', type=str, default='inputs', help='input test image folder or video path') parser.add_argument('-o', '--output', type=str, default='results', help='save image/video path') parser.add_argument( '-n', '--model_name', type=str, default='AnimeSR_v2', help='Model names: AnimeSR_v2 | AnimeSR_v1-PaperModel. Default:AnimeSR_v2') parser.add_argument( '-s', '--outscale', type=int, default=4, help='The netscale is x4, but you can achieve arbitrary output scale (e.g., x2) with the argument outscale' 'The program will further perform cheap resize operation after the AnimeSR output. ' 'This is useful when you want to save disk space or avoid too large-resolution output') parser.add_argument( '--expname', type=str, default='animesr', help='A unique name to identify your current inference') parser.add_argument( '--netscale', type=int, default=4, help='the released models are all x4 models, only change this if you train a x2 or x1 model by yourself') parser.add_argument( '--mod_scale', type=int, default=4, help='the scale used for mod crop, since AnimeSR use a multi-scale arch, so the edge should be divisible by 4') parser.add_argument('--fps', type=int, default=None, help='fps of the sr videos') parser.add_argument('--half', action='store_true', help='use half precision to inference') return parser def get_inference_model(args, device) -> MSRSWVSR: """return an on device model with eval mode""" # set up model model = MSRSWVSR(num_feat=64, num_block=[5, 3, 2], netscale=args.netscale) model_path = f'weights/{args.model_name}.pth' assert os.path.isfile(model_path), \ f'{model_path} does not exist, please make sure you successfully download the pretrained models ' \ f'and put them into the weights folder' # load checkpoint loadnet = torch.load(model_path) model.load_state_dict(loadnet, strict=True) model.eval() model = model.to(device) # num_parameters = sum(map(lambda x: x.numel(), model.parameters())) # print(num_parameters) # exit(0) return model.half() if args.half else model ================================================ FILE: animesr/utils/shot_detector.py ================================================ # The codes below partially refer to the PySceneDetect. According # to its BSD 3-Clause License, we keep the following. # # PySceneDetect: Python-Based Video Scene Detector # --------------------------------------------------------------- # [ Site: http://www.bcastell.com/projects/PySceneDetect/ ] # [ Github: https://github.com/Breakthrough/PySceneDetect/ ] # [ Documentation: http://pyscenedetect.readthedocs.org/ ] # # Copyright (C) 2014-2020 Brandon Castellano . # # PySceneDetect is licensed under the BSD 3-Clause License; see the included # LICENSE file, or visit one of the following pages for details: # - https://github.com/Breakthrough/PySceneDetect/ # - http://www.bcastell.com/projects/PySceneDetect/ # # This software uses Numpy, OpenCV, click, tqdm, simpletable, and pytest. # See the included LICENSE files or one of the above URLs for more information. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import cv2 import glob import numpy as np import os from tqdm import tqdm DEFAULT_DOWNSCALE_FACTORS = { 3200: 12, # ~4k 2100: 8, # ~2k 1700: 6, # ~1080p 1200: 5, 900: 4, # ~720p 600: 3, 400: 1 # ~480p } def compute_downscale_factor(frame_width): """Compute Downscale Factor: Returns the optimal default downscale factor based on a video's resolution (specifically, the width parameter). Returns: int: The defalt downscale factor to use with a video of frame_height x frame_width. """ for width in sorted(DEFAULT_DOWNSCALE_FACTORS, reverse=True): if frame_width >= width: return DEFAULT_DOWNSCALE_FACTORS[width] return 1 class ShotDetector(object): """Detects fast cuts using changes in colour and intensity between frames. Detect shot boundary using HSV and LUV information. """ def __init__(self, threshold=30.0, min_shot_len=15): super(ShotDetector, self).__init__() self.hsv_threshold = threshold self.delta_hsv_gap_threshold = 10 self.luv_threshold = 40 self.hsv_weight = 5 # minimum length (frames length) of any given shot self.min_shot_len = min_shot_len self.last_frame = None self.last_shot_cut = None self.last_hsv = None self._metric_keys = [ 'hsv_content_val', 'delta_hsv_hue', 'delta_hsv_sat', 'delta_hsv_lum', 'luv_content_val', 'delta_luv_hue', 'delta_luv_sat', 'delta_luv_lum' ] self.cli_name = 'detect-content' self.last_luv = None self.cut_list = [] def add_cut(self, cut): num_cuts = len(self.cut_list) if num_cuts == 0: self.cut_list.append([0, cut - 1]) else: self.cut_list.append([self.cut_list[num_cuts - 1][1] + 1, cut - 1]) def process_frame(self, frame_num, frame_img): """Similar to ThresholdDetector, but using the HSV colour space DIFFERENCE instead of single-frame RGB/grayscale intensity (thus cannot detect slow fades with this method). Args: frame_num (int): Frame number of frame that is being passed. frame_img (Optional[int]): Decoded frame image (np.ndarray) to perform shot detection on. Can be None *only* if the self.is_processing_required() method (inhereted from the base shotDetector class) returns True. Returns: List[int]: List of frames where shot cuts have been detected. There may be 0 or more frames in the list, and not necessarily the same as frame_num. """ cut_list = [] # if self.last_frame is not None: # # Change in average of HSV (hsv), (h)ue only, # # (s)aturation only, (l)uminance only. delta_hsv_avg, delta_hsv_h, delta_hsv_s, delta_hsv_v = 0.0, 0.0, 0.0, 0.0 delta_luv_avg, delta_luv_h, delta_luv_s, delta_luv_v = 0.0, 0.0, 0.0, 0.0 if frame_num == 0: self.last_frame = frame_img.copy() return cut_list else: num_pixels = frame_img.shape[0] * frame_img.shape[1] curr_luv = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_BGR2Luv)) curr_hsv = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_BGR2HSV)) last_hsv = self.last_hsv last_luv = self.last_luv if not last_hsv: last_hsv = cv2.split(cv2.cvtColor(self.last_frame, cv2.COLOR_BGR2HSV)) last_luv = cv2.split(cv2.cvtColor(self.last_frame, cv2.COLOR_BGR2Luv)) delta_hsv = [0, 0, 0, 0] for i in range(3): num_pixels = curr_hsv[i].shape[0] * curr_hsv[i].shape[1] curr_hsv[i] = curr_hsv[i].astype(np.int32) last_hsv[i] = last_hsv[i].astype(np.int32) delta_hsv[i] = np.sum(np.abs(curr_hsv[i] - last_hsv[i])) / float(num_pixels) delta_hsv[3] = sum(delta_hsv[0:3]) / 3.0 delta_hsv_h, delta_hsv_s, delta_hsv_v, delta_hsv_avg = \ delta_hsv delta_luv = [0, 0, 0, 0] for i in range(3): num_pixels = curr_luv[i].shape[0] * curr_luv[i].shape[1] curr_luv[i] = curr_luv[i].astype(np.int32) last_luv[i] = last_luv[i].astype(np.int32) delta_luv[i] = np.sum(np.abs(curr_luv[i] - last_luv[i])) / float(num_pixels) delta_luv[3] = sum(delta_luv[0:3]) / 3.0 delta_luv_h, delta_luv_s, delta_luv_v, delta_luv_avg = \ delta_luv self.last_hsv = curr_hsv self.last_luv = curr_luv if delta_hsv_avg >= self.hsv_threshold and delta_hsv_avg - self.hsv_threshold >= self.delta_hsv_gap_threshold: if self.last_shot_cut is None or ((frame_num - self.last_shot_cut) >= self.min_shot_len): cut_list.append(frame_num) self.last_shot_cut = frame_num elif delta_hsv_avg >= self.hsv_threshold and \ delta_hsv_avg - self.hsv_threshold < \ self.delta_hsv_gap_threshold and \ delta_luv_avg + self.hsv_weight * \ (delta_hsv_avg - self.hsv_threshold) > self.luv_threshold: if self.last_shot_cut is None or ((frame_num - self.last_shot_cut) >= self.min_shot_len): cut_list.append(frame_num) self.last_shot_cut = frame_num self.last_frame = frame_img.copy() return cut_list def detect_shots(self, frame_source, frame_skip=0, show_progress=True, keep_resolution=False): """Perform shot detection on the given frame_source using the added shotDetectors. Blocks until all frames in the frame_source have been processed. Results can be obtained by calling either the get_shot_list() or get_cut_list() methods. Arguments: frame_source (shotdetect.video_manager.VideoManager or cv2.VideoCapture): A source of frames to process (using frame_source.read() as in VideoCapture). VideoManager is preferred as it allows concatenation of multiple videos as well as seeking, by defining start time and end time/duration. end_time (int or FrameTimecode): Maximum number of frames to detect (set to None to detect all available frames). Only needed for OpenCV VideoCapture objects; for VideoManager objects, use set_duration() instead. frame_skip (int): Not recommended except for extremely high framerate videos. Number of frames to skip (i.e. process every 1 in N+1 frames, where N is frame_skip, processing only 1/N+1 percent of the video, speeding up the detection time at the expense of accuracy). `frame_skip` **must** be 0 (the default) when using a StatsManager. show_progress (bool): If True, and the ``tqdm`` module is available, displays a progress bar with the progress, framerate, and expected time to complete processing the video frame source. Raises: ValueError: `frame_skip` **must** be 0 (the default) if the shotManager was constructed with a StatsManager object. """ if frame_skip > 0 and self._stats_manager is not None: raise ValueError('frame_skip must be 0 when using a StatsManager.') curr_frame = 0 frame_paths = sorted(glob.glob(os.path.join(frame_source, '*'))) total_frames = len(frame_paths) end_frame = total_frames progress_bar = None if tqdm and show_progress: progress_bar = tqdm(total=total_frames, unit='frames') try: while True: if end_frame is not None and curr_frame >= end_frame: break frame_im = cv2.imread(frame_paths[curr_frame]) if not keep_resolution: if curr_frame == 0: downscale_factor = compute_downscale_factor(frame_im.shape[1]) frame_im = frame_im[::downscale_factor, ::downscale_factor, :] cut = self.process_frame(curr_frame, frame_im) if len(cut) != 0: self.add_cut(cut[0]) curr_frame += 1 if progress_bar: progress_bar.update(1) finally: if progress_bar: progress_bar.close() return self.cut_list ================================================ FILE: animesr/utils/video_util.py ================================================ import glob import os import subprocess default_ffmpeg_exe_path = 'ffmpeg' default_ffprobe_exe_path = 'ffprobe' default_ffmpeg_vcodec = 'h264' default_ffmpeg_pix_fmt = 'yuv420p' def get_video_fps(video_path, ret_type='float'): """Get the fps of the video. Args: video_path (str): the video path; ret_type (str): the return type, it supports `str`, `float`, and `tuple` (numerator, denominator). Returns: --fps (str): if ret_type is `str`. --fps (float): if ret_type is `float`. --fps (tuple): if ret_type is tuple, (numerator, denominator). """ global default_ffprobe_exe_path ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path) cmd = [ ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-of', 'default=noprint_wrappers=1:nokey=1', '-show_entries', 'stream=r_frame_rate', video_path ] result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) fps = result.stdout.decode('utf-8').strip() # e.g. 30/1 numerator, denominator = map(lambda x: int(x), fps.split('/')) if ret_type == 'float': return numerator / denominator elif ret_type == 'str': return str(numerator / denominator) else: return numerator, denominator def get_video_num_frames(video_path): """Get the video's total number of frames.""" global default_ffprobe_exe_path ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path) cmd = [ ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-count_packets', '-of', 'default=noprint_wrappers=1:nokey=1', '-show_entries', 'stream=nb_read_packets', video_path ] result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) nb_frames = result.stdout.decode('utf-8').strip() return int(nb_frames) def get_video_bitrate(video_path): """Get the bitrate of the video.""" global default_ffprobe_exe_path ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path) cmd = [ ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-of', 'default=noprint_wrappers=1:nokey=1', '-show_entries', 'stream=bit_rate', video_path ] result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) bitrate = result.stdout.decode('utf-8').strip() if bitrate == 'N/A': return bitrate return int(bitrate) // 1000 def get_video_resolution(video_path): """Get the resolution (h and w) of the video. Args: video_path (str): the video path; Returns: h, w (int) """ global default_ffprobe_exe_path ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path) cmd = [ ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-of', 'csv=s=x:p=0', '-show_entries', 'stream=width,height', video_path ] result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) resolution = result.stdout.decode('utf-8').strip() # print(resolution) w, h = resolution.split('x') return int(h), int(w) def video2frames(video_path, out_dir, force=False, high_quality=True, ss=None, to=None, vf=None): """Extract frames from the video Args: out_dir: where to save the frames force: if out_dir is not empty, forceTrue will still extract frames high_quality: whether to use the highest quality ss: start time, format HH.MM.SS[.xxx], if None, extract full video to: end time, format HH.MM.SS[.xxx], if None, extract full video vf: video filter """ global default_ffmpeg_exe_path ffmpeg_exc_path = os.environ.get('ffmpeg_exe_path', default_ffmpeg_exe_path) imgs = glob.glob(os.path.join(out_dir, '*.png')) length = len(imgs) if length > 0: print(f'{out_dir} already has frames!, force extract = {force}') if not force: return out_dir print(f'extracting frames for {video_path}') cmd = [ ffmpeg_exc_path, f'-i {video_path}', '-v error', f'-ss {ss} -to {to}' if ss is not None and to is not None else '', '-qscale:v 1 -qmin 1 -qmax 1 -vsync 0' if high_quality else '', f'-vf {vf}' if vf is not None else '', f'{out_dir}/%08d.png', ] print(' '.join(cmd)) subprocess.call(' '.join(cmd), shell=True) return out_dir def frames2video(frames_dir, out_path, fps=25, filter='*', suffix=None): """Combine frames under a folder to video Args: frames_dir: input folder where frames locate out_path: the output video path fps: output video fps suffix: the frame suffix, e.g., png jpg """ global default_ffmpeg_vcodec, default_ffmpeg_pix_fmt, default_ffmpeg_exe_path ffmpeg_exc_path = os.environ.get('ffmpeg_exe_path', default_ffmpeg_exe_path) vcodec = os.environ.get('ffmpeg_vcodec', default_ffmpeg_vcodec) pix_fmt = os.environ.get('ffmpeg_pix_fmt', default_ffmpeg_pix_fmt) if suffix is None: images_names = os.listdir(frames_dir) image_name = images_names[0] suffix = image_name.split('.')[-1] cmd = [ ffmpeg_exc_path, '-y', '-r', str(fps), '-f', 'image2', '-pattern_type', 'glob', '-i', f'{frames_dir}/{filter}.{suffix}', '-vcodec', vcodec, '-pix_fmt', pix_fmt, out_path ] # yapf: disable print(' '.join(cmd)) subprocess.call(cmd) return out_path ================================================ FILE: cog.yaml ================================================ build: gpu: true cuda: "11.6.2" python_version: "3.10" system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" - "ffmpeg" python_packages: - "ipython==8.4.0" - "torch==1.13.1" - "torchvision==0.14.1" - "ffmpeg-python==0.2.0" - "facexlib==0.2.5" - "basicsr==1.4.2" - "opencv-python==4.7.0.68" - "Pillow==9.4.0" - "psutil==5.9.4" - "tqdm==4.64.1" predict: "predict.py:Predictor" ================================================ FILE: options/train_animesr_step1_gan_BasicOPonly.yml ================================================ # general settings name: train_animesr_step1_gan_BasicOPonly model_type: VideoRecurrentGANCustomModel scale: 4 num_gpu: auto # set num_gpu: 0 for cpu mode manual_seed: 0 # USM the ground-truth l1_gt_usm: False percep_gt_usm: False gan_gt_usm: False # dataset and data loader settings datasets: train: name: AVC-Train type: FFMPEGAnimeDataset dataroot_gt: datasets/AVC-Train # TO_MODIFY test_mode: False io_backend: type: disk num_frame: 15 gt_size: 256 interval_list: [1, 2, 3] random_reverse: True use_flip: true use_rot: true usm_sharp_gt: False # data loader use_shuffle: true num_worker_per_gpu: 20 batch_size_per_gpu: 4 dataset_enlarge_ratio: 1 prefetch_mode: ~ pin_memory: True # network structures network_g: type: MSRSWVSR num_feat: 64 num_block: [5, 3, 2] network_d: type: MultiScaleDiscriminator num_in_ch: 3 num_feat: 64 num_layers: [3, 3, 3] max_nf_mult: 8 norm_type: none use_sigmoid: False use_sn: True use_downscale: True # path path: pretrain_network_g: experiments/train_animesr_step1_net_BasicOPonly/models/net_g_300000.pth param_key_g: params strict_load_g: true resume_state: ~ # training settings train: ema_decay: 0.999 optim_g: type: Adam lr: !!float 1e-4 weight_decay: 0 betas: [0.9, 0.99] lr_flow: !!float 1e-5 optim_d: type: Adam lr: !!float 1e-4 weight_decay: 0 betas: [0.9, 0.99] scheduler: type: MultiStepLR milestones: [300000] gamma: 0.5 total_iter: 300000 warmup_iter: -1 # no warm up fix_flow: ~ # losses pixel_opt: type: L1Loss loss_weight: 1.0 reduction: mean perceptual_opt: type: PerceptualLoss layer_weights: # before relu 'conv1_2': 0.1 'conv2_2': 0.1 'conv3_4': 1 'conv4_4': 1 'conv5_4': 1 vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 gan_opt: type: MultiScaleGANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 loss_weight: !!float 1.0 net_d_iters: 1 net_d_init_iters: 0 # validation settings #val: # val_freq: !!float 1e4 # save_img: true # logging settings logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 use_tb_logger: true wandb: project: ~ resume_id: ~ # dist training settings dist_params: backend: nccl port: 29500 find_unused_parameters: true ================================================ FILE: options/train_animesr_step1_net_BasicOPonly.yml ================================================ # general settings name: train_animesr_step1_net_BasicOPonly model_type: VideoRecurrentCustomModel scale: 4 num_gpu: auto # set num_gpu: 0 for cpu mode manual_seed: 0 # USM the ground-truth l1_gt_usm: True # dataset and data loader settings datasets: train: name: AVC-Train type: FFMPEGAnimeDataset dataroot_gt: datasets/AVC-Train # TO_MODIFY test_mode: False io_backend: type: disk num_frame: 15 gt_size: 256 interval_list: [1, 2, 3] random_reverse: True use_flip: true use_rot: true usm_sharp_gt: True usm_weight: 0.3 usm_radius: 50 # data loader use_shuffle: true num_worker_per_gpu: 20 batch_size_per_gpu: 4 dataset_enlarge_ratio: 1 prefetch_mode: ~ pin_memory: True # network structures network_g: type: MSRSWVSR num_feat: 64 num_block: [5, 3, 2] # path path: resume_state: ~ # training settings train: ema_decay: 0.999 optim_g: type: Adam lr: !!float 2e-4 weight_decay: 0 betas: [0.9, 0.99] scheduler: type: MultiStepLR milestones: [300000] gamma: 0.5 total_iter: 300000 warmup_iter: -1 # no warm up # losses pixel_opt: type: L1Loss loss_weight: 1.0 reduction: mean # validation settings #val: # val_freq: !!float 1e4 # save_img: true # logging settings logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 use_tb_logger: true wandb: project: ~ resume_id: ~ # dist training settings dist_params: backend: nccl port: 29500 find_unused_parameters: true ================================================ FILE: options/train_animesr_step2_lbo_1_gan.yml ================================================ # general settings name: train_animesr_step2_lbo_1_gan model_type: DegradationGANModel scale: 2 num_gpu: auto # set num_gpu: 0 for cpu mode manual_seed: 0 # dataset and data loader settings datasets: train: name: LBO_1 type: CustomPairedImageDataset dataroot_gt: results/input_rescaling_strategy_lbo_1/frames # TO_MODIFY dataroot_lq: datasets/lbo_training_data/real_world_video_to_train_lbo_1 # TO_MODIFY io_backend: type: disk gt_size: 256 use_hflip: true use_rot: true # data loader use_shuffle: true num_worker_per_gpu: 12 batch_size_per_gpu: 16 dataset_enlarge_ratio: 200 prefetch_mode: ~ # network structures network_g: type: SimpleDegradationArch num_in_ch: 3 num_out_ch: 3 num_feat: 64 downscale: 2 network_d: type: MultiScaleDiscriminator num_in_ch: 3 num_feat: 64 num_layers: [3] max_nf_mult: 8 norm_type: none use_sigmoid: False use_sn: True use_downscale: True # path path: pretrain_network_g: experiments/train_animesr_step2_lbo_1_net/models/net_g_100000.pth param_key_g: params strict_load_g: true resume_state: ~ # training settings train: optim_g: type: Adam lr: !!float 1e-4 weight_decay: 0 betas: [0.9, 0.99] optim_d: type: Adam lr: !!float 1e-4 weight_decay: 0 betas: [0.9, 0.99] scheduler: type: MultiStepLR milestones: [50000] gamma: 0.5 total_iter: 100000 warmup_iter: -1 # no warm up # losses pixel_opt: type: L1Loss loss_weight: 1.0 reduction: mean perceptual_opt: type: PerceptualLoss layer_weights: # before relu 'conv1_2': 0.1 'conv2_2': 0.1 'conv3_4': 1 'conv4_4': 1 'conv5_4': 1 vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 gan_opt: type: MultiScaleGANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 loss_weight: !!float 1.0 # logging settings logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 use_tb_logger: true wandb: project: ~ resume_id: ~ # dist training settings dist_params: backend: nccl port: 29500 ================================================ FILE: options/train_animesr_step2_lbo_1_net.yml ================================================ # general settings name: train_animesr_step2_lbo_1_net model_type: DegradationModel scale: 2 num_gpu: auto # set num_gpu: 0 for cpu mode manual_seed: 0 # dataset and data loader settings datasets: train: name: LBO_1 type: CustomPairedImageDataset dataroot_gt: results/input_rescaling_strategy_lbo_1/frames # TO_MODIFY dataroot_lq: datasets/lbo_training_data/real_world_video_to_train_lbo_1 # TO_MODIFY io_backend: type: disk gt_size: 256 use_hflip: true use_rot: true # data loader use_shuffle: true num_worker_per_gpu: 12 batch_size_per_gpu: 16 dataset_enlarge_ratio: 200 prefetch_mode: ~ # network structures network_g: type: SimpleDegradationArch num_in_ch: 3 num_out_ch: 3 num_feat: 64 downscale: 2 # path path: resume_state: ~ # training settings train: optim_g: type: Adam lr: !!float 2e-4 weight_decay: 0 betas: [0.9, 0.99] scheduler: type: MultiStepLR milestones: [50000] gamma: 0.5 total_iter: 100000 warmup_iter: -1 # no warm up # losses l1_opt: type: L1Loss loss_weight: 1.0 reduction: mean l2_opt: type: MSELoss loss_weight: 1.0 reduction: mean # logging settings logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 use_tb_logger: true wandb: project: ~ resume_id: ~ # dist training settings dist_params: backend: nccl port: 29500 ================================================ FILE: options/train_animesr_step3_gan_3LBOs.yml ================================================ # general settings name: train_animesr_step3_gan_3LBOs model_type: VideoRecurrentGANCustomModel scale: 4 num_gpu: auto # set num_gpu: 0 for cpu mode manual_seed: 0 # USM the ground-truth l1_gt_usm: False percep_gt_usm: False gan_gt_usm: False # dataset and data loader settings datasets: train: name: AVC-Train type: FFMPEGAnimeLBODataset dataroot_gt: datasets/AVC-Train # TO_MODIFY test_mode: False io_backend: type: disk num_frame: 15 gt_size: 256 interval_list: [1, 2, 3] random_reverse: True use_flip: true use_rot: true usm_sharp_gt: False degradation_model_path: [weights/pretrained_lbo_1.pth, weights/pretrained_lbo_2.pth, weights/pretrained_lbo_3.pth] # TO_MODIFY # data loader use_shuffle: true num_worker_per_gpu: 5 batch_size_per_gpu: 4 dataset_enlarge_ratio: 1 prefetch_mode: ~ pin_memory: True # network structures network_g: type: MSRSWVSR num_feat: 64 num_block: [5, 3, 2] network_d: type: MultiScaleDiscriminator num_in_ch: 3 num_feat: 64 num_layers: [3, 3, 3] max_nf_mult: 8 norm_type: none use_sigmoid: False use_sn: True use_downscale: True # path path: pretrain_network_g: experiments/train_animesr_step1_net_BasicOPonly/models/net_g_300000.pth param_key_g: params strict_load_g: true resume_state: ~ # training settings train: ema_decay: 0.999 optim_g: type: Adam lr: !!float 1e-4 weight_decay: 0 betas: [0.9, 0.99] lr_flow: !!float 1e-5 optim_d: type: Adam lr: !!float 1e-4 weight_decay: 0 betas: [0.9, 0.99] scheduler: type: MultiStepLR milestones: [300000] gamma: 0.5 total_iter: 300000 warmup_iter: -1 # no warm up fix_flow: ~ # losses pixel_opt: type: L1Loss loss_weight: 1.0 reduction: mean perceptual_opt: type: PerceptualLoss layer_weights: # before relu 'conv1_2': 0.1 'conv2_2': 0.1 'conv3_4': 1 'conv4_4': 1 'conv5_4': 1 vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 gan_opt: type: MultiScaleGANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 loss_weight: !!float 1.0 net_d_iters: 1 net_d_init_iters: 0 # validation settings #val: # val_freq: !!float 1e4 # save_img: true # logging settings logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 use_tb_logger: true wandb: project: ~ resume_id: ~ # dist training settings dist_params: backend: nccl port: 29500 find_unused_parameters: true ================================================ FILE: predict.py ================================================ import os import shutil import tempfile from subprocess import call from zipfile import ZipFile from typing import Optional import mimetypes import torch from cog import BasePredictor, Input, Path, BaseModel call("python setup.py develop", shell=True) class ModelOutput(BaseModel): video: Path sr_frames: Optional[Path] class Predictor(BasePredictor): @torch.inference_mode() def predict( self, video: Path = Input( description="Input video file", default=None, ), frames: Path = Input( description="Zip file of frames of a video. Ignored when video is provided.", default=None, ), ) -> ModelOutput: """Run a single prediction on the model""" assert frames or video, "Please provide frames of video input." out_path = "cog_temp" if os.path.exists(out_path): shutil.rmtree(out_path) os.makedirs(out_path, exist_ok=True) if video: print("processing video...") cmd = ( "python scripts/inference_animesr_video.py -i " + str(video) + " -o " + out_path + " -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 1" ) call(cmd, shell=True) else: print("processing frames...") unzip_frames = "cog_frames_temp" if os.path.exists(unzip_frames): shutil.rmtree(unzip_frames) os.makedirs(unzip_frames) with ZipFile(str(frames), "r") as zip_ref: for zip_info in zip_ref.infolist(): if zip_info.filename[-1] == "/" or zip_info.filename.startswith( "__MACOSX" ): continue mt = mimetypes.guess_type(zip_info.filename) if mt and mt[0] and mt[0].startswith("image/"): zip_info.filename = os.path.basename(zip_info.filename) zip_ref.extract(zip_info, unzip_frames) cmd = ( "python scripts/inference_animesr_frames.py -i " + unzip_frames + " -o " + out_path + " -n AnimeSR_v2 --expname animesr_v2 --save_video_too --fps 20" ) call(cmd, shell=True) frames_output = Path(tempfile.mkdtemp()) / "out.zip" frames_out_dir = os.listdir(f"{out_path}/animesr_v2/frames") assert len(frames_out_dir) == 1 frames_path = os.path.join( f"{out_path}/animesr_v2/frames", frames_out_dir[0] ) # by defult, sr_frames will be saved in cog_temp/animesr_v2/frames sr_frames_files = os.listdir(frames_path) with ZipFile(str(frames_output), "w") as zip: for img in sr_frames_files: zip.write(os.path.join(frames_path, img)) # by defult, video will be saved in cog_temp/animesr_v2/videos video_out_dir = os.listdir(f"{out_path}/animesr_v2/videos") assert len(video_out_dir) == 1 if video_out_dir[0].endswith(".mp4"): source = os.path.join(f"{out_path}/animesr_v2/videos", video_out_dir[0]) else: video_output = os.listdir( f"{out_path}/animesr_v2/videos/{video_out_dir[0]}" )[0] source = os.path.join( f"{out_path}/animesr_v2/videos", video_out_dir[0], video_output ) video_path = Path(tempfile.mkdtemp()) / "out.mp4" shutil.copy(source, str(video_path)) if video: return ModelOutput(video=video_path) return ModelOutput(sr_frames=frames_output, video=video_path) ================================================ FILE: requirements.txt ================================================ basicsr facexlib ffmpeg-python numpy opencv-python pillow psutil torch torchvision tqdm ================================================ FILE: scripts/anime_videos_preprocessing.py ================================================ import argparse import cv2 import glob import numpy as np import os import shutil import torch import torchvision from multiprocessing import Pool from os import path as osp from PIL import Image from tqdm import tqdm from animesr.utils import video_util from animesr.utils.shot_detector import ShotDetector from basicsr.archs.spynet_arch import SpyNet from basicsr.utils import img2tensor from basicsr.utils.download_util import download_file_from_google_drive from facexlib.assessment import init_assessment_model def main(args): """A script to prepare anime videos. The preparation can be divided into following steps: 1. use ffmpeg to extract frames 2. shot detection 3. estimate flow 4. detect black frames 5. use hyperIQA to evaluate the quality of frames 6. generate at most 5 clips per video """ opt = dict() opt['debug'] = args.debug opt['n_thread'] = args.n_thread opt['ss_idx'] = args.ss_idx opt['to_idx'] = args.to_idx # params for step1: extract frames opt['video_root'] = f'{args.dataroot}/raw_videos' opt['save_frames_root'] = f'{args.dataroot}/frames' opt['meta_files_root'] = f'{args.dataroot}/meta' # params for step2: shot detection opt['detect_shot_root'] = f'{args.dataroot}/detect_shot' # params for step3: flow estimation opt['estimate_flow_root'] = f'{args.dataroot}/estimate_flow' opt['spy_pretrain_weight'] = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth' opt['downscale_factor'] = 1 # params for step4: detect black frames opt['black_flag_root'] = f'{args.dataroot}/black_flag' opt['black_threshold'] = 0.98 # params for step5: image quality assessment opt['num_patch_per_iqa'] = 5 opt['iqa_score_root'] = f'{args.dataroot}/iqa_score' # params for step6: generate clips opt['num_frames_per_clip'] = args.n_frames_per_clip opt['num_clips_per_video'] = args.n_clips_per_video opt['select_clips_root'] = f'{args.dataroot}/{args.select_clip_root}' opt['select_clips_meta'] = osp.join(opt['select_clips_root'], 'meta_info') opt['select_clips_frames'] = osp.join(opt['select_clips_root'], 'frames') opt['select_done_flags'] = osp.join(opt['select_clips_root'], 'done_flags') if '1' in args.run: run_step1(opt) if '2' in args.run: run_step2(opt) if '3' in args.run: run_step3(opt) if '4' in args.run: run_step4(opt) if '5' in args.run: run_step5(opt) if '6' in args.run: run_step6(opt) # -------------------------------------------------------------------- # # --------------------------- step1 ---------------------------------- # # -------------------------------------------------------------------- # def run_step1(opt): """extract frames 1. read all video files under video_root folder 2. filter out the videos that already have been processed 3. use multi-process to extract the remaining videos """ video_root = opt['video_root'] frames_root = opt['save_frames_root'] meta_root = opt['meta_files_root'] os.makedirs(frames_root, exist_ok=True) os.makedirs(meta_root, exist_ok=True) if not osp.isdir(video_root): print(f'path {video_root} is not a valid folder, exit.') videos_path = sorted(glob.glob(osp.join(video_root, '*'))) if opt['debug']: videos_path = videos_path[:3] else: videos_path = videos_path[opt['ss_idx']:opt['to_idx']] pbar = tqdm(total=len(videos_path), unit='video', desc='step1') pool = Pool(opt['n_thread']) for video_path in videos_path: video_name = osp.splitext(osp.basename(video_path))[0] if video_name.startswith('.'): print(f'skip {video_name}') continue frame_path = osp.join(frames_root, video_name) meta_path = osp.join(meta_root, f'{video_name}.txt') pool.apply_async( worker1, args=(opt, video_name, video_path, frame_path, meta_path), callback=lambda arg: pbar.update(1)) pool.close() pool.join() def worker1(opt, video_name, video_path, frame_path, meta_path): # get info of video fps = video_util.get_video_fps(video_path) h, w = video_util.get_video_resolution(video_path) num_frames = video_util.get_video_num_frames(video_path) bit_rate = video_util.get_video_bitrate(video_path) # check whether this video has been processed flag = True num_extracted_frames = 0 if osp.exists(frame_path): num_extracted_frames = len(glob.glob(osp.join(frame_path, '*.png'))) if num_extracted_frames == num_frames: print(f'skip {video_path} since there are already {num_frames} frames have been extracted.') flag = False else: print(f'{num_extracted_frames} of {num_frames} have been extracted for {video_path}, re-run.') # extract frames os.makedirs(frame_path, exist_ok=True) video_util.video2frames(video_path, frame_path, force=flag, high_quality=True) if flag: num_extracted_frames = len(glob.glob(osp.join(frame_path, '*.png'))) # write some metadata to meta file with open(meta_path, 'w') as f: f.write(f'Video Name: {video_name}\n') f.write(f'H: {h}\n') f.write(f'W: {w}\n') f.write(f'FPS: {fps}\n') f.write(f'Bit Rate: {bit_rate}kbps\n') f.write(f'{num_extracted_frames}/{num_frames} have been extracted\n') # -------------------------------------------------------------------- # # --------------------------- step2 ---------------------------------- # # -------------------------------------------------------------------- # def run_step2(opt): """shot detection. refer to lijian's pipeline""" detect_shot_root = opt['detect_shot_root'] meta_root = opt['meta_files_root'] os.makedirs(detect_shot_root, exist_ok=True) if not osp.exists(meta_root): print('no videos has run step1, exit.') return # get the video which has been extracted frames videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt'))) videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name] if opt['debug']: videos_name = videos_name[:3] else: videos_name = videos_name[opt['ss_idx']:opt['to_idx']] pbar = tqdm(total=len(videos_name), unit='video', desc='step2') pool = Pool(opt['n_thread']) for video_name in videos_name: pool.apply_async(worker2, args=(opt, video_name), callback=lambda arg: pbar.update(1)) pool.close() pool.join() def worker2(opt, video_name): video_frame_path = osp.join(opt['save_frames_root'], video_name) detect_shot_file_path = osp.join(opt['detect_shot_root'], f'{video_name}.txt') if osp.exists(detect_shot_file_path): print(f'skip {video_name} since {detect_shot_file_path} already exist.') return detector = ShotDetector() shot_list = detector.detect_shots(video_frame_path) with open(detect_shot_file_path, 'w') as f: for shot in shot_list: f.write(f'{shot[0]} {shot[1]}\n') # -------------------------------------------------------------------- # # --------------------------- step3 ---------------------------------- # # -------------------------------------------------------------------- # def run_step3(opt): estimate_flow_root = opt['estimate_flow_root'] meta_root = opt['meta_files_root'] os.makedirs(estimate_flow_root, exist_ok=True) if not osp.exists(meta_root): print('no videos has run step1, exit.') return # download the spynet checkpoint first if not osp.exists(opt['spy_pretrain_weight']): download_file_from_google_drive('1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF', opt['spy_pretrain_weight']) # get the video which has been extracted frames videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt'))) videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name] if opt['debug']: videos_name = videos_name[:3] else: videos_name = videos_name[opt['ss_idx']:opt['to_idx']] pbar = tqdm(total=len(videos_name), unit='video', desc='step3') num_gpus = torch.cuda.device_count() ctx = torch.multiprocessing.get_context('spawn') pool = ctx.Pool(min(3 * num_gpus, opt['n_thread'])) for idx, video_name in enumerate(videos_name): pool.apply_async( worker3, args=(opt, video_name, torch.device(idx % num_gpus)), callback=lambda arg: pbar.update(1)) pool.close() pool.join() def read_img(img_path, device, downscale_factor=1): img = cv2.imread(img_path) h, w = img.shape[0:2] if downscale_factor != 1: img = cv2.resize(img, (w // downscale_factor, h // downscale_factor), interpolation=cv2.INTER_LANCZOS4) img = img2tensor(img) img = img.unsqueeze(0).to(device) return img @torch.no_grad() def worker3(opt, video_name, device): video_frame_path = osp.join(opt['save_frames_root'], video_name) frames_path = sorted(glob.glob(osp.join(video_frame_path, '*.png'))) estimate_flow_file_path = osp.join(opt['estimate_flow_root'], f'{video_name}.txt') if osp.exists(estimate_flow_file_path): with open(estimate_flow_file_path, 'r') as f: lines = f.readlines() length = len(lines) if length == len(frames_path): print(f'skip {video_name} since {length}/{len(frames_path)} have done.') return else: print(f're-run {video_name} since only {length}/{len(frames_path)} have done.') spynet = SpyNet(load_path=opt['spy_pretrain_weight']).eval().to(device) downscale_factor = opt['downscale_factor'] flow_out_list = [] pbar = tqdm(total=len(frames_path), unit='frame', desc='worker3') pre_img = None for idx, frame_path in enumerate(frames_path): img_name = osp.basename(frame_path) cur_img = read_img(frame_path, device, downscale_factor=downscale_factor) if pre_img is not None: flow = spynet(cur_img, pre_img) flow = flow.abs() flow_max = flow.max().item() flow_avg = flow.mean().item() * 2.0 # according to lijian's hyper-parameter elif idx == 0: flow_max = 0.0 flow_avg = 0.0 else: raise RuntimeError(f'pre_img is none at {idx}') flow_out_list.append(f'{img_name} {flow_max:.6f} {flow_avg:.6f}\n') pre_img = cur_img pbar.update(1) with open(estimate_flow_file_path, 'w') as f: for line in flow_out_list: f.write(line) # -------------------------------------------------------------------- # # --------------------------- step4 ---------------------------------- # # -------------------------------------------------------------------- # def run_step4(opt): black_flag_root = opt['black_flag_root'] meta_root = opt['meta_files_root'] os.makedirs(black_flag_root, exist_ok=True) if not osp.exists(meta_root): print('no videos has run step1, exit.') return # get the video which has been extracted frames videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt'))) videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name] if opt['debug']: videos_name = videos_name[:3] os.makedirs('tmp_black', exist_ok=True) else: videos_name = videos_name[opt['ss_idx']:opt['to_idx']] pbar = tqdm(total=len(videos_name), unit='video', desc='step4') pool = Pool(opt['n_thread']) for idx, video_name in enumerate(videos_name): pool.apply_async(worker4, args=(opt, video_name), callback=lambda arg: pbar.update(1)) pool.close() pool.join() def worker4(opt, video_name): video_frame_path = osp.join(opt['save_frames_root'], video_name) black_flag_path = osp.join(opt['black_flag_root'], f'{video_name}.txt') if osp.exists(black_flag_path): print(f'skip {video_name} since {black_flag_path} already exists.') return frames_path = sorted(glob.glob(osp.join(video_frame_path, '*.png'))) out_list = [] pbar = tqdm(total=len(frames_path), unit='frame', desc='worker4') for frame_path in frames_path: img = cv2.imread(frame_path) img_name = osp.basename(frame_path) h, w = img.shape[0:2] total_pixels = h * w img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) hist = cv2.calcHist([img_gray], [0], None, [256], [0.0, 255.0]) max_pixel = max(hist)[0] percentage = max_pixel / total_pixels if percentage > opt['black_threshold']: out_list.append(f'{img_name} {0} {percentage:.6f}\n') if opt['debug']: cv2.imwrite(osp.join('tmp_black', f'{video_name}_{img_name}'), img) else: out_list.append(f'{img_name} {1} {percentage:.6f}\n') pbar.update(1) with open(black_flag_path, 'w') as f: for line in out_list: f.write(line) # -------------------------------------------------------------------- # # --------------------------- step5 ---------------------------------- # # -------------------------------------------------------------------- # def run_step5(opt): iqa_score_root = opt['iqa_score_root'] meta_root = opt['meta_files_root'] os.makedirs(iqa_score_root, exist_ok=True) if not osp.exists(meta_root): print('no videos has run step1, exit.') return # get the video which has been extracted frames videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt'))) videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name] if opt['debug']: videos_name = videos_name[:3] os.makedirs('tmp_low_iqa', exist_ok=True) else: videos_name = videos_name[opt['ss_idx']:opt['to_idx']] pbar = tqdm(total=len(videos_name), unit='video', desc='step5') num_gpus = torch.cuda.device_count() ctx = torch.multiprocessing.get_context('spawn') pool = ctx.Pool(min(3 * num_gpus, opt['n_thread'])) for idx, video_name in enumerate(videos_name): pool.apply_async( worker5, args=(opt, video_name, torch.device(idx % num_gpus)), callback=lambda arg: pbar.update(1)) pool.close() pool.join() @torch.no_grad() def worker5(opt, video_name, device): video_frame_path = osp.join(opt['save_frames_root'], video_name) frames_path = sorted(glob.glob(osp.join(video_frame_path, '*.png'))) iqa_score_path = osp.join(opt['iqa_score_root'], f'{video_name}.txt') if osp.exists(iqa_score_path): with open(iqa_score_path, 'r') as f: lines = f.readlines() length = len(lines) if length == len(frames_path): print(f'skip {video_name} since {length}/{len(frames_path)} have done.') return else: print(f're-run {video_name} since only {length}/{len(frames_path)} have done.') assess_net = init_assessment_model('hypernet', device=device) assess_net = assess_net.half() # specified transformation in original hyperIQA transforms_resize = torchvision.transforms.Compose([ torchvision.transforms.Resize((512, 384)), ]) transforms_crop = torchvision.transforms.Compose([ torchvision.transforms.RandomCrop(size=224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) iqa_out_list = [] pbar = tqdm(total=len(frames_path), unit='frame', desc='worker3') for idx, frame_path in enumerate(frames_path): img_name = osp.basename(frame_path) cv2_img = cv2.imread(frame_path) # BRG -> RGB img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB) img = Image.fromarray(img) patchs = [] img_resize = transforms_resize(img) for _ in range(opt['num_patch_per_iqa']): patchs.append(transforms_crop(img_resize)) patch = torch.stack(patchs, dim=0).to(device) pred = assess_net(patch.half()) score = pred.mean().item() iqa_out_list.append(f'{img_name} {score:.6f}\n') if opt['debug'] and score < 50.0: cv2.imwrite(osp.join('tmp_low_iqa', f'{video_name}_{img_name}'), cv2_img) pbar.update(1) with open(iqa_score_path, 'w') as f: for line in iqa_out_list: f.write(line) # -------------------------------------------------------------------- # # --------------------------- step6 ---------------------------------- # # -------------------------------------------------------------------- # def filter_frozen_shots(shots, flows): """select clips from input video.""" flag_shot = np.ones(len(shots)) for idx, shot in enumerate(shots): shot = shot.split(' ') start = int(shot[0]) end = int(shot[1]) flow_in_shot = [] for i in range(start, end + 1, 1): if i == 0: continue else: flow_in_shot.append(float(flows[i].split(' ')[2])) flow_in_shot = np.array(flow_in_shot) flow_std = np.std(flow_in_shot) if flow_std < 14.0: flag_shot[idx] = 0 return flag_shot def generate_clips(shots, flows, filter_frames, hyperiqa, max_length=500): """ hyperiqa [0, 100] flows [0, 15000] (may be larger) """ clips = [] clip_scores = [] clip = [] shot_flow = 0 shot_hyperiqa = 0 for shot in shots: shot = shot.split(' ') start = int(shot[0]) end = int(shot[1]) pre_black = 0 for i in range(start, end + 1, 1): if i == start: stat = 0 pre_black = 1 # the first frame in shot do not need flow else: stat = 1 black_frame_thr = float(filter_frames[i].split(' ')[2]) # drop img when 90% of pixels are identical if black_frame_thr < 0.90: black_frame = 0 else: black_frame = 1 # if current frame is a black frame, delete if black_frame == 1: pre_black = 1 elif pre_black == 0: flow = float(flows[i].split(' ')[1]) shot_flow += flow else: pre_black = 0 flow = float(flows[i].split(' ')[1]) # calcu hyperiqa for non-black frames if black_frame == 0: curr_hyperiqa = float(hyperiqa[i].split(' ')[1]) shot_hyperiqa += curr_hyperiqa clip.append(f'{i+1:08d} {stat} {flow} {curr_hyperiqa}') if len(clip) == max_length: clips.append(clip.copy()) clip_score = shot_flow / 150.0 + shot_hyperiqa clip_score = clip_score / len(clip) clip_scores.append(clip_score) clip = [] shot_flow = 0 shot_hyperiqa = 0 # print(len(clip)) # if len(clip) > 0: # clips.append(clip.copy()) # clip_score = shot_flow / 150.0 + shot_hyperiqa # clip_score = clip_score / len(clip) # clip_scores.append(clip_score) sorted_shot = np.argsort(-np.array(clip_scores)) return [clips[i] for i in sorted_shot], [clip_scores[i] for i in sorted_shot] def run_step6(opt): meta_root = opt['meta_files_root'] if not osp.exists(meta_root): print('no videos has run step1, exit.') return # get the video which has been extracted frames videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt'))) videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name] if opt['debug']: videos_name = videos_name[:3] else: videos_name = videos_name[opt['ss_idx']:opt['to_idx']] pbar = tqdm(total=len(videos_name), unit='video', desc='step6') os.makedirs(opt['select_clips_meta'], exist_ok=True) os.makedirs(opt['select_clips_frames'], exist_ok=True) os.makedirs(opt['select_done_flags'], exist_ok=True) pool = Pool(opt['n_thread']) for video_name in videos_name: pool.apply_async(worker6, args=(opt, video_name), callback=lambda arg: pbar.update(1)) pool.close() pool.join() def worker6(opt, video_name): select_clips_meta = opt['select_clips_meta'] select_clips_frames = opt['select_clips_frames'] select_done_flags = opt['select_done_flags'] if osp.exists(osp.join(select_done_flags, f'{video_name}.txt')): print(f'skip {video_name}.') return with open(osp.join(opt['detect_shot_root'], f'{video_name}.txt'), 'r') as f: shots = f.readlines() shots = [shot.strip() for shot in shots] with open(osp.join(opt['estimate_flow_root'], f'{video_name}.txt'), 'r') as f: flows = f.readlines() flows = [flow.strip() for flow in flows] with open(osp.join(opt['black_flag_root'], f'{video_name}.txt'), 'r') as f: black_flags = f.readlines() black_flags = [black_flag.strip() for black_flag in black_flags] with open(osp.join(opt['iqa_score_root'], f'{video_name}.txt'), 'r') as f: iqa_scores = f.readlines() iqa_scores = [iqa_score.strip() for iqa_score in iqa_scores] flag_shot = filter_frozen_shots(shots, flows) flag = np.where(flag_shot == 1) flag = flag[0].tolist() filtered_shots = [shots[i] for i in flag] clips, scores = generate_clips( filtered_shots, flows, black_flags, iqa_scores, max_length=opt['num_frames_per_clip']) with open(osp.join(select_clips_meta, f'{video_name}.txt'), 'w') as f: for i, clip in enumerate(clips): os.makedirs(osp.join(select_clips_frames, f'{video_name}_{i}'), exist_ok=True) for idx, info in enumerate(clip): f.write(f'clip: {i:02d} {info} {scores[i]}\n') img_name = info.split(' ')[0] + '.png' shutil.copy( osp.join(opt['save_frames_root'], video_name, img_name), osp.join(select_clips_frames, f'{video_name}_{i}', f'{idx:08d}.png')) if i >= opt['num_clips_per_video'] - 1: break with open(osp.join(select_done_flags, f'{video_name}.txt'), 'w') as f: f.write(f'{i+1} clips are selected for {video_name}.') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--dataroot', type=str, required=True, help='dataset root, dataroot/raw_videos should contains your HQ videos to be processed.') parser.add_argument('--n_thread', type=int, default=4, help='Thread number.') parser.add_argument('--run', type=str, default='123456', help='run which steps') parser.add_argument('--debug', action='store_true') parser.add_argument('--ss_idx', type=int, default=None, help='ss index') parser.add_argument('--to_idx', type=int, default=None, help='to index') parser.add_argument('--n_frames_per_clip', type=int, default=100) parser.add_argument('--n_clips_per_video', type=int, default=1) parser.add_argument('--select_clip_root', type=str, default='select_clips') args = parser.parse_args() main(args) ================================================ FILE: scripts/inference_animesr_frames.py ================================================ """inference AnimeSR on frames""" import argparse import cv2 import glob import numpy as np import os import psutil import queue import threading import time import torch from os import path as osp from tqdm import tqdm from animesr.utils.inference_base import get_base_argument_parser, get_inference_model from animesr.utils.video_util import frames2video from basicsr.data.transforms import mod_crop from basicsr.utils.img_util import img2tensor, tensor2img def read_img(path, require_mod_crop=True, mod_scale=4, input_rescaling_factor=1.0): """ read an image tensor from a given path Args: path: image path require_mod_crop: mod crop or not. since the arch is multi-scale, so mod crop is needed by default mod_scale: scale factor for mod_crop Returns: torch.Tensor: size(1, c, h, w) """ img = cv2.imread(path) img = img.astype(np.float32) / 255. if input_rescaling_factor != 1.0: h, w = img.shape[:2] img = cv2.resize( img, (int(w * input_rescaling_factor), int(h * input_rescaling_factor)), interpolation=cv2.INTER_LANCZOS4) if require_mod_crop: img = mod_crop(img, mod_scale) img = img2tensor(img, bgr2rgb=True, float32=True) return img.unsqueeze(0) class IOConsumer(threading.Thread): """Since IO time can take up a significant portion of the total inference time, so we use multi thread to write frames individually. """ def __init__(self, args: argparse.Namespace, que, qid): super().__init__() self._queue = que self.qid = qid self.args = args def run(self): while True: msg = self._queue.get() if isinstance(msg, str) and msg == 'quit': break output = msg['output'] imgname = msg['imgname'] out_img = tensor2img(output.squeeze(0)) if self.args.outscale != self.args.netscale: h, w = out_img.shape[:2] out_img = cv2.resize( out_img, (int( w * self.args.outscale / self.args.netscale), int(h * self.args.outscale / self.args.netscale)), interpolation=cv2.INTER_LANCZOS4) cv2.imwrite(imgname, out_img) print(f'IO for worker {self.qid} is done.') @torch.no_grad() def main(): """Inference demo for AnimeSR. It mainly for restoring anime frames. """ parser = get_base_argument_parser() parser.add_argument('--input_rescaling_factor', type=float, default=1.0) parser.add_argument('--num_io_consumer', type=int, default=3, help='number of IO consumer') parser.add_argument( '--sample_interval', type=int, default=1, help='save 1 frame for every $sample_interval frames. this will be useful for calculating the metrics') parser.add_argument('--save_video_too', action='store_true') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = get_inference_model(args, device) # prepare output dir frame_output = osp.join(args.output, args.expname, 'frames') os.makedirs(frame_output, exist_ok=True) # the input format can be: # 1. clip folder which contains frames # or 2. a folder which contains several clips first_level_dir = len(glob.glob(osp.join(args.input, '*.png'))) > 0 if args.input.endswith('/'): args.input = args.input[:-1] if first_level_dir: videos_name = [osp.basename(args.input)] args.input = osp.dirname(args.input) else: videos_name = sorted(os.listdir(args.input)) pbar1 = tqdm(total=len(videos_name), unit='video', desc='inference') que = queue.Queue() consumers = [IOConsumer(args, que, f'IO_{i}') for i in range(args.num_io_consumer)] for consumer in consumers: consumer.start() for video_name in videos_name: video_folder_path = osp.join(args.input, video_name) imgs_list = sorted(glob.glob(osp.join(video_folder_path, '*'))) num_imgs = len(imgs_list) os.makedirs(osp.join(frame_output, video_name), exist_ok=True) # prepare prev = read_img( imgs_list[0], require_mod_crop=True, mod_scale=args.mod_scale, input_rescaling_factor=args.input_rescaling_factor).to(device) cur = prev nxt = read_img( imgs_list[min(1, num_imgs - 1)], require_mod_crop=True, mod_scale=args.mod_scale, input_rescaling_factor=args.input_rescaling_factor).to(device) c, h, w = prev.size()[-3:] state = prev.new_zeros(1, 64, h, w) out = prev.new_zeros(1, c, h * args.netscale, w * args.netscale) pbar2 = tqdm(total=num_imgs, unit='frame', desc='inference') tot_model_time = 0 cnt_model_time = 0 for idx in range(num_imgs): torch.cuda.synchronize() start = time.time() img_name = osp.splitext(osp.basename(imgs_list[idx]))[0] out, state = model.cell(torch.cat((prev, cur, nxt), dim=1), out, state) torch.cuda.synchronize() model_time = time.time() - start tot_model_time += model_time cnt_model_time += 1 if (idx + 1) % args.sample_interval == 0: # put the output frame to the queue to be consumed que.put({'output': out.cpu().clone(), 'imgname': osp.join(frame_output, video_name, f'{img_name}.png')}) torch.cuda.synchronize() start = time.time() prev = cur cur = nxt nxt = read_img( imgs_list[min(idx + 2, num_imgs - 1)], require_mod_crop=True, mod_scale=args.mod_scale, input_rescaling_factor=args.input_rescaling_factor).to(device) torch.cuda.synchronize() read_time = time.time() - start pbar2.update(1) pbar2.set_description(f'read_time: {read_time}, model_time: {tot_model_time/cnt_model_time}') mem = psutil.virtual_memory() # since the speed of producer (model inference) is faster than the consumer (I/O) # if there is a risk of OOM, just sleep to let the consumer work if mem.percent > 80.0: time.sleep(30) pbar1.update(1) for _ in range(args.num_io_consumer): que.put('quit') for consumer in consumers: consumer.join() if not args.save_video_too: return # convert the frames to videos video_output = osp.join(args.output, args.expname, 'videos') os.makedirs(video_output, exist_ok=True) for video_name in videos_name: out_path = osp.join(video_output, f'{video_name}.mp4') frames2video( osp.join(frame_output, video_name), out_path, fps=24 if args.fps is None else args.fps, suffix='png') if __name__ == '__main__': main() ================================================ FILE: scripts/inference_animesr_video.py ================================================ import cv2 import ffmpeg import glob import mimetypes import numpy as np import os import shutil import subprocess import torch from os import path as osp from tqdm import tqdm from animesr.utils import video_util from animesr.utils.inference_base import get_base_argument_parser, get_inference_model from basicsr.data.transforms import mod_crop from basicsr.utils.img_util import img2tensor, tensor2img from basicsr.utils.logger import AvgTimer def get_video_meta_info(video_path): """get the meta info of the video by using ffprobe with python interface""" ret = {} probe = ffmpeg.probe(video_path) video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) ret['width'] = video_streams[0]['width'] ret['height'] = video_streams[0]['height'] ret['fps'] = eval(video_streams[0]['avg_frame_rate']) ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None try: ret['nb_frames'] = int(video_streams[0]['nb_frames']) except KeyError: # bilibili transcoder dont have nb_frames ret['duration'] = float(probe['format']['duration']) ret['nb_frames'] = int(ret['duration'] * ret['fps']) print(ret['duration'], ret['nb_frames']) return ret def get_sub_video(args, num_process, process_idx): """Cut the whole video into num_process parts, return the process_idx-th part""" if num_process == 1: return args.input meta = get_video_meta_info(args.input) duration = int(meta['nb_frames'] / meta['fps']) part_time = duration // num_process print(f'duration: {duration}, part_time: {part_time}') out_path = osp.join(args.output, 'inp_sub_videos', f'{process_idx:03d}.mp4') cmd = [ args.ffmpeg_bin, f'-i {args.input}', f'-ss {part_time * process_idx}', f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '', '-async 1', out_path, '-y', ] print(' '.join(cmd)) subprocess.call(' '.join(cmd), shell=True) return out_path class Reader: """read frames from a video stream or frames list""" def __init__(self, args, total_workers=1, worker_idx=0, device=torch.device('cuda')): self.args = args input_type = mimetypes.guess_type(args.input)[0] self.input_type = 'folder' if input_type is None else input_type self.paths = [] # for image&folder type self.audio = None self.input_fps = None if self.input_type.startswith('video'): video_path = get_sub_video(args, total_workers, worker_idx) # read bgr from stream, which is the same format as opencv self.stream_reader = ( ffmpeg .input(video_path) .output('pipe:', format='rawvideo', pix_fmt='bgr24', loglevel='error') .run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin) ) # yapf: disable # noqa meta = get_video_meta_info(video_path) self.width = meta['width'] self.height = meta['height'] self.input_fps = meta['fps'] self.audio = meta['audio'] self.nb_frames = meta['nb_frames'] else: if self.input_type.startswith('image'): self.paths = [args.input] else: paths = sorted(glob.glob(os.path.join(args.input, '*'))) tot_frames = len(paths) num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0) self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)] self.nb_frames = len(self.paths) assert self.nb_frames > 0, 'empty folder' from PIL import Image tmp_img = Image.open(self.paths[0]) # lazy load self.width, self.height = tmp_img.size self.idx = 0 self.device = device def get_resolution(self): return self.height, self.width def get_fps(self): """the fps of sr video is set to the user input fps first, followed by the input fps, If the first two values are None, then the commonly used fps 24 is set""" if self.args.fps is not None: return self.args.fps elif self.input_fps is not None: return self.input_fps return 24 def get_audio(self): return self.audio def __len__(self): """return the number of frames for this worker, however, this may be not accurate for video stream""" return self.nb_frames def get_frame_from_stream(self): img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel if not img_bytes: # end of stream return None img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) return img def get_frame_from_list(self): if self.idx >= self.nb_frames: return None img = cv2.imread(self.paths[self.idx]) self.idx += 1 return img def get_frame(self): if self.input_type.startswith('video'): img = self.get_frame_from_stream() else: img = self.get_frame_from_list() if img is None: raise StopIteration # bgr uint8 numpy -> rgb float32 [0, 1] tensor on device img = img.astype(np.float32) / 255. img = mod_crop(img, self.args.mod_scale) img = img2tensor(img, bgr2rgb=True, float32=True).unsqueeze(0).to(self.device) if self.args.half: # half precision won't make a big impact on visuals img = img.half() return img def close(self): # close the video stream if self.input_type.startswith('video'): self.stream_reader.stdin.close() self.stream_reader.wait() class Writer: """write frames to a video stream""" def __init__(self, args, audio, height, width, video_save_path, fps): out_width, out_height = int(width * args.outscale), int(height * args.outscale) if out_height > 2160: print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', 'We highly recommend to decrease the outscale(aka, -s).') vsp = video_save_path if audio is not None: self.stream_writer = ( ffmpeg .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', framerate=fps) .output(audio, vsp, pix_fmt='yuv420p', vcodec='libx264', loglevel='error', acodec='copy') .overwrite_output() .run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin) ) # yapf: disable # noqa else: self.stream_writer = ( ffmpeg .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', framerate=fps) .output(vsp, pix_fmt='yuv420p', vcodec='libx264', loglevel='error') .overwrite_output() .run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin) ) # yapf: disable # noqa self.out_width = out_width self.out_height = out_height self.args = args def write_frame(self, frame): if self.args.outscale != self.args.netscale: frame = cv2.resize(frame, (self.out_width, self.out_height), interpolation=cv2.INTER_LANCZOS4) self.stream_writer.stdin.write(frame.tobytes()) def close(self): self.stream_writer.stdin.close() self.stream_writer.wait() @torch.no_grad() def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0): # prepare model model = get_inference_model(args, device) # prepare reader and writer reader = Reader(args, total_workers, worker_idx, device=device) audio = reader.get_audio() height, width = reader.get_resolution() height = height - height % args.mod_scale width = width - width % args.mod_scale fps = reader.get_fps() writer = Writer(args, audio, height, width, video_save_path, fps) # initialize pre/cur/nxt frames, pre sr frame, and pre hidden state for inference end_flag = False prev = reader.get_frame() cur = prev try: nxt = reader.get_frame() except StopIteration: end_flag = True nxt = cur state = prev.new_zeros(1, 64, height, width) out = prev.new_zeros(1, 3, height * args.netscale, width * args.netscale) pbar = tqdm(total=len(reader), unit='frame', desc='inference') model_timer = AvgTimer() # model inference time tracker i_timer = AvgTimer() # I(input read) time tracker o_timer = AvgTimer() # O(output write) time tracker while True: # inference at current step torch.cuda.synchronize(device=device) model_timer.start() out, state = model.cell(torch.cat((prev, cur, nxt), dim=1), out, state) torch.cuda.synchronize(device=device) model_timer.record() # write current sr frame to video stream torch.cuda.synchronize(device=device) o_timer.start() output_frame = tensor2img(out, rgb2bgr=False) writer.write_frame(output_frame) torch.cuda.synchronize(device=device) o_timer.record() # if end of stream, break if end_flag: break # move the sliding window torch.cuda.synchronize(device=device) i_timer.start() prev = cur cur = nxt try: nxt = reader.get_frame() except StopIteration: nxt = cur end_flag = True torch.cuda.synchronize(device=device) i_timer.record() # update&print infomation pbar.update(1) pbar.set_description( f'I: {i_timer.get_avg_time():.4f} O: {o_timer.get_avg_time():.4f} Model: {model_timer.get_avg_time():.4f}') reader.close() writer.close() def run(args): if args.suffix is None: args.suffix = '' else: args.suffix = f'_{args.suffix}' video_save_path = osp.join(args.output, f'{args.video_name}{args.suffix}.mp4') # set up multiprocessing num_gpus = torch.cuda.device_count() num_process = num_gpus * args.num_process_per_gpu if num_process == 1: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') inference_video(args, video_save_path, device=device) return ctx = torch.multiprocessing.get_context('spawn') pool = ctx.Pool(num_process) out_sub_videos_dir = osp.join(args.output, 'out_sub_videos') os.makedirs(out_sub_videos_dir, exist_ok=True) os.makedirs(osp.join(args.output, 'inp_sub_videos'), exist_ok=True) pbar = tqdm(total=num_process, unit='sub_video', desc='inference') for i in range(num_process): sub_video_save_path = osp.join(out_sub_videos_dir, f'{i:03d}.mp4') pool.apply_async( inference_video, args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i), callback=lambda arg: pbar.update(1)) pool.close() pool.join() # combine sub videos # prepare vidlist.txt with open(f'{args.output}/vidlist.txt', 'w') as f: for i in range(num_process): f.write(f'file \'out_sub_videos/{i:03d}.mp4\'\n') # To avoid video&audio desync as mentioned in https://github.com/xinntao/Real-ESRGAN/issues/388 # we use the solution provided in https://stackoverflow.com/a/52156277 to solve this issue cmd = [ args.ffmpeg_bin, '-f', 'concat', '-safe', '0', '-i', f'{args.output}/vidlist.txt', '-c:v', 'copy', '-af', 'aresample=async=1000', video_save_path, '-y', ] # yapf: disable print(' '.join(cmd)) subprocess.call(cmd) shutil.rmtree(out_sub_videos_dir) shutil.rmtree(osp.join(args.output, 'inp_sub_videos')) os.remove(f'{args.output}/vidlist.txt') def main(): """Inference demo for AnimeSR. It mainly for restoring anime videos. """ parser = get_base_argument_parser() parser.add_argument( '--extract_frame_first', action='store_true', help='if input is a video, you can still extract the frames first, other wise AnimeSR will read from stream') parser.add_argument( '--num_process_per_gpu', type=int, default=1, help='the total process is number_process_per_gpu * num_gpu') parser.add_argument( '--suffix', type=str, default=None, help='you can add a suffix string to the sr video name, for example, x2') args = parser.parse_args() args.ffmpeg_bin = os.environ.get('ffmpeg_exe_path', 'ffmpeg') args.input = args.input.rstrip('/').rstrip('\\') if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'): is_video = True else: is_video = False if args.extract_frame_first and not is_video: args.extract_frame_first = False # prepare input and output args.video_name = osp.splitext(osp.basename(args.input))[0] args.output = osp.join(args.output, args.expname, 'videos', args.video_name) os.makedirs(args.output, exist_ok=True) if args.extract_frame_first: inp_extracted_frames = osp.join(args.output, 'inp_extracted_frames') os.makedirs(inp_extracted_frames, exist_ok=True) video_util.video2frames(args.input, inp_extracted_frames, force=True, high_quality=True) video_meta = get_video_meta_info(args.input) args.fps = video_meta['fps'] args.input = inp_extracted_frames run(args) if args.extract_frame_first: shutil.rmtree(args.input) if __name__ == '__main__': main() ================================================ FILE: scripts/metrics/MANIQA/inference_MANIQA.py ================================================ import argparse import os import random import torch from pipal_data import NTIRE2022 from torch.utils.data import DataLoader from torchvision import transforms from tqdm import tqdm from utils import Normalize, ToTensor, crop_image def parse_args(): parser = argparse.ArgumentParser(description='Inference script of RealBasicVSR') parser.add_argument('--model_path', help='checkpoint file', required=True) parser.add_argument('--input_dir', help='directory of the input video', required=True) parser.add_argument( '--output_dir', help='directory of the output results', default='output/ensemble_attentionIQA2_finetune_e2/AnimeSR') args = parser.parse_args() return args def main(): args = parse_args() # configuration batch_size = 10 num_workers = 8 average_iters = 20 crop_size = 224 os.makedirs(args.output_dir, exist_ok=True) model = torch.load(args.model_path) # map to cuda, if available cuda_flag = False if torch.cuda.is_available(): model = model.cuda() cuda_flag = True model.eval() total_avg_score = [] subfolder_namelist = [] for subfolder_name in sorted(os.listdir(args.input_dir)): avg_score = 0.0 subfolder_root = os.path.join(args.input_dir, subfolder_name) if os.path.isdir(subfolder_root) and subfolder_name != 'assemble-folder': # data load val_dataset = NTIRE2022( ref_path=subfolder_root, dis_path=subfolder_root, transform=transforms.Compose([Normalize(0.5, 0.5), ToTensor()]), ) val_loader = DataLoader( dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=False) name_list, pred_list = [], [] with open(os.path.join(args.output_dir, f'{subfolder_name}.txt'), 'w') as f: for data in tqdm(val_loader): pred = 0 for i in range(average_iters): if cuda_flag: x_d = data['d_img_org'].cuda() b, c, h, w = x_d.shape top = random.randint(0, h - crop_size) left = random.randint(0, w - crop_size) img = crop_image(top, left, crop_size, img=x_d) with torch.no_grad(): pred += model(img) pred /= average_iters d_name = data['d_name'] pred = pred.cpu().numpy() name_list.extend(d_name) pred_list.extend(pred) for i in range(len(name_list)): f.write(f'{name_list[i]}, {float(pred_list[i][0]): .6f}\n') avg_score += float(pred_list[i][0]) avg_score /= len(name_list) f.write(f'The average score of {subfolder_name} is {avg_score:.6f}') f.close() subfolder_namelist.append(subfolder_name) total_avg_score.append(avg_score) with open(os.path.join(args.output_dir, 'average.txt'), 'w') as f: for idx, averge_score in enumerate(total_avg_score): string = f'Folder {subfolder_namelist[idx]}, Average Score: {averge_score:.6f}\n' f.write(string) print(f'Folder {subfolder_namelist[idx]}, Average Score: {averge_score:.6f}') print(f'Average Score of {len(subfolder_namelist)} Folders: {sum(total_avg_score) / len(total_avg_score):.6f}') string = f'Average Score of {len(subfolder_namelist)} Folders: {sum(total_avg_score) / len(total_avg_score):.6f}' # noqa E501 f.write(string) f.close() if __name__ == '__main__': main() ================================================ FILE: scripts/metrics/MANIQA/models/model_attentionIQA2.py ================================================ # flake8: noqa import timm import torch from einops import rearrange from models.swin import SwinTransformer from timm.models.vision_transformer import Block from torch import nn class ChannelAttn(nn.Module): def __init__(self, dim, drop=0.1): super().__init__() self.c_q = nn.Linear(dim, dim) self.c_k = nn.Linear(dim, dim) self.c_v = nn.Linear(dim, dim) self.norm_fact = dim**-0.5 self.softmax = nn.Softmax(dim=-1) self.attn_drop = nn.Dropout(drop) self.proj_drop = nn.Dropout(drop) def forward(self, x): _x = x B, C, N = x.shape q = self.c_q(x) k = self.c_k(x) v = self.c_v(x) attn = q @ k.transpose(-2, -1) * self.norm_fact attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, C, N) x = self.proj_drop(x) x = x + _x return x class SaveOutput: def __init__(self): self.outputs = [] def __call__(self, module, module_in, module_out): self.outputs.append(module_out) def clear(self): self.outputs = [] class AttentionIQA(nn.Module): def __init__(self, embed_dim=72, num_outputs=1, patch_size=8, drop=0.1, depths=[2, 2], window_size=4, dim_mlp=768, num_heads=[4, 4], img_size=224, num_channel_attn=2, **kwargs): super().__init__() self.img_size = img_size self.patch_size = patch_size self.input_size = img_size // patch_size self.patches_resolution = (img_size // patch_size, img_size // patch_size) self.vit = timm.create_model('vit_base_patch8_224', pretrained=True) self.save_output = SaveOutput() hook_handles = [] for layer in self.vit.modules(): if isinstance(layer, Block): handle = layer.register_forward_hook(self.save_output) hook_handles.append(handle) self.channel_attn1 = nn.Sequential(*[ChannelAttn(self.input_size**2) for i in range(num_channel_attn)]) self.channel_attn2 = nn.Sequential(*[ChannelAttn(self.input_size**2) for i in range(num_channel_attn)]) self.conv1 = nn.Conv2d(embed_dim * 4, embed_dim, 1, 1, 0) self.swintransformer1 = SwinTransformer( patches_resolution=self.patches_resolution, depths=depths, num_heads=num_heads, embed_dim=embed_dim, window_size=window_size, dim_mlp=dim_mlp) self.swintransformer2 = SwinTransformer( patches_resolution=self.patches_resolution, depths=depths, num_heads=num_heads, embed_dim=embed_dim // 2, window_size=window_size, dim_mlp=dim_mlp) self.conv2 = nn.Conv2d(embed_dim, embed_dim // 2, 1, 1, 0) self.fc_score = nn.Sequential( nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(), nn.Dropout(drop), nn.Linear(embed_dim // 2, num_outputs), nn.ReLU()) self.fc_weight = nn.Sequential( nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(), nn.Dropout(drop), nn.Linear(embed_dim // 2, num_outputs), nn.Sigmoid()) def extract_feature(self, save_output): x6 = save_output.outputs[6][:, 1:] x7 = save_output.outputs[7][:, 1:] x8 = save_output.outputs[8][:, 1:] x9 = save_output.outputs[9][:, 1:] x = torch.cat((x6, x7, x8, x9), dim=2) return x def forward(self, x): _x = self.vit(x) x = self.extract_feature(self.save_output) self.save_output.outputs.clear() # stage 1 x = rearrange(x, 'b (h w) c -> b c (h w)', h=self.input_size, w=self.input_size) x = self.channel_attn1(x) x = rearrange(x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size) x = self.conv1(x) x = self.swintransformer1(x) # stage2 x = rearrange(x, 'b c h w -> b c (h w)', h=self.input_size, w=self.input_size) x = self.channel_attn2(x) x = rearrange(x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size) x = self.conv2(x) x = self.swintransformer2(x) x = rearrange(x, 'b c h w -> b (h w) c', h=self.input_size, w=self.input_size) f = self.fc_score(x) w = self.fc_weight(x) s = torch.sum(f * w, dim=1) / torch.sum(w, dim=1) return s ================================================ FILE: scripts/metrics/MANIQA/models/swin.py ================================================ """ isort:skip_file """ # flake8: noqa import torch import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from einops import rearrange from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from torch import nn """ attention decoder mask """ def get_attn_decoder_mask(seq): subsequent_mask = torch.ones_like(seq).unsqueeze(-1).expand(seq.size(0), seq.size(1), seq.size(1)) subsequent_mask = subsequent_mask.triu(diagonal=1) # upper triangular part of a matrix(2-D) return subsequent_mask """ attention pad mask """ def get_attn_pad_mask(seq_q, seq_k, i_pad): batch_size, len_q = seq_q.size() batch_size, len_k = seq_k.size() pad_attn_mask = seq_k.data.eq(i_pad) pad_attn_mask = pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k) return pad_attn_mask class DecoderWindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.W_Q = nn.Linear(dim, dim) self.W_K = nn.Linear(dim, dim) self.W_V = nn.Linear(dim, dim) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, q, k, v, mask=None, attn_mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = q.shape q = self.W_Q(q).view(B_, N, self.num_heads, C // self.num_heads).transpose(1, 2) k = self.W_K(k).view(B_, N, self.num_heads, C // self.num_heads).transpose(1, 2) v = self.W_V(v).view(B_, N, self.num_heads, C // self.num_heads).transpose(1, 2) q = q * self.scale attn = (q @ k.transpose(-2, -1)) attn.masked_fill_(attn_mask, -1e9) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops """ decoder layer """ class DecoderLayer(nn.Module): def __init__(self, input_resolution=(28, 28), embed_dim=256, layer_norm_epsilon=1e-12, dim_mlp=1024, num_heads=4, dim_head=128, window_size=7, shift_size=0, i_layer=0, act_layer=nn.GELU, drop=0., drop_path=0.): super().__init__() self.i_layer = i_layer self.shift_size = shift_size self.window_size = window_size self.input_resolution = input_resolution self.conv = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) self.layer_norm1 = nn.LayerNorm(embed_dim, eps=layer_norm_epsilon) self.dec_enc_attn_wmsa = DecoderWindowAttention( dim=embed_dim, window_size=to_2tuple(window_size), num_heads=num_heads, ) self.layer_norm2 = nn.LayerNorm(embed_dim, eps=layer_norm_epsilon) self.dec_enc_attn_swmsa = DecoderWindowAttention( dim=embed_dim, window_size=to_2tuple(window_size), num_heads=num_heads, ) self.layer_norm3 = nn.LayerNorm(embed_dim, eps=layer_norm_epsilon) self.mlp = Mlp(in_features=embed_dim, hidden_features=dim_mlp, act_layer=act_layer, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) self.register_buffer("attn_mask", attn_mask) def partition(self, inputs, B, H, W, C, shift_size=0): # partition mask_dec_inputs inputs = inputs.view(B, H, W, C) if shift_size > 0: shifted_inputs = torch.roll(inputs, shifts=(-shift_size, -shift_size), dims=(1, 2)) else: shifted_inputs = inputs windows_inputs = window_partition(shifted_inputs, self.window_size) # nW*B, window_size, window_size, C windows_inputs = windows_inputs.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C return windows_inputs def reverse(self, inputs, B, H, W, C, shift_size=0): # merge windows inputs = inputs.view(-1, self.window_size, self.window_size, C) inputs = window_reverse(inputs, self.window_size, H, W) # B H' W' C # reverse cyclic shift if shift_size > 0: inputs = torch.roll(inputs, shifts=(shift_size, shift_size), dims=(1, 2)) else: inputs = inputs inputs = inputs.view(B, H * W, C) return inputs def forward(self, mask_dec_inputs, enc_outputs, self_attn_mask, dec_enc_attn_mask): H, W = self.input_resolution[0], self.input_resolution[1] B, L, C = mask_dec_inputs.shape assert L == H * W, "input feature has wrong size" dec_enc_att_outputs = mask_dec_inputs shortcut1 = dec_enc_att_outputs dec_enc_att_outputs = self.layer_norm1(dec_enc_att_outputs) enc_outputs = self.partition(enc_outputs, B, H, W, C, shift_size=0) dec_enc_att_outputs = self.partition(dec_enc_att_outputs, B, H, W, C, shift_size=0) dec_enc_att_outputs = self.dec_enc_attn_wmsa( q=dec_enc_att_outputs, k=enc_outputs, v=enc_outputs, mask=None, attn_mask=dec_enc_attn_mask) dec_enc_att_outputs = self.reverse(dec_enc_att_outputs, B, H, W, C, shift_size=0) enc_outputs = self.reverse(enc_outputs, B, H, W, C, shift_size=0) dec_enc_att_outputs = shortcut1 + self.drop_path(dec_enc_att_outputs) shortcut2 = dec_enc_att_outputs dec_enc_att_outputs = self.layer_norm2(dec_enc_att_outputs) enc_outputs = self.partition(enc_outputs, B, H, W, C, shift_size=self.window_size // 2) dec_enc_att_outputs = self.partition(dec_enc_att_outputs, B, H, W, C, shift_size=self.window_size // 2) dec_enc_att_outputs = self.dec_enc_attn_swmsa( q=dec_enc_att_outputs, k=enc_outputs, v=enc_outputs, mask=self.attn_mask, attn_mask=dec_enc_attn_mask) dec_enc_att_outputs = self.reverse(dec_enc_att_outputs, B, H, W, C, shift_size=self.window_size // 2) enc_outputs = self.reverse(enc_outputs, B, H, W, C, shift_size=self.window_size // 2) dec_enc_att_outputs = shortcut2 + self.drop_path(dec_enc_att_outputs) shortcut3 = dec_enc_att_outputs dec_enc_att_outputs = self.layer_norm3(dec_enc_att_outputs) dec_enc_att_outputs = self.mlp(dec_enc_att_outputs) dec_enc_att_outputs = shortcut3 + self.drop_path(dec_enc_att_outputs) dec_enc_att_outputs = rearrange( dec_enc_att_outputs, 'b (h w) c -> b c h w', h=self.input_resolution[0], w=self.input_resolution[1]) # if self.i_layer % 2 == 0: # dec_enc_att_outputs = self.conv(dec_enc_att_outputs) dec_enc_att_outputs = rearrange(dec_enc_att_outputs, 'b c h w -> b (h w) c') return dec_enc_att_outputs """ decoder """ class SwinDecoder(nn.Module): def __init__(self, input_resolution=(28, 28), embed_dim=256, num_heads=4, num_layers=2, drop=0.1, i_pad=0, dim_mlp=1024, window_size=7, drop_path_rate=0.1): super().__init__() self.window_size = window_size self.embed_dim = embed_dim self.input_resolution = input_resolution self.num_heads = num_heads self.i_pad = i_pad self.dropout = nn.Dropout(drop) self.layers = nn.ModuleList() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] for i_layer in range(num_layers): layer = DecoderLayer( input_resolution=(input_resolution[0], input_resolution[1]), embed_dim=embed_dim, dim_mlp=dim_mlp, window_size=window_size, i_layer=i_layer + 1, shift_size=window_size // 2, drop_path=dpr[i_layer]) self.layers.append(layer) def forward(self, y_embed, enc_outputs): inputs_embed = y_embed B, C, H, W = y_embed.shape inputs_embed = rearrange(inputs_embed, 'b c h w -> b (h w) c') dec_outputs = self.dropout(inputs_embed) idx = 1 down_rate = 0 for layer in self.layers: window_num = int((self.input_resolution[0] // 2 ** down_rate) // self.window_size) * \ int((self.input_resolution[1] // 2 ** down_rate) // self.window_size) enc_inputs_length = self.window_size * self.window_size dec_inputs_length = self.window_size * self.window_size mask_enc_inputs = torch.ones(B * window_num, enc_inputs_length).cuda() mask_dec_inputs = torch.ones(B * window_num, dec_inputs_length).cuda() dec_attn_pad_mask = get_attn_pad_mask(mask_dec_inputs, mask_dec_inputs, self.i_pad) dec_attn_decoder_mask = get_attn_decoder_mask(mask_dec_inputs) dec_self_attn_mask = torch.gt((dec_attn_pad_mask + dec_attn_decoder_mask), 0) dec_enc_attn_mask = get_attn_pad_mask(mask_dec_inputs, mask_enc_inputs, self.i_pad) dec_self_attn_mask = dec_self_attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) dec_enc_attn_mask = dec_enc_attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) dec_outputs = layer(dec_outputs, enc_outputs[idx - 1], dec_self_attn_mask, dec_enc_attn_mask) idx += 1 dec_outputs = rearrange( dec_outputs, 'b (h w) c -> b c h w', h=self.input_resolution[0] // 2**down_rate, w=self.input_resolution[1] // 2**down_rate) return dec_outputs class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops class SwinBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, dim_mlp=1024., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.dim_mlp = dim_mlp if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = self.dim_mlp self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size=7, dim_mlp=1024, qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.conv = nn.Conv2d(dim, dim, 3, 1, 1) self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, dim_mlp=dim_mlp, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth) ]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) x = rearrange(x, 'b (h w) c -> b c h w', h=self.input_resolution[0], w=self.input_resolution[1]) x = F.relu(self.conv(x)) x = rearrange(x, 'b c h w -> b (h w) c') return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class SwinTransformer(nn.Module): def __init__(self, patches_resolution, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], embed_dim=256, drop=0.1, drop_rate=0., drop_path_rate=0.1, dropout=0., window_size=7, dim_mlp=1024, qkv_bias=True, qk_scale=None, attn_drop_rate=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, **kwargs): super().__init__() self.embed_dim = embed_dim self.depths = depths self.num_heads = num_heads self.window_size = window_size self.pos_drop = nn.Dropout(p=drop_rate) self.dropout = nn.Dropout(p=drop) self.num_features = embed_dim self.num_layers = len(depths) self.patches_resolution = (patches_resolution[0], patches_resolution[1]) self.downsample = nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=3, stride=2, padding=1) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=self.embed_dim, input_resolution=patches_resolution, depth=self.depths[i_layer], num_heads=self.num_heads[i_layer], window_size=self.window_size, dim_mlp=dim_mlp, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=dropout, attn_drop=attn_drop_rate, drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])], norm_layer=norm_layer, downsample=downsample, use_checkpoint=use_checkpoint) self.layers.append(layer) def forward(self, x): x = self.dropout(x) x = self.pos_drop(x) x = rearrange(x, 'b c h w -> b (h w) c') for layer in self.layers: _x = x x = layer(x) x = 0.13 * x + _x x = rearrange(x, 'b (h w) c -> b c h w', h=self.patches_resolution[0], w=self.patches_resolution[1]) return x ================================================ FILE: scripts/metrics/MANIQA/pipal_data.py ================================================ import cv2 import numpy as np import os import torch class NTIRE2022(torch.utils.data.Dataset): def __init__(self, ref_path, dis_path, transform): super(NTIRE2022, self).__init__() self.ref_path = ref_path self.dis_path = dis_path self.transform = transform ref_files_data, dis_files_data = [], [] for dis in os.listdir(dis_path): ref = dis ref_files_data.append(ref) dis_files_data.append(dis) self.data_dict = {'r_img_list': ref_files_data, 'd_img_list': dis_files_data} def __len__(self): return len(self.data_dict['r_img_list']) def __getitem__(self, idx): # r_img: H x W x C -> C x H x W r_img_name = self.data_dict['r_img_list'][idx] r_img = cv2.imread(os.path.join(self.ref_path, r_img_name), cv2.IMREAD_COLOR) r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) r_img = np.array(r_img).astype('float32') / 255 r_img = np.transpose(r_img, (2, 0, 1)) d_img_name = self.data_dict['d_img_list'][idx] d_img = cv2.imread(os.path.join(self.dis_path, d_img_name), cv2.IMREAD_COLOR) d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB) d_img = np.array(d_img).astype('float32') / 255 d_img = np.transpose(d_img, (2, 0, 1)) sample = {'r_img_org': r_img, 'd_img_org': d_img, 'd_name': d_img_name} if self.transform: sample = self.transform(sample) return sample ================================================ FILE: scripts/metrics/MANIQA/utils.py ================================================ import numpy as np import torch def crop_image(top, left, patch_size, img=None): tmp_img = img[:, :, top:top + patch_size, left:left + patch_size] return tmp_img class RandCrop(object): def __init__(self, patch_size, num_crop): self.patch_size = patch_size self.num_crop = num_crop def __call__(self, sample): # r_img : C x H x W (numpy) r_img, d_img = sample['r_img_org'], sample['d_img_org'] d_name = sample['d_name'] c, h, w = d_img.shape new_h = self.patch_size new_w = self.patch_size ret_r_img = np.zeros((c, self.patch_size, self.patch_size)) ret_d_img = np.zeros((c, self.patch_size, self.patch_size)) for _ in range(self.num_crop): top = np.random.randint(0, h - new_h) left = np.random.randint(0, w - new_w) tmp_r_img = r_img[:, top:top + new_h, left:left + new_w] tmp_d_img = d_img[:, top:top + new_h, left:left + new_w] ret_r_img += tmp_r_img ret_d_img += tmp_d_img ret_r_img /= self.num_crop ret_d_img /= self.num_crop sample = {'r_img_org': ret_r_img, 'd_img_org': ret_d_img, 'd_name': d_name} return sample class Normalize(object): def __init__(self, mean, var): self.mean = mean self.var = var def __call__(self, sample): # r_img: C x H x W (numpy) r_img, d_img = sample['r_img_org'], sample['d_img_org'] d_name = sample['d_name'] r_img = (r_img - self.mean) / self.var d_img = (d_img - self.mean) / self.var sample = {'r_img_org': r_img, 'd_img_org': d_img, 'd_name': d_name} return sample class RandHorizontalFlip(object): def __init__(self): pass def __call__(self, sample): r_img, d_img = sample['r_img_org'], sample['d_img_org'] d_name = sample['d_name'] prob_lr = np.random.random() # np.fliplr needs HxWxC if prob_lr > 0.5: d_img = np.fliplr(d_img).copy() r_img = np.fliplr(r_img).copy() sample = {'r_img_org': r_img, 'd_img_org': d_img, 'd_name': d_name} return sample class ToTensor(object): def __init__(self): pass def __call__(self, sample): r_img, d_img = sample['r_img_org'], sample['d_img_org'] d_name = sample['d_name'] d_img = torch.from_numpy(d_img).type(torch.FloatTensor) r_img = torch.from_numpy(r_img).type(torch.FloatTensor) sample = {'r_img_org': r_img, 'd_img_org': d_img, 'd_name': d_name} return sample ================================================ FILE: scripts/metrics/README.md ================================================ # Instruction for calculating metrics ## Prepare the frames For fast evaluation, we measure 1 frames every 10 frames, that is, the 0-th frame, 10-th frame, 20-th frame, etc. will participate the metrics calculation. This can be achieved by `sample_interval` argument. ```bash python scripts/inference_animesr_frames.py -i AVC-RealLQ-ROOT -n AnimeSR_v1-PaperModel --expname animesr_v1_si10 --sample_interval 10 ``` ## MANIQA calculation ### requirements `pip install timm==0.5.4` ### checkpoint we use the ensemble model provided by the authors to compute MANIQA. You can download the checkpoint from [Ondrive](https://1drv.ms/u/s!Akkg8_btnkS7mU7eb_dFDHkW05B2?e=G922hL). ### inference: ```bash # cd into scripts/metrics/MANIQA python inference_MANIQA.py --model_path MANIQA_CKPT_YOU_JUST_DOWNLOADED --input_dir ../../../results/animesr_v1_si10/frames --output_dir output/ensemble_attentionIQA2_finetune_e2/AnimeSR_v1_si10 ``` note that the result has certain randomness, but the error should be relatively small. ### license the MANIQA codes&checkpoint are original from [MANIQA](https://github.com/IIGROUP/MANIQA) and @[TianheWu](https://github.com/TianheWu). ================================================ FILE: setup.cfg ================================================ [flake8] ignore = # line break before binary operator (W503) W503, # line break after binary operator (W504) W504, max-line-length=120 [yapf] based_on_style = pep8 column_limit = 120 blank_line_before_nested_class_or_def = true split_before_expression_after_opening_paren = true [isort] line_length = 120 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = basicsr,facexlib,animesr known_third_party = PIL,cv2,ffmpeg,numpy,psutil,torch,torchvision,tqdm no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY ================================================ FILE: setup.py ================================================ #!/usr/bin/env python from setuptools import find_packages, setup import os import subprocess import time version_file = 'animesr/version.py' def readme(): with open('README.md', encoding='utf-8') as f: content = f.read() return content def get_git_hash(): def _minimal_ext_cmd(cmd): # construct minimal environment env = {} for k in ['SYSTEMROOT', 'PATH', 'HOME']: v = os.environ.get(k) if v is not None: env[k] = v # LANGUAGE is used on win32 env['LANGUAGE'] = 'C' env['LANG'] = 'C' env['LC_ALL'] = 'C' out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] return out try: out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) sha = out.strip().decode('ascii') except OSError: sha = 'unknown' return sha def get_hash(): if os.path.exists('.git'): sha = get_git_hash()[:7] elif os.path.exists(version_file): try: from animesr.version import __version__ sha = __version__.split('+')[-1] except ImportError: raise ImportError('Unable to get git version') else: sha = 'unknown' return sha def write_version_py(): content = """# GENERATED VERSION FILE # TIME: {} __version__ = '{}' __gitsha__ = '{}' version_info = ({}) """ sha = get_hash() with open('VERSION', 'r') as f: SHORT_VERSION = f.read().strip() VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) with open(version_file, 'w') as f: f.write(version_file_str) def get_version(): with open(version_file, 'r') as f: exec(compile(f.read(), version_file, 'exec')) return locals()['__version__'] def get_requirements(filename='requirements.txt'): here = os.path.dirname(os.path.realpath(__file__)) with open(os.path.join(here, filename), 'r') as f: requires = [line.replace('\n', '') for line in f.readlines()] return requires if __name__ == '__main__': write_version_py() setup( name='animesr', version=get_version(), description='AnimeSR: Learning Real-World Super-Resolution Models for Animation Videos (NeurIPS 2022)', long_description=readme(), long_description_content_type='text/markdown', author='Yanze Wu', author_email='wuyanze123@gmail.com', keywords='computer vision, pytorch, image restoration, super-resolution', url='https://github.com/TencentARC/AnimeSR', include_package_data=True, packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), classifiers=[ 'Development Status :: 4 - Beta', 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', ], license='BSD-3-Clause License', setup_requires=['cython', 'numpy'], install_requires=get_requirements(), zip_safe=False)