[
  {
    "path": ".github/workflows/pylint.yml",
    "content": "name: PyLint\n\non: [push, pull_request]\n\njobs:\n  build:\n\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [3.8]\n\n    steps:\n    - uses: actions/checkout@v2\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v2\n      with:\n        python-version: ${{ matrix.python-version }}\n\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install flake8 yapf isort\n\n    - name: Lint\n      run: |\n        flake8 .\n        isort --check-only --diff animesr/ options/ scripts/ setup.py\n        yapf -r -d animesr/ options/ scripts/ setup.py\n"
  },
  {
    "path": ".gitignore",
    "content": "datasets/*\nexperiments/*\nresults/*\ntb_logger/*\nwandb/*\ntmp/*\nweights/*\ninputs/*\n\n*.DS_Store\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n.idea/\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  # flake8\n  - repo: https://github.com/PyCQA/flake8\n    rev: 3.7.9\n    hooks:\n      - id: flake8\n        args: [\"--config=setup.cfg\", \"--ignore=W504, W503\"]\n\n  # modify known_third_party\n  - repo: https://github.com/asottile/seed-isort-config\n    rev: v2.2.0\n    hooks:\n      - id: seed-isort-config\n        args: [\"--exclude=scripts/metrics/MANIQA\"]\n\n  # isort\n  - repo: https://github.com/timothycrosley/isort\n    rev: 5.2.2\n    hooks:\n      - id: isort\n\n  # yapf\n  - repo: https://github.com/pre-commit/mirrors-yapf\n    rev: v0.30.0\n    hooks:\n      - id: yapf\n\n  # pre-commit-hooks\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v3.2.0\n    hooks:\n      - id: trailing-whitespace  # Trim trailing whitespace\n      - id: check-yaml  # Attempt to load all yaml files to verify syntax\n      - id: check-merge-conflict  # Check for files that contain merge conflict strings\n      - id: end-of-file-fixer  # Make sure files end in a newline and only a newline\n      - id: requirements-txt-fixer  # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0\n      - id: mixed-line-ending  # Replace or check mixed line ending\n        args: [\"--fix=lf\"]\n"
  },
  {
    "path": "LICENSE",
    "content": "Tencent is pleased to support the open source community by making AnimeSR available.\n\nCopyright (C) 2022 THL A29 Limited, a Tencent company.  All rights reserved.\n\nAnimeSR is licensed under the Apache License Version 2.0 except for the third-party components listed below.\n\n\nTerms of the Apache License Version 2.0:\n---------------------------------------------\nApache License\n\nVersion 2.0, January 2004\n\nhttp://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n1. Definitions.\n\n“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.\n\n“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.\n\n“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.\n\n“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.\n\n“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.\n\n“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.\n\n“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).\n\n“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.\n\n“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”\n\n“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.\n\n2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.\n\n3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.\n\n4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:\n\nYou must give any other recipients of the Work or Derivative Works a copy of this License; and\n\nYou must cause any modified files to carry prominent notices stating that You changed the files; and\n\nYou must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and\n\nIf the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.\n\nYou may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.\n\n5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.\n\n6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.\n\n7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.\n\n8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.\n\n9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.\n\nEND OF TERMS AND CONDITIONS\n\n\n\nOther dependencies and licenses:\n\nOpen Source Software Licensed under the Apache License Version 2.0:\n--------------------------------------------------------------------\n1. ffmpeg-python\n Copyright 2017 Karl Kroening\n\n2. basicsr\nCopyright 2018-2022 BasicSR Authors\n\nTerms of the Apache License Version 2.0:\n--------------------------------------------------------------------\nApache License\n\nVersion 2.0, January 2004\n\nhttp://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n1. Definitions.\n\n\"License\" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.\n\n\"Licensor\" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.\n\n\"Legal Entity\" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, \"control\" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.\n\n\"You\" (or \"Your\") shall mean an individual or Legal Entity exercising permissions granted by this License.\n\n\"Source\" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.\n\n\"Object\" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.\n\n\"Work\" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).\n\n\"Derivative Works\" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.\n\n\"Contribution\" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, \"submitted\" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as \"Not a Contribution.\"\n\n\"Contributor\" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.\n\n2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.\n\n3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.\n\n4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:\n\nYou must give any other recipients of the Work or Derivative Works a copy of this License; and\n\nYou must cause any modified files to carry prominent notices stating that You changed the files; and\n\nYou must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and\n\nIf the Work includes a \"NOTICE\" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.\n\nYou may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.\n\n5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.\n\n6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.\n\n7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.\n\n8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.\n\n9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.\n\nEND OF TERMS AND CONDITIONS\n\n\nOpen Source Software Licensed under the BSD 3-Clause License:\n--------------------------------------------------------------------\n1. torch\nFrom PyTorch:\n\nCopyright (c) 2016-     Facebook, Inc            (Adam Paszke)\nCopyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\nCopyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\nCopyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\nCopyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\nCopyright (c) 2011-2013 NYU                      (Clement Farabet)\nCopyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\nCopyright (c) 2006      Idiap Research Institute (Samy Bengio)\nCopyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\nFrom Caffe2:\n\nCopyright (c) 2016-present, Facebook Inc. All rights reserved.\n\nAll contributions by Facebook:\nCopyright (c) 2016 Facebook Inc.\n\nAll contributions by Google:\nCopyright (c) 2015 Google Inc.\nAll rights reserved.\n\nAll contributions by Yangqing Jia:\nCopyright (c) 2015 Yangqing Jia\nAll rights reserved.\n\nAll contributions by Kakao Brain:\nCopyright 2019-2020 Kakao Brain\n\nAll contributions by Cruise LLC:\nCopyright (c) 2022 Cruise LLC.\nAll rights reserved.\n\nAll contributions from Caffe:\nCopyright(c) 2013, 2014, 2015, the respective contributors\nAll rights reserved.\n\nAll other contributions:\nCopyright(c) 2015, 2016 the respective contributors\nAll rights reserved.\n\nCaffe2 uses a copyright model similar to Caffe: each contributor holds\ncopyright over their contributions to Caffe2. The project versioning records\nall such contribution and copyright details. If a contributor wants to further\nmark their specific copyright on a particular contribution, they should\nindicate their copyright solely in the commit message of the change when it is\ncommitted.\n\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright\n   notice, this list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright\n   notice, this list of conditions and the following disclaimer in the\n   documentation and/or other materials provided with the distribution.\n\n3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n   and IDIAP Research Institute nor the names of its contributors may be\n   used to endorse or promote products derived from this software without\n   specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\nARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\nLIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\nCONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\nSUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\nINTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\nCONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\nARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\nPOSSIBILITY OF SUCH DAMAGE.\n\n2. torchvision\nCopyright (c) Soumith Chintala 2016,\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n3. numpy\nCopyright (c) 2005-2022, NumPy Developers.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are\nmet:\n\n    * Redistributions of source code must retain the above copyright\n       notice, this list of conditions and the following disclaimer.\n\n    * Redistributions in binary form must reproduce the above\n       copyright notice, this list of conditions and the following\n       disclaimer in the documentation and/or other materials provided\n       with the distribution.\n\n    * Neither the name of the NumPy Developers nor the names of any\n       contributors may be used to endorse or promote products derived\n       from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\nOWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n4. psutil\nCopyright (c) 2009, Jay Loden, Dave Daeschler, Giampaolo Rodola'\n\nRedistribution and use in source and binary forms, with or without modification,\nare permitted provided that the following conditions are met:\n\n * Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n * Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n * Neither the name of the psutil authors nor the names of its contributors\n   may be used to endorse or promote products derived from this software without\n   specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\nANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\nWARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\nANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\nLOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\nANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\nSOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nOpen Source Software Licensed under the HPND License:\n--------------------------------------------------------------------\n1. Pillow\n\nThe Python Imaging Library (PIL) is\n\n    Copyright © 1997-2011 by Secret Labs AB\n    Copyright © 1995-2011 by Fredrik Lundh\n\nPillow is the friendly PIL fork. It is\n\n    Copyright © 2010-2022 by Alex Clark and contributors\n\nLike PIL, Pillow is licensed under the open source HPND License:\n\nBy obtaining, using, and/or copying this software and/or its associated\ndocumentation, you agree that you have read, understood, and will comply\nwith the following terms and conditions:\n\nPermission to use, copy, modify, and distribute this software and its\nassociated documentation for any purpose and without fee is hereby granted,\nprovided that the above copyright notice appears in all copies, and that\nboth that copyright notice and this permission notice appear in supporting\ndocumentation, and that the name of Secret Labs AB or the author not be\nused in advertising or publicity pertaining to distribution of the software\nwithout specific, written prior permission.\n\nSECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS\nSOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS.\nIN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL,\nINDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM\nLOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE\nOR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR\nPERFORMANCE OF THIS SOFTWARE.\n\n\nOpen Source Software Licensed under the MIT License:\n--------------------------------------------------------------------\n1. opencv-python\nCopyright (c) Olli-Pekka Heinisuo\n\n\nTerms of the MIT License:\n--------------------------------------------------------------------\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n\n\nOpen Source Software Licensed under the MIT and MPL v2.0 Licenses:\n--------------------------------------------------------------------\n1. tqdm\n`tqdm` is a product of collaborative work.\nUnless otherwise stated, all authors (see commit logs) retain copyright\nfor their respective work, and release the work under the MIT licence\n(text below).\n\nExceptions or notable authors are listed below\nin reverse chronological order:\n\n* files: *\n  MPLv2.0 2015-2021 (c) Casper da Costa-Luis\n  [casperdcl](https://github.com/casperdcl).\n* files: tqdm/_tqdm.py\n  MIT 2016 (c) [PR #96] on behalf of Google Inc.\n* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore\n  MIT 2013 (c) Noam Yorav-Raphael, original author.\n\n[PR #96]: https://github.com/tqdm/tqdm/pull/96\n\n\nMozilla Public Licence (MPL) v. 2.0 - Exhibit A\n-----------------------------------------------\n\nThis Source Code Form is subject to the terms of the\nMozilla Public License, v. 2.0.\nIf a copy of the MPL was not distributed with this project,\nYou can obtain one at https://mozilla.org/MPL/2.0/.\n\n\nMIT License (MIT)\n-----------------\n\nCopyright (c) 2013 noamraph\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of\nthis software and associated documentation files (the \"Software\"), to deal in\nthe Software without restriction, including without limitation the rights to\nuse, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of\nthe Software, and to permit persons to whom the Software is furnished to do so,\nsubject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS\nFOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR\nCOPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER\nIN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN\nCONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# AnimeSR (NeurIPS 2022)\n\n### :open_book: AnimeSR: Learning Real-World Super-Resolution Models for Animation Videos\n> [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2206.07038)<br>\n> [Yanze Wu](https://github.com/ToTheBeginning), [Xintao Wang](https://xinntao.github.io/), [Gen Li](https://scholar.google.com.hk/citations?user=jBxlX7oAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en) <br>\n> [Tencent ARC Lab](https://arc.tencent.com/en/index); Platform Technologies, Tencent Online Video\n\n\n### :triangular_flag_on_post: Updates\n* **2022.11.28**: release codes&models.\n* **2022.08.29**: release AVC-Train and AVC-Test.\n\n\n## Web Demo and API\n\n[![Replicate](https://replicate.com/cjwbw/animesr/badge)](https://replicate.com/cjwbw/animesr) \n\n## Video Demos\n\nhttps://user-images.githubusercontent.com/11482921/204205018-d69e2e51-fbdc-4766-8293-a40ffce3ed25.mp4\n\nhttps://user-images.githubusercontent.com/11482921/204205109-35866094-fa7f-413b-8b43-bb479b42dfb6.mp4\n\n\n\n## :wrench: Dependencies and Installation\n- Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))\n- [PyTorch >= 1.7](https://pytorch.org/)\n- Other required packages in `requirements.txt`\n\n### Installation\n\n1. Clone repo\n\n    ```bash\n    git clone https://github.com/TencentARC/AnimeSR.git\n    cd AnimeSR\n    ```\n2. Install\n\n    ```bash\n    # Install dependent packages\n    pip install -r requirements.txt\n\n    # Install AnimeSR\n    python setup.py develop\n    ```\n\n## :zap: Quick Inference\nDownload the pre-trained AnimeSR models [[Google Drive](https://drive.google.com/drive/folders/1gwNTbKLUjt5FlgT6PQQnBz5wFzmNUX8g?usp=share_link)], and put them into the [weights](weights/) folder. Currently, the available pre-trained models are:\n- `AnimeSR_v1-PaperModel.pth`: v1 model, also the paper model. You can use this model for paper results reproducing.\n- `AnimeSR_v2.pth`: v2 model. Compare with v1, this version has better naturalness, fewer artifacts, and better texture/background restoration. If you want better results, use this model.\n\nAnimeSR supports both frames and videos as input for inference. We provide several sample test cases in [google drive](https://drive.google.com/drive/folders/1K8JzGNqY_pHahYBv51iUUI7_hZXmamt2?usp=share_link), you can download it and put them to [inputs](inputs/) folder.\n\n**Inference on Frames**\n```bash\npython scripts/inference_animesr_frames.py -i inputs/tom_and_jerry -n AnimeSR_v2 --expname animesr_v2 --save_video_too --fps 20\n```\n```console\nUsage:\n  -i --input           Input frames folder/root. Support first level dir (i.e., input/*.png) and second level dir (i.e., input/*/*.png)\n  -n --model_name      AnimeSR model name. Default: AnimeSR_v2, can also be AnimeSR_v1-PaperModel\n  -s --outscale        The netscale is x4, but you can achieve arbitrary output scale (e.g., x2 or x1) with the argument outscale.\n                       The program will further perform cheap resize operation after the AnimeSR output. Default: 4\n  -o --output          Output root. Default: results\n  -expname             Identify the name of your current inference. The outputs will be saved in $output/$expname\n  -save_video_too      Save the output frames to video. Default: off\n  -fps                 The fps of the (possible) saved videos. Default: 24\n```\nAfter run the above command, you will get the SR frames in `results/animesr_v2/frames` and the SR video in `results/animesr_v2/videos`.\n\n**Inference on Video**\n```bash\n# single gpu and single process inference\nCUDA_VISIBLE_DEVICES=0 python scripts/inference_animesr_video.py -i inputs/TheMonkeyKing1965.mp4 -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 1 --suffix 1gpu1process\n# single gpu and multi process inference (you can use multi-processing to improve GPU utilization)\nCUDA_VISIBLE_DEVICES=0 python scripts/inference_animesr_video.py -i inputs/TheMonkeyKing1965.mp4 -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 3 --suffix 1gpu3process\n# multi gpu and multi process inference\nCUDA_VISIBLE_DEVICES=0,1 python scripts/inference_animesr_video.py -i inputs/TheMonkeyKing1965.mp4 -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 3 --suffix 2gpu6process\n```\n```console\nUsage:\n  -i --input           Input video path or extracted frames folder\n  -n --model_name      AnimeSR model name. Default: AnimeSR_v2, can also be AnimeSR_v1-PaperModel\n  -s --outscale        The netscale is x4, but you can achieve arbitrary output scale (e.g., x2 or x1) with the argument outscale.\n                       The program will further perform cheap resize operation after the AnimeSR output. Default: 4\n  -o -output           Output root. Default: results\n  -expname             Identify the name of your current inference. The outputs will be saved in $output/$expname\n  -fps                 The fps of the (possible) saved videos. Default: None\n  -extract_frame_first If input is a video, you can still extract the frames first, other wise AnimeSR will read from stream\n  -num_process_per_gpu Since the slow I/O speed will make GPU utilization not high enough, so as long as the\n                       video memory is sufficient, we recommend placing multiple processes on one GPU to increase the utilization of each GPU.\n                       The total process will be number_process_per_gpu * num_gpu\n  -suffix              You can add a suffix string to the sr video name, for example, 1gpu3processx2 which means the SR video is generated with one GPU and three process and the outscale is x2\n  -half                Use half precision for inference, it won't make big impact on the visual results\n```\nSR videos are saved in `results/animesr_v2/videos/$video_name` folder.\n\nIf you are looking for portable executable files, you can try our [realesr-animevideov3](https://github.com/xinntao/Real-ESRGAN/blob/master/docs/anime_video_model.md) model which shares the similar technology with AnimeSR.\n\n\n\n## :computer: Training\nSee [Training.md](Training.md)\n\n## Request for AVC-Dataset\n1. Download and carefully read the [LICENSE AGREEMENT](assets/LICENSE%20AGREEMENT.pdf) PDF file.\n2. If you understand, acknowledge, and agree to all the terms specified in the [LICENSE AGREEMENT](assets/LICENSE%20AGREEMENT.pdf). Please email `wuyanze123@gmail.com` with the **LICENSE AGREEMENT PDF** file, **your name**, and **institution**. We will keep the license and send the download link of AVC dataset to you.\n\n\n## Acknowledgement\nThis project is build based on [BasicSR](https://github.com/XPixelGroup/BasicSR).\n\n##  Citation\nIf you find this project useful for your research, please consider citing our paper:\n```bibtex\n@InProceedings{wu2022animesr,\n  author={Wu, Yanze and Wang, Xintao and Li, Gen and Shan, Ying},\n  title={AnimeSR: Learning Real-World Super-Resolution Models for Animation Videos},\n  booktitle={Advances in Neural Information Processing Systems},\n  year={2022}\n}\n```\n\n## :e-mail: Contact\nIf you have any question, please email `wuyanze123@gmail.com`.\n"
  },
  {
    "path": "Training.md",
    "content": "# :computer: How to Train AnimeSR\n\n- [Overview](#overview)\n- [Dataset Preparation](#dataset-preparation)\n- [Training](#training)\n  - [Training step 1](#training-step-1)\n  - [Training step 2](#training-step-2)\n  - [Training step 3](#training-step-3)\n- [The Pre-Trained Checkpoints](#the-pre-trained-checkpoints)\n- [Other Tips](#other-tips)\n    - [How to build your own (training) dataset ？](#how-to-build-your-own-training-dataset-)\n\n\n## Overview\nThe training has been divided into three steps.\n1. Training a Video Super-Resolution (VSR) model with a degradation model that only contains the classic basic operators (*i.e.*, blur, noise, downscale, compression).\n2. Training several **L**earnable **B**asic **O**perators (**LBO**s). Using the VSR model from step 1 and the input-rescaling strategy to generate pseudo HR for real-world LR. The paired pseudo HR-LR data are used to train the LBO in a supervised manner.\n3. Training the final VSR model with a degradation model containing both classic basic operators and learnable basic operators.\n\nSpecifically, the model training in each step consists of two stages. In the first stage, the model is trained with L1 loss from scratch. In the second stage, the model is fine-tuned with the combination of L1 loss, perceptual loss, and GAN loss.\n\n## Dataset Preparation\nWe use AVC-Train dataset for our training. The AVC dataset is released under request, please refer to [Request for AVC-Dataset](README.md#request-for-avc-dataset).\nAfter you download the AVC-Train dataset, put all the clips into one folder (dataset root). The dataset root should contain 553 folders(clips), each folder contains 100 frames, from `00000000.png` to `00000099.png`.\n\nIf you want to build your own (training) dataset or enlarge AVC-Train dataset, please refer to [How to build your own (training) dataset](#how-to-build-your-own-training-dataset).\n\n## Training\nAs described in the paper, all the training is performed on four NVIDIA A100 GPUs in an internal cluster. You may need to adjust the batchsize according to the CUDA memory of your GPU card.\n\nBefore the training, you should modify the [option files](options/) accordingly. For example, you should modify the `dataroot_gt` to your own dataset root. We have comment all the lines you should modify with the `TO_MODIFY` tag.\n### Training step 1\n1. Train `Net` model\n\n   Before the training, you should modify the [yaml file](options/train_animesr_step1_net_BasicOPonly.yml) accordingly. For example, you should modify the `dataroot_gt` to your own dataset root.\n   ```bash\n   CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step1_net_BasicOPonly.yml --launcher pytorch --auto_resume\n   ```\n2. Train `GAN` model\n\n   The GAN model is fine-tuned from the `Net` model, as specified in `pretrain_network_g`\n   ```bash\n   CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step1_gan_BasicOPonly.yml --launcher pytorch --auto_resume\n   ```\n\n### Training step 2\nThe input frames for training LBO in the paper are included in the AVC dataset download link we sent you. These frames came from three real-world LR animation videos, and ~2,000 frames are selected from each video.\n\nIn order to obtain the paired data required for training LBO, you will need to use the VSR model obtained in step 1 and the input-rescaling strategy as described in the paper to process these input frames to obtain pseudo HR-LR paired data.\n```bash\n# make the soft link for the VSR model obtained in step 1\nln -s experiments/train_animesr_step1_net_BasicOPonly/models/net_g_300000.pth weights/step1_vsr_gan_model.pth\n# using input-rescaling strategy to inference\npython scripts/inference_animesr_frames.py -i datasets/lbo_training_data/real_world_video_to_train_lbo_1 -n step1_vsr_gan_model --input_rescaling_factor 0.5 --mod_scale 8 --expname input_rescaling_strategy_lbo_1\n```\nAfter the inference, you can train the LBO. Note that we only provide one [option file](options/train_animesr_step2_lbo_1_net.yml) for training `Net` model and one [option file](options/train_animesr_step2_lbo_1_gan.yml) for training `GAN` model. If you want to train multiple LBOs, just copy&paste those option files and modify the `name`, `dataroot_gt`, and `dataroot_lq`.\n```bash\n# train Net model\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step2_lbo_1_net.yml --launcher pytorch --auto_resume\n# train GAN model\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step2_lbo_1_gan.yml --launcher pytorch --auto_resume\n```\n\n\n### Training step 3\nBefore the training, you will need to modify the `degradation_model_path` to the pre-trained LBO path.\n```bash\nCUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 realanimevsr/train.py -opt options/train_animesr_step3_gan_3LBOs.yml --launcher pytorch --auto_resume\n```\n\n## Evaluation\nSee [evaluation readme](scripts/metrics/README.md).\n\n## The Pre-Trained Checkpoints\nYou can download the checkpoints of all steps in [google drive](https://drive.google.com/drive/folders/1hCXhKNZYBADXsS_weHO2z3HhNE-Eg_jw?usp=share_link).\n\n\n## Other Tips\n#### How to build your own (training) dataset ？\nSuppose you have a batch of HQ (high resolution, high bitrate, high quality) animation video, we provide the [anime_videos_preprocessing.py](scripts/anime_videos_preprocessing.py) script to help you to prepare training clips from the raw videos.\n\nThe preprocessing consists of 6 steps:\n1. use FFmpeg to extract frames. Note that this step will take up a lot of disk space.\n2. shot detection using [PySceneDetect](https://github.com/Breakthrough/PySceneDetect/)\n3. flow estimation using spynet\n4. black frames detection\n5. image quality assessment using hyperIQA\n6. generate clips for each video\n\n```console\nUsage: python scripts/anime_videos_preprocessing.py --dataroot datasets/YOUR_OWN_ANIME --n_thread 4 --run 1\n  --dataroot           dataset root, dataroot/raw_videos should contains your HQ videos to be processed\n  --n_thread           number of workers to process in parallel\n  --run                which step to run. Since each step may take a long time, we recommend performing it step by step.\n                       And after each step, check whether the output files are as expected\n  --n_frames_per_clip  number of frames per clip. Default 100. You can increase the number if you want more training data\n  --n_clips_per_video  number of clips per video. Default 1.  You can increase the number if you want more training data\n```\nAfter you finish all the steps, you will get the clips in `dataroot/select_clips`\n"
  },
  {
    "path": "VERSION",
    "content": "0.1.0\n"
  },
  {
    "path": "animesr/__init__.py",
    "content": "# flake8: noqa\nfrom .archs import *\nfrom .data import *\nfrom .models import *\n\n# from .version import __gitsha__, __version__\n"
  },
  {
    "path": "animesr/archs/__init__.py",
    "content": "import importlib\nfrom os import path as osp\n\nfrom basicsr.utils import scandir\n\n# automatically scan and import arch modules for registry\n# scan all the files that end with '_arch.py' under the archs folder\narch_folder = osp.dirname(osp.abspath(__file__))\narch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]\n# import all the arch modules\n_arch_modules = [importlib.import_module(f'animesr.archs.{file_name}') for file_name in arch_filenames]\n"
  },
  {
    "path": "animesr/archs/discriminator_arch.py",
    "content": "import functools\nfrom torch import nn as nn\nfrom torch.nn import functional as F\nfrom torch.nn.utils import spectral_norm\n\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\ndef get_conv_layer(input_nc, ndf, kernel_size, stride, padding, bias=True, use_sn=False):\n    if not use_sn:\n        return nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)\n    return spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))\n\n\n@ARCH_REGISTRY.register()\nclass UNetDiscriminatorSN(nn.Module):\n    \"\"\"Defines a U-Net discriminator with spectral normalization (SN). copy from real-esrgan\"\"\"\n\n    def __init__(self, num_in_ch, num_feat=64, skip_connection=True):\n        super(UNetDiscriminatorSN, self).__init__()\n        self.skip_connection = skip_connection\n        norm = spectral_norm\n\n        self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)\n\n        self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))\n        self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))\n        self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))\n        # upsample\n        self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))\n        self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))\n        self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))\n\n        # extra\n        self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))\n        self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))\n\n        self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)\n\n    def forward(self, x):\n        x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)\n        x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)\n        x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)\n        x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)\n\n        # upsample\n        x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)\n        x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)\n\n        if self.skip_connection:\n            x4 = x4 + x2\n        x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)\n        x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)\n\n        if self.skip_connection:\n            x5 = x5 + x1\n        x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)\n        x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)\n\n        if self.skip_connection:\n            x6 = x6 + x0\n\n        # extra\n        out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)\n        out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)\n        out = self.conv9(out)\n\n        return out\n\n\n@ARCH_REGISTRY.register()\nclass PatchDiscriminator(nn.Module):\n    \"\"\"Defines a PatchGAN discriminator, the receptive field of default config is 70x70.\n\n    Args:\n        use_sn (bool): Use spectra_norm or not, if use_sn is True, then norm_type should be none.\n    \"\"\"\n\n    def __init__(self,\n                 num_in_ch,\n                 num_feat=64,\n                 num_layers=3,\n                 max_nf_mult=8,\n                 norm_type='batch',\n                 use_sigmoid=False,\n                 use_sn=False):\n        super(PatchDiscriminator, self).__init__()\n\n        norm_layer = self._get_norm_layer(norm_type)\n        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters\n            use_bias = norm_layer.func != nn.BatchNorm2d\n        else:\n            use_bias = norm_layer != nn.BatchNorm2d\n\n        kw = 4\n        padw = 1\n        sequence = [\n            get_conv_layer(num_in_ch, num_feat, kernel_size=kw, stride=2, padding=padw, use_sn=use_sn),\n            nn.LeakyReLU(0.2, True)\n        ]\n        nf_mult = 1\n        nf_mult_prev = 1\n        for n in range(1, num_layers):  # gradually increase the number of filters\n            nf_mult_prev = nf_mult\n            nf_mult = min(2**n, max_nf_mult)\n            sequence += [\n                get_conv_layer(\n                    num_feat * nf_mult_prev,\n                    num_feat * nf_mult,\n                    kernel_size=kw,\n                    stride=2,\n                    padding=padw,\n                    bias=use_bias,\n                    use_sn=use_sn),\n                norm_layer(num_feat * nf_mult),\n                nn.LeakyReLU(0.2, True)\n            ]\n\n        nf_mult_prev = nf_mult\n        nf_mult = min(2**num_layers, max_nf_mult)\n        sequence += [\n            get_conv_layer(\n                num_feat * nf_mult_prev,\n                num_feat * nf_mult,\n                kernel_size=kw,\n                stride=1,\n                padding=padw,\n                bias=use_bias,\n                use_sn=use_sn),\n            norm_layer(num_feat * nf_mult),\n            nn.LeakyReLU(0.2, True)\n        ]\n\n        # output 1 channel prediction map\n        sequence += [get_conv_layer(num_feat * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, use_sn=use_sn)]\n\n        if use_sigmoid:\n            sequence += [nn.Sigmoid()]\n        self.model = nn.Sequential(*sequence)\n\n    def _get_norm_layer(self, norm_type='batch'):\n        if norm_type == 'batch':\n            norm_layer = functools.partial(nn.BatchNorm2d, affine=True)\n        elif norm_type == 'instance':\n            norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)\n        elif norm_type == 'batchnorm2d':\n            norm_layer = nn.BatchNorm2d\n        elif norm_type == 'none':\n            norm_layer = nn.Identity\n        else:\n            raise NotImplementedError(f'normalization layer [{norm_type}] is not found')\n\n        return norm_layer\n\n    def forward(self, x):\n        return self.model(x)\n\n\n@ARCH_REGISTRY.register()\nclass MultiScaleDiscriminator(nn.Module):\n    \"\"\"Define a multi-scale discriminator, each discriminator is a instance of PatchDiscriminator.\n\n    Args:\n        num_layers (int or list): If the type of this variable is int, then degrade to PatchDiscriminator.\n                                  If the type of this variable is list, then the length of the list is\n                                  the number of discriminators.\n        use_downscale (bool): Progressive downscale the input to feed into different discriminators.\n                              If set to True, then the discriminators are usually the same.\n    \"\"\"\n\n    def __init__(self,\n                 num_in_ch,\n                 num_feat=64,\n                 num_layers=3,\n                 max_nf_mult=8,\n                 norm_type='batch',\n                 use_sigmoid=False,\n                 use_sn=False,\n                 use_downscale=False):\n        super(MultiScaleDiscriminator, self).__init__()\n\n        if isinstance(num_layers, int):\n            num_layers = [num_layers]\n\n        # check whether the discriminators are the same\n        if use_downscale:\n            assert len(set(num_layers)) == 1\n        self.use_downscale = use_downscale\n\n        self.num_dis = len(num_layers)\n        self.dis_list = nn.ModuleList()\n        for nl in num_layers:\n            self.dis_list.append(\n                PatchDiscriminator(\n                    num_in_ch,\n                    num_feat=num_feat,\n                    num_layers=nl,\n                    max_nf_mult=max_nf_mult,\n                    norm_type=norm_type,\n                    use_sigmoid=use_sigmoid,\n                    use_sn=use_sn,\n                ))\n\n    def forward(self, x):\n        outs = []\n        h, w = x.size()[2:]\n\n        y = x\n        for i in range(self.num_dis):\n            if i != 0 and self.use_downscale:\n                y = F.interpolate(y, size=(h // 2, w // 2), mode='bilinear', align_corners=True)\n                h, w = y.size()[2:]\n            outs.append(self.dis_list[i](y))\n\n        return outs\n"
  },
  {
    "path": "animesr/archs/simple_degradation_arch.py",
    "content": "from torch import nn as nn\n\nfrom basicsr.archs.arch_util import default_init_weights, pixel_unshuffle\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\n@ARCH_REGISTRY.register()\nclass SimpleDegradationArch(nn.Module):\n    \"\"\"simple degradation architecture which consists several conv and non-linear layer\n    it learns the mapping from HR to LR\n    \"\"\"\n\n    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, downscale=2):\n        \"\"\"\n        :param num_in_ch: input is a pseudo HR image, channel is 3\n        :param num_out_ch: output is an LR image, channel is also 3\n        :param num_feat: we use a small network, hidden dimension is 64\n        :param downscale: suppose (h, w) is the height&width of a real-world LR video.\n                          Firstly, we select the best rescaling factor (usually around 0.5) for this LR video.\n                          Secondly, we obtain the pseudo HR frames and resize them to (2h, 2w).\n                          To learn the mapping from pseudo HR to LR, LBO contains a pixel-unshuffle layer with\n                          a scale factor of 2 to perform the downsampling at the beginning.\n        \"\"\"\n        super(SimpleDegradationArch, self).__init__()\n        num_in_ch = num_in_ch * downscale * downscale\n        self.main = nn.Sequential(\n            nn.Conv2d(num_in_ch, num_feat, 3, 1, 1),\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(num_feat, num_feat, 3, 1, 1),\n            nn.LeakyReLU(0.2, inplace=True),\n            nn.Conv2d(num_feat, num_out_ch, 3, 1, 1),\n        )\n        self.downscale = downscale\n\n        default_init_weights(self.main)\n\n    def forward(self, x):\n        x = pixel_unshuffle(x, self.downscale)\n        x = self.main(x)\n        return x\n"
  },
  {
    "path": "animesr/archs/vsr_arch.py",
    "content": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.archs.arch_util import ResidualBlockNoBN, pixel_unshuffle\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\nclass RightAlignMSConvResidualBlocks(nn.Module):\n    \"\"\"right align multi-scale ConvResidualBlocks, currently only support 3 scales (1, 2, 4)\"\"\"\n\n    def __init__(self, num_in_ch=3, num_state_ch=64, num_out_ch=64, num_block=(5, 3, 2)):\n        super().__init__()\n\n        assert len(num_block) == 3\n        assert num_block[0] >= num_block[1] >= num_block[2]\n        self.num_block = num_block\n\n        self.conv_s1_first = nn.Sequential(\n            nn.Conv2d(num_in_ch, num_state_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True))\n        self.conv_s2_first = nn.Sequential(\n            nn.Conv2d(num_state_ch, num_state_ch, 3, 2, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True))\n        self.conv_s4_first = nn.Sequential(\n            nn.Conv2d(num_state_ch, num_state_ch, 3, 2, 1, bias=True),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n        )\n\n        self.body_s1_first = nn.ModuleList()\n        for _ in range(num_block[0]):\n            self.body_s1_first.append(ResidualBlockNoBN(num_feat=num_state_ch))\n        self.body_s2_first = nn.ModuleList()\n        for _ in range(num_block[1]):\n            self.body_s2_first.append(ResidualBlockNoBN(num_feat=num_state_ch))\n        self.body_s4_first = nn.ModuleList()\n        for _ in range(num_block[2]):\n            self.body_s4_first.append(ResidualBlockNoBN(num_feat=num_state_ch))\n\n        self.upsample_x2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n        self.upsample_x4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)\n\n        self.fusion = nn.Sequential(\n            nn.Conv2d(3 * num_state_ch, 2 * num_out_ch, 3, 1, 1, bias=True),\n            nn.LeakyReLU(negative_slope=0.1, inplace=True),\n            nn.Conv2d(2 * num_out_ch, num_out_ch, 3, 1, 1, bias=True),\n        )\n\n    def up(self, x, scale=2):\n        if isinstance(x, int):\n            return x\n        elif scale == 2:\n            return self.upsample_x2(x)\n        else:\n            return self.upsample_x4(x)\n\n    def forward(self, x):\n        x_s1 = self.conv_s1_first(x)\n        x_s2 = self.conv_s2_first(x_s1)\n        x_s4 = self.conv_s4_first(x_s2)\n\n        flag_s2 = False\n        flag_s4 = False\n        for i in range(0, self.num_block[0]):\n            x_s1 = self.body_s1_first[i](\n                x_s1 + (self.up(x_s2, 2) if flag_s2 else 0) + (self.up(x_s4, 4) if flag_s4 else 0))\n            if i >= self.num_block[0] - self.num_block[1]:\n                x_s2 = self.body_s2_first[i - self.num_block[0] + self.num_block[1]](\n                    x_s2 + (self.up(x_s4, 2) if flag_s4 else 0))\n                flag_s2 = True\n            if i >= self.num_block[0] - self.num_block[2]:\n                x_s4 = self.body_s4_first[i - self.num_block[0] + self.num_block[2]](x_s4)\n                flag_s4 = True\n\n        x_fusion = self.fusion(torch.cat((x_s1, self.upsample_x2(x_s2), self.upsample_x4(x_s4)), dim=1))\n\n        return x_fusion\n\n\n@ARCH_REGISTRY.register()\nclass MSRSWVSR(nn.Module):\n    \"\"\"\n    Multi-Scale, unidirectional Recurrent, Sliding Window (MSRSW)\n    The implementation refers to paper: Efficient Video Super-Resolution through Recurrent Latent Space Propagation\n    \"\"\"\n\n    def __init__(self, num_feat=64, num_block=(5, 3, 2), netscale=4):\n        super(MSRSWVSR, self).__init__()\n        self.num_feat = num_feat\n\n        # 3(img channel) * 3(prev cur nxt 3 imgs) + 3(hr img channel) * netscale * netscale + num_feat\n        self.recurrent_cell = RightAlignMSConvResidualBlocks(3 * 3 + 3 * netscale * netscale + num_feat, num_feat,\n                                                             num_feat + 3 * netscale * netscale, num_block)\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1)\n        self.pixel_shuffle = nn.PixelShuffle(netscale)\n        self.netscale = netscale\n\n    def cell(self, x, fb, state):\n        res = x[:, 3:6]\n        # pre frame, cur frame, nxt frame, pre sr frame, pre hidden state\n        inp = torch.cat((x, pixel_unshuffle(fb, self.netscale), state), dim=1)\n        # the out contains both state and sr frame\n        out = self.recurrent_cell(inp)\n        out_img = self.pixel_shuffle(out[:, :3 * self.netscale * self.netscale]) + F.interpolate(\n            res, scale_factor=self.netscale, mode='bilinear', align_corners=False)\n        out_state = self.lrelu(out[:, 3 * self.netscale * self.netscale:])\n\n        return out_img, out_state\n\n    def forward(self, x):\n        b, n, c, h, w = x.size()\n        # initialize previous sr frame and previous hidden state as zero tensor\n        out = x.new_zeros(b, c, h * self.netscale, w * self.netscale)\n        state = x.new_zeros(b, self.num_feat, h, w)\n        out_l = []\n        for i in range(n):\n            if i == 0:\n                # there is no previous frame for the 1st frame, so reuse 1st frame as previous\n                out, state = self.cell(torch.cat((x[:, i], x[:, i], x[:, i + 1]), dim=1), out, state)\n            elif i == n - 1:\n                # there is no next frame for the last frame, so reuse last frame as next\n                out, state = self.cell(torch.cat((x[:, i - 1], x[:, i], x[:, i]), dim=1), out, state)\n            else:\n                out, state = self.cell(torch.cat((x[:, i - 1], x[:, i], x[:, i + 1]), dim=1), out, state)\n            out_l.append(out)\n\n        return torch.stack(out_l, dim=1)\n"
  },
  {
    "path": "animesr/data/__init__.py",
    "content": "import importlib\nfrom os import path as osp\n\nfrom basicsr.utils import scandir\n\n# automatically scan and import dataset modules for registry\n# scan all the files that end with '_dataset.py' under the data folder\ndata_folder = osp.dirname(osp.abspath(__file__))\ndataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]\n# import all the dataset modules\n_dataset_modules = [importlib.import_module(f'animesr.data.{file_name}') for file_name in dataset_filenames]\n"
  },
  {
    "path": "animesr/data/data_utils.py",
    "content": "import random\nimport torch\n\n\ndef random_crop(imgs, patch_size, top=None, left=None):\n    \"\"\"\n    randomly crop patches from imgs\n    :param imgs: can be (list of) tensor, cv2 img\n    :param patch_size: patch size, usually 256\n    :param top: will sample if is None\n    :param left: will sample if is None\n    :return: cropped patches from input imgs\n    \"\"\"\n    if not isinstance(imgs, list):\n        imgs = [imgs]\n\n    # determine input type: Numpy array or Tensor\n    input_type = 'Tensor' if torch.is_tensor(imgs[0]) else 'Numpy'\n\n    if input_type == 'Tensor':\n        h, w = imgs[0].size()[-2:]\n    else:\n        h, w = imgs[0].shape[0:2]\n\n    # randomly choose top and left coordinates\n    if top is None:\n        top = random.randint(0, h - patch_size)\n    if left is None:\n        left = random.randint(0, w - patch_size)\n\n    if input_type == 'Tensor':\n        imgs = [v[:, :, top:top + patch_size, left:left + patch_size] for v in imgs]\n    else:\n        imgs = [v[top:top + patch_size, left:left + patch_size, ...] for v in imgs]\n    if len(imgs) == 1:\n        imgs = imgs[0]\n    return imgs\n"
  },
  {
    "path": "animesr/data/ffmpeg_anime_dataset.py",
    "content": "import cv2\nimport ffmpeg\nimport glob\nimport numpy as np\nimport os\nimport random\nimport torch\nfrom os import path as osp\nfrom torch.utils import data as data\n\nfrom basicsr.data.degradations import random_add_gaussian_noise, random_mixed_kernels\nfrom basicsr.data.transforms import augment\nfrom basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor\nfrom basicsr.utils.registry import DATASET_REGISTRY\nfrom .data_utils import random_crop\n\n\n@DATASET_REGISTRY.register()\nclass FFMPEGAnimeDataset(data.Dataset):\n    \"\"\"Anime datasets with only classic basic operators\"\"\"\n\n    def __init__(self, opt):\n        super(FFMPEGAnimeDataset, self).__init__()\n        self.opt = opt\n        self.num_frame = opt['num_frame']\n        self.num_half_frames = opt['num_frame'] // 2\n\n        self.keys = []\n        self.clip_frames = {}\n\n        self.gt_root = opt['dataroot_gt']\n\n        logger = get_root_logger()\n\n        clip_names = os.listdir(self.gt_root)\n        for clip_name in clip_names:\n            num_frames = len(glob.glob(osp.join(self.gt_root, clip_name, '*.png')))\n            self.keys.extend([f'{clip_name}/{i:08d}' for i in range(num_frames)])\n            self.clip_frames[clip_name] = num_frames\n\n        # file client (io backend)\n        self.file_client = None\n        self.io_backend_opt = opt['io_backend']\n        self.is_lmdb = False\n\n        self.iso_blur_range = opt.get('iso_blur_range', [0.2, 4])\n        self.aniso_blur_range = opt.get('aniso_blur_range', [0.8, 3])\n        self.noise_range = opt.get('noise_range', [0, 10])\n        self.crf_range = opt.get('crf_range', [18, 35])\n        self.ffmpeg_profile_names = opt.get('ffmpeg_profile_names', ['baseline', 'main', 'high'])\n        self.ffmpeg_profile_probs = opt.get('ffmpeg_profile_probs', [0.1, 0.2, 0.7])\n\n        self.scale = opt.get('scale', 4)\n        assert self.scale in (2, 4)\n\n        # temporal augmentation configs\n        self.interval_list = opt.get('interval_list', [1])\n        self.random_reverse = opt.get('random_reverse', False)\n        interval_str = ','.join(str(x) for x in self.interval_list)\n        logger.info(f'Temporal augmentation interval list: [{interval_str}]; '\n                    f'random reverse is {self.random_reverse}.')\n\n    def get_gt_clip(self, index):\n        \"\"\"\n        get the GT(hr) clip with self.num_frame frames\n        :param index: the index from __getitem__\n        :return: a list of images, with numpy(cv2) format\n        \"\"\"\n        key = self.keys[index]  # get clip from this key frame (if possible)\n        clip_name, frame_name = key.split('/')  # key example: 000/00000000\n\n        # determine the \"interval\" of neighboring frames\n        interval = random.choice(self.interval_list)\n\n        # ensure not exceeding the borders\n        center_frame_idx = int(frame_name)\n        start_frame_idx = center_frame_idx - self.num_half_frames * interval\n        end_frame_idx = center_frame_idx + self.num_half_frames * interval\n\n        # if the index doesn't satisfy the requirement, resample it\n        if (start_frame_idx < 0) or (end_frame_idx >= self.clip_frames[clip_name]):\n            center_frame_idx = random.randint(self.num_half_frames * interval,\n                                              self.clip_frames[clip_name] - 1 - self.num_half_frames * interval)\n            start_frame_idx = center_frame_idx - self.num_half_frames * interval\n            end_frame_idx = center_frame_idx + self.num_half_frames * interval\n\n        # determine the neighbor frames\n        neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))\n\n        # random reverse\n        if self.random_reverse and random.random() < 0.5:\n            neighbor_list.reverse()\n\n        # get the neighboring GT frames\n        img_gts = []\n        for neighbor in neighbor_list:\n            if self.is_lmdb:\n                img_gt_path = f'{clip_name}/{neighbor:08d}'\n            else:\n                img_gt_path = osp.join(self.gt_root, clip_name, f'{neighbor:08d}.png')\n\n            # get GT\n            img_bytes = self.file_client.get(img_gt_path, 'gt')\n            img_gt = imfrombytes(img_bytes, float32=True)\n            img_gts.append(img_gt)\n\n        # random crop\n        img_gts = random_crop(img_gts, self.opt['gt_size'])\n        # augmentation\n        img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot'])\n\n        return img_gts\n\n    def add_ffmpeg_compression(self, img_lqs, width, height):\n        # ffmpeg\n        loglevel = 'error'\n        format = 'h264'\n        fps = random.choices([24, 25, 30, 50, 60], [0.2, 0.2, 0.2, 0.2, 0.2])[0]  # still have problems\n        fps = 25\n        crf = np.random.uniform(self.crf_range[0], self.crf_range[1])\n\n        try:\n            extra_args = dict()\n            if format == 'h264':\n                vcodec = 'libx264'\n                profile = random.choices(self.ffmpeg_profile_names, self.ffmpeg_profile_probs)[0]\n                extra_args['profile:v'] = profile\n\n            ffmpeg_img2video = (\n                ffmpeg.input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{width}x{height}',\n                             r=fps).filter('fps', fps=fps, round='up').output(\n                                 'pipe:', format=format, pix_fmt='yuv420p', crf=crf, vcodec=vcodec,\n                                 **extra_args).global_args('-hide_banner').global_args('-loglevel', loglevel).run_async(\n                                     pipe_stdin=True, pipe_stdout=True))\n            ffmpeg_video2img = (\n                ffmpeg.input('pipe:', format=format).output('pipe:', format='rawvideo',\n                                                            pix_fmt='rgb24').global_args('-hide_banner').global_args(\n                                                                '-loglevel',\n                                                                loglevel).run_async(pipe_stdin=True, pipe_stdout=True))\n\n            # read a sequence of images\n            for img_lq in img_lqs:\n                ffmpeg_img2video.stdin.write(img_lq.astype(np.uint8).tobytes())\n\n            ffmpeg_img2video.stdin.close()\n            video_bytes = ffmpeg_img2video.stdout.read()\n            ffmpeg_img2video.wait()\n\n            # ffmpeg: video to images\n            ffmpeg_video2img.stdin.write(video_bytes)\n            ffmpeg_video2img.stdin.close()\n            img_lqs_ffmpeg = []\n            while True:\n                in_bytes = ffmpeg_video2img.stdout.read(width * height * 3)\n                if not in_bytes:\n                    break\n                in_frame = (np.frombuffer(in_bytes, np.uint8).reshape([height, width, 3]))\n                in_frame = in_frame.astype(np.float32) / 255.\n                img_lqs_ffmpeg.append(in_frame)\n\n            ffmpeg_video2img.wait()\n\n            assert len(img_lqs_ffmpeg) == self.num_frame, 'Wrong length'\n        except AssertionError as error:\n            logger = get_root_logger()\n            logger.warn(f'ffmpeg assertion error: {error}')\n        except Exception as error:\n            logger = get_root_logger()\n            logger.warn(f'ffmpeg exception error: {error}')\n        else:\n            img_lqs = img_lqs_ffmpeg\n\n        return img_lqs\n\n    def __getitem__(self, index):\n        if self.file_client is None:\n            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)\n\n        img_gts = self.get_gt_clip(index)\n\n        # ------------- generate LQ frames --------------#\n        # add blur\n        kernel = random_mixed_kernels(['iso', 'aniso'], [0.7, 0.3], 21, self.iso_blur_range, self.aniso_blur_range)\n        img_lqs = [cv2.filter2D(v, -1, kernel) for v in img_gts]\n        # add noise\n        img_lqs = [\n            random_add_gaussian_noise(v, sigma_range=self.noise_range, gray_prob=0.5, clip=True, rounds=False)\n            for v in img_lqs\n        ]\n        # downsample\n        h, w = img_gts[0].shape[0:2]\n        width = w // self.scale\n        height = h // self.scale\n        resize_type = random.choices([cv2.INTER_AREA, cv2.INTER_LINEAR, cv2.INTER_CUBIC], [0.3, 0.3, 0.4])[0]\n        img_lqs = [cv2.resize(v, (width, height), interpolation=resize_type) for v in img_lqs]\n        # ffmpeg\n        img_lqs = [np.clip(img_lq * 255.0, 0, 255) for img_lq in img_lqs]\n        img_lqs = self.add_ffmpeg_compression(img_lqs, width, height)\n        # ------------- end --------------#\n        img_gts = img2tensor(img_gts)\n        img_lqs = img2tensor(img_lqs)\n        img_gts = torch.stack(img_gts, dim=0)\n        img_lqs = torch.stack(img_lqs, dim=0)\n\n        # img_lqs: (t, c, h, w)\n        # img_gts: (t, c, h, w)\n        return {'lq': img_lqs, 'gt': img_gts}\n\n    def __len__(self):\n        return len(self.keys)\n"
  },
  {
    "path": "animesr/data/ffmpeg_anime_lbo_dataset.py",
    "content": "import numpy as np\nimport random\nimport torch\nfrom torch.nn import functional as F\n\nfrom animesr.archs.simple_degradation_arch import SimpleDegradationArch\nfrom basicsr.data.degradations import random_add_gaussian_noise_pt, random_mixed_kernels\nfrom basicsr.utils import FileClient, get_root_logger, img2tensor\nfrom basicsr.utils.dist_util import get_dist_info\nfrom basicsr.utils.img_process_util import filter2D\nfrom basicsr.utils.registry import DATASET_REGISTRY\nfrom .ffmpeg_anime_dataset import FFMPEGAnimeDataset\n\n\n@DATASET_REGISTRY.register()\nclass FFMPEGAnimeLBODataset(FFMPEGAnimeDataset):\n    \"\"\"Anime datasets with both classic basic operators and learnable basic operators (LBO)\"\"\"\n\n    def __init__(self, opt):\n        super(FFMPEGAnimeLBODataset, self).__init__(opt)\n\n        self.rank, self.world_size = get_dist_info()\n\n        self.lbo = SimpleDegradationArch(downscale=2)\n        lbo_list = opt['degradation_model_path']\n        if not isinstance(lbo_list, list):\n            lbo_list = [lbo_list]\n        self.lbo_list = lbo_list\n        # print(f'degradation model path for {self.rank} {self.world_size}: {degradation_model_path}\\n')\n        # the real load is at reload_degradation_model function\n        self.lbo.load_state_dict(torch.load(self.lbo_list[0], map_location=lambda storage, loc: storage)['params'])\n        self.lbo = self.lbo.to(f'cuda:{self.rank}').eval()\n        self.lbo_prob = opt.get('lbo_prob', 0.5)\n\n    def reload_degradation_model(self):\n        \"\"\"\n        __init__ will be only invoked once for one gpu worker, so if we want to\n         have num_worker_dataset * num_gpu degradation model, we must call this func in __getitem__\n         ref: https://discuss.pytorch.org/t/what-happened-when-set-num-workers-0-in-dataloader/138515\n        \"\"\"\n        degradation_model_path = random.choice(self.lbo_list)\n        self.lbo.load_state_dict(\n            torch.load(degradation_model_path, map_location=lambda storage, loc: storage)['params'])\n        print(f'reload degradation model path for {self.rank} {self.world_size}: {degradation_model_path}\\n')\n        logger = get_root_logger()\n        logger.info(f'reload degradation model path for {self.rank} {self.world_size}: {degradation_model_path}\\n')\n\n    @torch.no_grad()\n    def custom_resize(self, x, scale=2):\n        if random.random() < self.lbo_prob:  # learned degradation model from real-world\n            x = self.lbo(x)\n        else:  # classic synthetic\n            h, w = x.shape[2:]\n            width = w // scale\n            height = h // scale\n            mode = random.choice(['area', 'bilinear', 'bicubic'])\n            if mode == 'area':\n                align_corners = None\n            else:\n                align_corners = False\n            x = F.interpolate(x, size=(height, width), mode=mode, align_corners=align_corners)\n\n        return x\n\n    @torch.no_grad()\n    def __getitem__(self, index):\n        if self.file_client is None:\n            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)\n            # called only once\n            self.reload_degradation_model()\n\n        img_gts = self.get_gt_clip(index)\n\n        # ------------- generate LQ frames --------------#\n        # change to CUDA implementation\n        img_gts = img2tensor(img_gts)\n        img_gts = torch.stack(img_gts, dim=0)\n        img_gts = img_gts.to(f'cuda:{self.rank}')\n        # add blur\n        kernel = random_mixed_kernels(['iso', 'aniso'], [0.7, 0.3], 21, self.iso_blur_range, self.aniso_blur_range)\n        with torch.no_grad():\n            kernel = torch.FloatTensor(kernel).unsqueeze(0).expand(self.num_frame, 21, 21).to(f'cuda:{self.rank}')\n            img_lqs = filter2D(img_gts, kernel)\n            # add noise\n            img_lqs = random_add_gaussian_noise_pt(\n                img_lqs, sigma_range=self.noise_range, clip=True, rounds=False, gray_prob=0.5)\n            # downsample\n            img_lqs = self.custom_resize(img_lqs)\n            if self.scale == 4:\n                img_lqs = self.custom_resize(img_lqs)\n            height, width = img_lqs.shape[2:]\n            # back to numpy since ffmpeg compression operate on cpu\n            img_lqs = img_lqs.detach().clamp_(0, 1).permute(0, 2, 3, 1) * 255  # B, H, W, C\n            img_lqs = img_lqs.type(torch.uint8).cpu().numpy()[:, :, :, ::-1]\n            img_lqs = np.split(img_lqs, self.num_frame, axis=0)\n            img_lqs = [img_lq[0] for img_lq in img_lqs]\n\n        # ffmpeg\n        img_lqs = self.add_ffmpeg_compression(img_lqs, width, height)\n        # ------------- end --------------#\n        img_lqs = img2tensor(img_lqs)\n        img_lqs = torch.stack(img_lqs, dim=0)\n\n        # img_lqs: (t, c, h, w)\n        # img_gts: (t, c, h, w) on gpu\n        return {'lq': img_lqs, 'gt': img_gts.cpu()}\n\n    def __len__(self):\n        return len(self.keys)\n"
  },
  {
    "path": "animesr/data/paired_image_dataset.py",
    "content": "import glob\nimport os\nfrom torch.utils import data as data\nfrom torchvision.transforms.functional import normalize\n\nfrom basicsr.data.transforms import augment, mod_crop, paired_random_crop\nfrom basicsr.utils import FileClient, imfrombytes, img2tensor\nfrom basicsr.utils.registry import DATASET_REGISTRY\n\n\n@DATASET_REGISTRY.register()\nclass CustomPairedImageDataset(data.Dataset):\n    \"\"\"Paired image dataset for training LBO.\n\n    Read real-world LQ and GT frames pairs.\n    The organization of these gt&lq folder is similar to AVC-Train,\n    except that each folder contains 200 clips, and each clip contains 11 frames.\n    We will ignore the first frame, so there are finally 2000 training pair data.\n\n    Args:\n        opt (dict): Config for train datasets. It contains the following keys:\n            dataroot_gt (str): Data root path for gt, also the pseudo HR path.\n            dataroot_lq (str): Data root path for lq.\n            io_backend (dict): IO backend type and other kwarg.\n            gt_size (int): Cropped patched size for gt patches.\n            use_hflip (bool): Use horizontal flips.\n            use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).\n\n            scale (bool): Scale, which will be added automatically.\n            phase (str): 'train' or 'val'.\n    \"\"\"\n\n    def __init__(self, opt):\n        super(CustomPairedImageDataset, self).__init__()\n        self.opt = opt\n        # file client (io backend)\n        self.file_client = None\n        self.io_backend_opt = opt['io_backend']\n        self.mean = opt['mean'] if 'mean' in opt else None\n        self.std = opt['std'] if 'std' in opt else None\n        self.mod_crop_scale = 8\n\n        self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']\n\n        omit_first_frame = opt.get('omit_first_frame', True)\n        start_idx = 1 if omit_first_frame else 0\n\n        self.paths = []\n        clip_list = os.listdir(self.lq_folder)\n        for clip_name in clip_list:\n            lq_frame_list = sorted(glob.glob(f'{self.lq_folder}/{clip_name}/*.png'))\n            gt_frame_list = sorted(glob.glob(f'{self.gt_folder}/{clip_name}/*.png'))\n            assert len(lq_frame_list) == len(gt_frame_list)\n            for i in range(start_idx, len(lq_frame_list)):\n                # omit the first frame\n                self.paths.append(dict([('lq_path', lq_frame_list[i]), ('gt_path', gt_frame_list[i])]))\n\n    def __getitem__(self, index):\n        if self.file_client is None:\n            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)\n\n        scale = self.opt['scale']\n\n        # Load gt and lq images. Dimension order: HWC; channel order: BGR;\n        # image range: [0, 1], float32.\n        gt_path = self.paths[index]['gt_path']\n        img_bytes = self.file_client.get(gt_path, 'gt')\n        img_gt = imfrombytes(img_bytes, float32=True)\n        lq_path = self.paths[index]['lq_path']\n        img_bytes = self.file_client.get(lq_path, 'lq')\n        img_lq = imfrombytes(img_bytes, float32=True)\n        img_lq = mod_crop(img_lq, self.mod_crop_scale)\n\n        # augmentation for training\n        if self.opt['phase'] == 'train':\n            gt_size = self.opt['gt_size']\n            # random crop\n            img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)\n            # flip, rotation\n            img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])\n\n        # BGR to RGB, HWC to CHW, numpy to tensor\n        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)\n        # normalize\n        if self.mean is not None or self.std is not None:\n            normalize(img_lq, self.mean, self.std, inplace=True)\n            normalize(img_gt, self.mean, self.std, inplace=True)\n\n        return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}\n\n    def __len__(self):\n        return len(self.paths)\n"
  },
  {
    "path": "animesr/models/__init__.py",
    "content": "import importlib\nfrom os import path as osp\n\nfrom basicsr.utils import scandir\n\n# automatically scan and import model modules for registry\n# scan all the files that end with '_model.py' under the model folder\nmodel_folder = osp.dirname(osp.abspath(__file__))\nmodel_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]\n# import all the model modules\n_model_modules = [importlib.import_module(f'animesr.models.{file_name}') for file_name in model_filenames]\n"
  },
  {
    "path": "animesr/models/degradation_gan_model.py",
    "content": "from collections import OrderedDict\n\nfrom basicsr.models.srgan_model import SRGANModel\nfrom basicsr.utils.registry import MODEL_REGISTRY\n\n\n@MODEL_REGISTRY.register()\nclass DegradationGANModel(SRGANModel):\n    \"\"\"Degradation model for real-world, hard-to-synthesis degradation.\"\"\"\n\n    def feed_data(self, data):\n        # we reverse the order of lq and gt for convenient implementation\n        self.lq = data['gt'].to(self.device)\n        if 'lq' in data:\n            self.gt = data['lq'].to(self.device)\n\n    def optimize_parameters(self, current_iter):\n        # optimize net_g\n        for p in self.net_d.parameters():\n            p.requires_grad = False\n\n        self.optimizer_g.zero_grad()\n        self.output = self.net_g(self.lq)\n\n        l_g_total = 0\n        loss_dict = OrderedDict()\n        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):\n            # pixel loss\n            if self.cri_pix:\n                l_g_pix = self.cri_pix(self.output, self.gt)\n                l_g_total += l_g_pix\n                loss_dict['l_g_pix'] = l_g_pix\n            # perceptual loss\n            if self.cri_perceptual:\n                l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)\n                if l_g_percep is not None:\n                    l_g_total += l_g_percep\n                    loss_dict['l_g_percep'] = l_g_percep\n                if l_g_style is not None:\n                    l_g_total += l_g_style\n                    loss_dict['l_g_style'] = l_g_style\n            # gan loss\n            fake_g_pred = self.net_d(self.output)\n            l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)\n            l_g_total += l_g_gan\n            loss_dict['l_g_gan'] = l_g_gan\n\n            l_g_total.backward()\n            self.optimizer_g.step()\n\n        # optimize net_d\n        for p in self.net_d.parameters():\n            p.requires_grad = True\n\n        self.optimizer_d.zero_grad()\n        # real\n        real_d_pred = self.net_d(self.gt)\n        l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)\n        loss_dict['l_d_real'] = l_d_real\n        l_d_real.backward()\n        # fake\n        fake_d_pred = self.net_d(self.output.detach())\n        l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)\n        loss_dict['l_d_fake'] = l_d_fake\n        l_d_fake.backward()\n        self.optimizer_d.step()\n\n        self.log_dict = self.reduce_loss_dict(loss_dict)\n\n        if self.ema_decay > 0:\n            self.model_ema(decay=self.ema_decay)\n"
  },
  {
    "path": "animesr/models/degradation_model.py",
    "content": "from collections import OrderedDict\n\nfrom basicsr.losses import build_loss\nfrom basicsr.models.sr_model import SRModel\nfrom basicsr.utils.registry import MODEL_REGISTRY\n\n\n@MODEL_REGISTRY.register()\nclass DegradationModel(SRModel):\n    \"\"\"Degradation model for real-world, hard-to-synthesis degradation.\"\"\"\n\n    def init_training_settings(self):\n        self.net_g.train()\n        train_opt = self.opt['train']\n\n        # define losses\n        self.l1_pix = build_loss(train_opt['l1_opt']).to(self.device)\n        self.l2_pix = build_loss(train_opt['l2_opt']).to(self.device)\n\n        # set up optimizers and schedulers\n        self.setup_optimizers()\n        self.setup_schedulers()\n\n    def feed_data(self, data):\n        # we reverse the order of lq and gt for convenient implementation\n        self.lq = data['gt'].to(self.device)\n        if 'lq' in data:\n            self.gt = data['lq'].to(self.device)\n\n    def optimize_parameters(self, current_iter):\n        self.optimizer_g.zero_grad()\n        self.output = self.net_g(self.lq)\n\n        l_total = 0\n        loss_dict = OrderedDict()\n        # l1 loss\n        l_l1 = self.l1_pix(self.output, self.gt)\n        l_total += l_l1\n        loss_dict['l_l1'] = l_l1\n        # l2 loss\n        l_l2 = self.l2_pix(self.output, self.gt)\n        l_total += l_l2\n        loss_dict['l_l2'] = l_l2\n\n        l_total.backward()\n        self.optimizer_g.step()\n\n        self.log_dict = self.reduce_loss_dict(loss_dict)\n"
  },
  {
    "path": "animesr/models/video_recurrent_gan_model.py",
    "content": "from collections import OrderedDict\n\nfrom basicsr.archs import build_network\nfrom basicsr.losses import build_loss\nfrom basicsr.utils import get_root_logger\nfrom basicsr.utils.registry import MODEL_REGISTRY\nfrom .video_recurrent_model import VideoRecurrentCustomModel\n\n\n@MODEL_REGISTRY.register()\nclass VideoRecurrentGANCustomModel(VideoRecurrentCustomModel):\n    \"\"\"Currently, the VideoRecurrentGANModel and multi-scale discriminator are not compatible,\n    so we use a custom model.\n    \"\"\"\n\n    def init_training_settings(self):\n        train_opt = self.opt['train']\n\n        self.ema_decay = train_opt.get('ema_decay', 0)\n        if self.ema_decay > 0:\n            logger = get_root_logger()\n            logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')\n            # build network net_g with Exponential Moving Average (EMA)\n            # net_g_ema only used for testing on one GPU and saving.\n            # There is no need to wrap with DistributedDataParallel\n            self.net_g_ema = build_network(self.opt['network_g']).to(self.device)\n            # load pretrained model\n            load_path = self.opt['path'].get('pretrain_network_g', None)\n            if load_path is not None:\n                self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')\n            else:\n                self.model_ema(0)  # copy net_g weight\n            self.net_g_ema.eval()\n\n        # define network net_d\n        self.net_d = build_network(self.opt['network_d'])\n        self.net_d = self.model_to_device(self.net_d)\n        self.print_network(self.net_d)\n\n        # load pretrained models\n        load_path = self.opt['path'].get('pretrain_network_d', None)\n        if load_path is not None:\n            param_key = self.opt['path'].get('param_key_d', 'params')\n            self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)\n\n        self.net_g.train()\n        self.net_d.train()\n\n        # define losses\n        if train_opt.get('pixel_opt'):\n            self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)\n        else:\n            self.cri_pix = None\n\n        if train_opt.get('perceptual_opt'):\n            self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)\n        else:\n            self.cri_perceptual = None\n\n        if train_opt.get('gan_opt'):\n            self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)\n\n        self.net_d_iters = train_opt.get('net_d_iters', 1)\n        self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)\n\n        # set up optimizers and schedulers\n        self.setup_optimizers()\n        self.setup_schedulers()\n\n    def setup_optimizers(self):\n        train_opt = self.opt['train']\n        if train_opt['fix_flow']:\n            normal_params = []\n            flow_params = []\n            for name, param in self.net_g.named_parameters():\n                if 'spynet' in name:  # The fix_flow now only works for spynet.\n                    flow_params.append(param)\n                else:\n                    normal_params.append(param)\n\n            optim_params = [\n                {  # add flow params first\n                    'params': flow_params,\n                    'lr': train_opt['lr_flow']\n                },\n                {\n                    'params': normal_params,\n                    'lr': train_opt['optim_g']['lr']\n                },\n            ]\n        else:\n            optim_params = self.net_g.parameters()\n\n        # optimizer g\n        optim_type = train_opt['optim_g'].pop('type')\n        self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])\n        self.optimizers.append(self.optimizer_g)\n        # optimizer d\n        optim_type = train_opt['optim_d'].pop('type')\n        self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])\n        self.optimizers.append(self.optimizer_d)\n\n    def optimize_parameters(self, current_iter):\n        # optimize net_g\n        for p in self.net_d.parameters():\n            p.requires_grad = False\n\n        self.optimize_parameters_base(current_iter)\n\n        _, _, c, h, w = self.output.size()\n\n        pix_gt = self.gt\n        percep_gt = self.gt\n        gan_gt = self.gt\n        if self.opt.get('l1_gt_usm', False):\n            pix_gt = self.gt_usm\n        if self.opt.get('percep_gt_usm', False):\n            percep_gt = self.gt_usm\n        if self.opt.get('gan_gt_usm', False):\n            gan_gt = self.gt_usm\n\n        l_g_total = 0\n        loss_dict = OrderedDict()\n        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):\n            # pixel loss\n            if self.cri_pix:\n                l_g_pix = self.cri_pix(self.output, pix_gt)\n                l_g_total += l_g_pix\n                loss_dict['l_g_pix'] = l_g_pix\n            # perceptual loss\n            if self.cri_perceptual:\n                l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), percep_gt.view(-1, c, h, w))\n                if l_g_percep is not None:\n                    l_g_total += l_g_percep\n                    loss_dict['l_g_percep'] = l_g_percep\n                if l_g_style is not None:\n                    l_g_total += l_g_style\n                    loss_dict['l_g_style'] = l_g_style\n            # gan loss\n            fake_g_pred = self.net_d(self.output.view(-1, c, h, w))\n            l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)\n            l_g_total += l_g_gan\n            loss_dict['l_g_gan'] = l_g_gan\n\n            l_g_total.backward()\n            self.optimizer_g.step()\n\n        # optimize net_d\n        for p in self.net_d.parameters():\n            p.requires_grad = True\n\n        self.optimizer_d.zero_grad()\n        # real\n        # reshape to (b*n, c, h, w)\n        real_d_pred = self.net_d(gan_gt.view(-1, c, h, w))\n        l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)\n        loss_dict['l_d_real'] = l_d_real\n        l_d_real.backward()\n        # fake\n        # reshape to (b*n, c, h, w)\n        fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach())\n        l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)\n        loss_dict['l_d_fake'] = l_d_fake\n        l_d_fake.backward()\n        self.optimizer_d.step()\n\n        self.log_dict = self.reduce_loss_dict(loss_dict)\n\n        if self.ema_decay > 0:\n            self.model_ema(decay=self.ema_decay)\n\n    def save(self, epoch, current_iter):\n        if self.ema_decay > 0:\n            self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])\n        else:\n            self.save_network(self.net_g, 'net_g', current_iter)\n        self.save_network(self.net_d, 'net_d', current_iter)\n        self.save_training_state(epoch, current_iter)\n"
  },
  {
    "path": "animesr/models/video_recurrent_model.py",
    "content": "import cv2\nimport os\nimport torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom torch import distributed as dist\nfrom tqdm import tqdm\n\nfrom basicsr.models.video_base_model import VideoBaseModel\nfrom basicsr.utils import USMSharp, get_root_logger, imwrite, tensor2img\nfrom basicsr.utils.dist_util import get_dist_info\nfrom basicsr.utils.registry import MODEL_REGISTRY\n\n\n@MODEL_REGISTRY.register()\nclass VideoRecurrentCustomModel(VideoBaseModel):\n\n    def __init__(self, opt):\n        super(VideoRecurrentCustomModel, self).__init__(opt)\n        if self.is_train:\n            self.fix_flow_iter = opt['train'].get('fix_flow')\n        self.idx = 1\n        lq_from_usm = opt['datasets']['train'].get('lq_from_usm', False)\n        assert lq_from_usm is False\n        self.usm_sharp_gt = opt['datasets']['train'].get('usm_sharp_gt', False)\n\n        if self.usm_sharp_gt:\n            usm_radius = opt['datasets']['train'].get('usm_radius', 50)\n            self.usm_sharpener = USMSharp(radius=usm_radius).cuda()\n            self.usm_weight = opt['datasets']['train'].get('usm_weight', 0.5)\n            self.usm_threshold = opt['datasets']['train'].get('usm_threshold', 10)\n\n    @torch.no_grad()\n    def feed_data(self, data):\n        self.lq = data['lq'].to(self.device)\n        if 'gt' in data:\n            self.gt = data['gt'].to(self.device)\n            if 'gt_usm' in data:\n                self.gt_usm = data['gt_usm'].to(self.device)\n                logger = get_root_logger()\n                logger.warning(\n                    'since lq is not from gt_usm, '\n                    'we should put the usm_sharp operation outside the dataloader to speed up the traning time')\n            elif self.usm_sharp_gt:\n                b, n, c, h, w = self.gt.size()\n                self.gt_usm = self.usm_sharpener(\n                    self.gt.view(b * n, c, h, w), weight=self.usm_weight,\n                    threshold=self.usm_threshold).view(b, n, c, h, w)\n\n        # if self.opt['rank'] == 0 and 'debug' in self.opt['name']:\n        #     import torchvision\n        #     os.makedirs('tmp/gt', exist_ok=True)\n        #     os.makedirs('tmp/gt_usm', exist_ok=True)\n        #     os.makedirs('tmp/lq', exist_ok=True)\n        #     print(self.idx)\n        #     for i in range(15):\n        #         torchvision.utils.save_image(\n        #             self.lq[:, i, :, :, :],\n        #             f'tmp/lq/lq{self.idx}_{i}.png',\n        #             nrow=4,\n        #             padding=2,\n        #             normalize=True,\n        #             range=(0, 1))\n        #         torchvision.utils.save_image(\n        #             self.gt[:, i, :, :, :],\n        #             f'tmp/gt/gt{self.idx}_{i}.png',\n        #             nrow=4,\n        #             padding=2,\n        #             normalize=True,\n        #             range=(0, 1))\n        #         torchvision.utils.save_image(\n        #             self.gt_usm[:, i, :, :, :],\n        #             f'tmp/gt_usm/gt_usm{self.idx}_{i}.png',\n        #             nrow=4,\n        #             padding=2,\n        #             normalize=True,\n        #             range=(0, 1))\n        #     self.idx += 1\n        #     if self.idx >= 20:\n        #         exit()\n\n    def setup_optimizers(self):\n        train_opt = self.opt['train']\n        flow_lr_mul = train_opt.get('flow_lr_mul', 1)\n        logger = get_root_logger()\n        logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.')\n        if flow_lr_mul == 1:\n            optim_params = self.net_g.parameters()\n        else:  # separate flow params and normal params for different lr\n            normal_params = []\n            flow_params = []\n            for name, param in self.net_g.named_parameters():\n                if 'spynet' in name:\n                    flow_params.append(param)\n                else:\n                    normal_params.append(param)\n            optim_params = [\n                {  # add normal params first\n                    'params': normal_params,\n                    'lr': train_opt['optim_g']['lr']\n                },\n                {\n                    'params': flow_params,\n                    'lr': train_opt['optim_g']['lr'] * flow_lr_mul\n                },\n            ]\n\n        optim_type = train_opt['optim_g'].pop('type')\n        self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])\n        self.optimizers.append(self.optimizer_g)\n\n    def optimize_parameters_base(self, current_iter):\n        if self.fix_flow_iter:\n            logger = get_root_logger()\n            if current_iter == 1:\n                logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')\n                for name, param in self.net_g.named_parameters():\n                    if 'spynet' in name or 'edvr' in name:\n                        param.requires_grad_(False)\n            elif current_iter == self.fix_flow_iter:\n                logger.warning('Train all the parameters.')\n                self.net_g.requires_grad_(True)\n\n        self.optimizer_g.zero_grad()\n        self.output = self.net_g(self.lq)\n\n    def optimize_parameters(self, current_iter):\n        self.optimize_parameters_base(current_iter)\n        _, _, c, h, w = self.output.size()\n\n        pix_gt = self.gt\n        percep_gt = self.gt\n        if self.opt.get('l1_gt_usm', False):\n            pix_gt = self.gt_usm\n        if self.opt.get('percep_gt_usm', False):\n            percep_gt = self.gt_usm\n\n        l_total = 0\n        loss_dict = OrderedDict()\n        # pixel loss\n        if self.cri_pix:\n            l_pix = self.cri_pix(self.output, pix_gt)\n            l_total += l_pix\n            loss_dict['l_pix'] = l_pix\n        # perceptual loss\n        if self.cri_perceptual:\n            l_percep, l_style = self.cri_perceptual(self.output.view(-1, c, h, w), percep_gt.view(-1, c, h, w))\n            if l_percep is not None:\n                l_total += l_percep\n                loss_dict['l_percep'] = l_percep\n            if l_style is not None:\n                l_total += l_style\n                loss_dict['l_style'] = l_style\n\n        l_total.backward()\n        self.optimizer_g.step()\n\n        self.log_dict = self.reduce_loss_dict(loss_dict)\n\n        if self.ema_decay > 0:\n            self.model_ema(decay=self.ema_decay)\n\n    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):\n        \"\"\"dist_test actually, no gt, no metrics\"\"\"\n        dataset = dataloader.dataset\n        dataset_name = dataset.opt['name']\n        assert dataset_name.endswith('CoreFrames')\n        rank, world_size = get_dist_info()\n\n        num_folders = len(dataset)\n        num_pad = (world_size - (num_folders % world_size)) % world_size\n        if rank == 0:\n            pbar = tqdm(total=len(dataset), unit='folder')\n            os.makedirs(osp.join(self.opt['path']['visualization'], dataset_name, str(current_iter)), exist_ok=True)\n\n        if self.opt['dist']:\n            dist.barrier()\n        # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.\n        # (To avoid wait-dead)\n        for i in range(rank, num_folders + num_pad, world_size):\n            idx = min(i, num_folders - 1)\n            val_data = dataset[idx]\n            folder = val_data['folder']\n\n            # compute outputs\n            val_data['lq'].unsqueeze_(0)\n            self.feed_data(val_data)\n            val_data['lq'].squeeze_(0)\n\n            self.test()\n            visuals = self.get_current_visuals()\n\n            # tentative for out of GPU memory\n            del self.lq\n            del self.output\n            if 'gt' in visuals:\n                del self.gt\n            torch.cuda.empty_cache()\n\n            # evaluate\n            if i < num_folders:\n                for idx in range(visuals['result'].size(1)):\n                    result = visuals['result'][0, idx, :, :, :]\n                    result_img = tensor2img([result])  # uint8, bgr\n\n                    # since we keep all frames, scale of 4 is not very friendly to storage space\n                    # so we use a default scale of 2 to save the frames\n                    save_scale = self.opt.get('savescale', 2)\n                    net_scale = self.opt.get('scale')\n                    if save_scale != net_scale:\n                        h, w = result_img.shape[0:2]\n                        result_img = cv2.resize(\n                            result_img, (w // net_scale * save_scale, h // net_scale * save_scale),\n                            interpolation=cv2.INTER_LANCZOS4)\n\n                    if save_img:\n                        img_path = osp.join(self.opt['path']['visualization'], dataset_name, str(current_iter),\n                                            f\"{folder}_{idx:08d}_{self.opt['name'][:5]}.png\")\n                        # image name only for REDS dataset\n                        imwrite(result_img, img_path)\n\n                # progress bar\n                if rank == 0:\n                    for _ in range(world_size):\n                        pbar.update(1)\n                        pbar.set_description(f'Folder: {folder}')\n\n        if rank == 0:\n            pbar.close()\n\n    def test(self):\n        n = self.lq.size(1)\n        self.net_g.eval()\n\n        flip_seq = self.opt['val'].get('flip_seq', False)\n        self.center_frame_only = self.opt['val'].get('center_frame_only', False)\n\n        if flip_seq:\n            self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1)\n\n        with torch.no_grad():\n            self.output = self.net_g(self.lq)\n\n        if flip_seq:\n            output_1 = self.output[:, :n, :, :, :]\n            output_2 = self.output[:, n:, :, :, :].flip(1)\n            self.output = 0.5 * (output_1 + output_2)\n\n        if self.center_frame_only:\n            self.output = self.output[:, n // 2, :, :, :]\n\n        self.net_g.train()\n"
  },
  {
    "path": "animesr/test.py",
    "content": "# flake8: noqa\nimport os.path as osp\n\nimport animesr.archs\nimport animesr.data\nimport animesr.models\nfrom basicsr.test import test_pipeline\n\nif __name__ == '__main__':\n    root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))\n    test_pipeline(root_path)\n"
  },
  {
    "path": "animesr/train.py",
    "content": "# flake8: noqa\nimport os.path as osp\n\nimport animesr.archs\nimport animesr.data\nimport animesr.models\nfrom basicsr.train import train_pipeline\n\nif __name__ == '__main__':\n    root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))\n    train_pipeline(root_path)\n"
  },
  {
    "path": "animesr/utils/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "animesr/utils/inference_base.py",
    "content": "import argparse\nimport os.path\nimport torch\n\nfrom animesr.archs.vsr_arch import MSRSWVSR\n\n\ndef get_base_argument_parser() -> argparse.ArgumentParser:\n    \"\"\"get the base argument parser for inference scripts\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-i', '--input', type=str, default='inputs', help='input test image folder or video path')\n    parser.add_argument('-o', '--output', type=str, default='results', help='save image/video path')\n    parser.add_argument(\n        '-n',\n        '--model_name',\n        type=str,\n        default='AnimeSR_v2',\n        help='Model names: AnimeSR_v2 | AnimeSR_v1-PaperModel. Default:AnimeSR_v2')\n    parser.add_argument(\n        '-s',\n        '--outscale',\n        type=int,\n        default=4,\n        help='The netscale is x4, but you can achieve arbitrary output scale (e.g., x2) with the argument outscale'\n        'The program will further perform cheap resize operation after the AnimeSR output. '\n        'This is useful when you want to save disk space or avoid too large-resolution output')\n    parser.add_argument(\n        '--expname', type=str, default='animesr', help='A unique name to identify your current inference')\n    parser.add_argument(\n        '--netscale',\n        type=int,\n        default=4,\n        help='the released models are all x4 models, only change this if you train a x2 or x1 model by yourself')\n    parser.add_argument(\n        '--mod_scale',\n        type=int,\n        default=4,\n        help='the scale used for mod crop, since AnimeSR use a multi-scale arch, so the edge should be divisible by 4')\n    parser.add_argument('--fps', type=int, default=None, help='fps of the sr videos')\n    parser.add_argument('--half', action='store_true', help='use half precision to inference')\n\n    return parser\n\n\ndef get_inference_model(args, device) -> MSRSWVSR:\n    \"\"\"return an on device model with eval mode\"\"\"\n    # set up model\n    model = MSRSWVSR(num_feat=64, num_block=[5, 3, 2], netscale=args.netscale)\n\n    model_path = f'weights/{args.model_name}.pth'\n    assert os.path.isfile(model_path), \\\n        f'{model_path} does not exist, please make sure you successfully download the pretrained models ' \\\n        f'and put them into the weights folder'\n\n    # load checkpoint\n    loadnet = torch.load(model_path)\n    model.load_state_dict(loadnet, strict=True)\n    model.eval()\n    model = model.to(device)\n\n    # num_parameters = sum(map(lambda x: x.numel(), model.parameters()))\n    # print(num_parameters)\n    # exit(0)\n\n    return model.half() if args.half else model\n"
  },
  {
    "path": "animesr/utils/shot_detector.py",
    "content": "# The codes below partially refer to the PySceneDetect. According\n# to its BSD 3-Clause License, we keep the following.\n#\n#          PySceneDetect: Python-Based Video Scene Detector\n#   ---------------------------------------------------------------\n#     [  Site: http://www.bcastell.com/projects/PySceneDetect/   ]\n#     [  Github: https://github.com/Breakthrough/PySceneDetect/  ]\n#     [  Documentation: http://pyscenedetect.readthedocs.org/    ]\n#\n# Copyright (C) 2014-2020 Brandon Castellano <http://www.bcastell.com>.\n#\n# PySceneDetect is licensed under the BSD 3-Clause License; see the included\n# LICENSE file, or visit one of the following pages for details:\n#  - https://github.com/Breakthrough/PySceneDetect/\n#  - http://www.bcastell.com/projects/PySceneDetect/\n#\n# This software uses Numpy, OpenCV, click, tqdm, simpletable, and pytest.\n# See the included LICENSE files or one of the above URLs for more information.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE\n# AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN\n# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION\n# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n\nimport cv2\nimport glob\nimport numpy as np\nimport os\nfrom tqdm import tqdm\n\nDEFAULT_DOWNSCALE_FACTORS = {\n    3200: 12,  # ~4k\n    2100: 8,  # ~2k\n    1700: 6,  # ~1080p\n    1200: 5,\n    900: 4,  # ~720p\n    600: 3,\n    400: 1  # ~480p\n}\n\n\ndef compute_downscale_factor(frame_width):\n    \"\"\"Compute Downscale Factor: Returns the optimal default downscale factor\n    based on a video's resolution (specifically, the width parameter).\n\n    Returns:\n        int: The defalt downscale factor to use with a video of frame_height x\n            frame_width.\n    \"\"\"\n    for width in sorted(DEFAULT_DOWNSCALE_FACTORS, reverse=True):\n        if frame_width >= width:\n            return DEFAULT_DOWNSCALE_FACTORS[width]\n    return 1\n\n\nclass ShotDetector(object):\n    \"\"\"Detects fast cuts using changes in colour and intensity between frames.\n\n    Detect shot boundary using HSV and LUV information.\n    \"\"\"\n\n    def __init__(self, threshold=30.0, min_shot_len=15):\n        super(ShotDetector, self).__init__()\n        self.hsv_threshold = threshold\n        self.delta_hsv_gap_threshold = 10\n        self.luv_threshold = 40\n        self.hsv_weight = 5\n        # minimum length (frames length) of any given shot\n        self.min_shot_len = min_shot_len\n        self.last_frame = None\n        self.last_shot_cut = None\n        self.last_hsv = None\n        self._metric_keys = [\n            'hsv_content_val', 'delta_hsv_hue', 'delta_hsv_sat', 'delta_hsv_lum', 'luv_content_val', 'delta_luv_hue',\n            'delta_luv_sat', 'delta_luv_lum'\n        ]\n        self.cli_name = 'detect-content'\n        self.last_luv = None\n        self.cut_list = []\n\n    def add_cut(self, cut):\n        num_cuts = len(self.cut_list)\n        if num_cuts == 0:\n            self.cut_list.append([0, cut - 1])\n        else:\n            self.cut_list.append([self.cut_list[num_cuts - 1][1] + 1, cut - 1])\n\n    def process_frame(self, frame_num, frame_img):\n        \"\"\"Similar to ThresholdDetector, but using the HSV colour space\n        DIFFERENCE instead of single-frame RGB/grayscale intensity (thus cannot\n        detect slow fades with this method).\n\n        Args:\n            frame_num (int): Frame number of frame that is being passed.\n\n            frame_img (Optional[int]): Decoded frame image (np.ndarray) to\n                perform shot detection on. Can be None *only* if the\n                self.is_processing_required() method\n                (inhereted from the base shotDetector class) returns True.\n\n        Returns:\n            List[int]: List of frames where shot cuts have been detected.\n            There may be 0 or more frames in the list, and not necessarily\n            the same as frame_num.\n        \"\"\"\n        cut_list = []\n\n        # if self.last_frame is not None:\n        #     # Change in average of HSV (hsv), (h)ue only,\n        #     # (s)aturation only, (l)uminance only.\n        delta_hsv_avg, delta_hsv_h, delta_hsv_s, delta_hsv_v = 0.0, 0.0, 0.0, 0.0\n        delta_luv_avg, delta_luv_h, delta_luv_s, delta_luv_v = 0.0, 0.0, 0.0, 0.0\n        if frame_num == 0:\n            self.last_frame = frame_img.copy()\n            return cut_list\n\n        else:\n            num_pixels = frame_img.shape[0] * frame_img.shape[1]\n            curr_luv = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_BGR2Luv))\n            curr_hsv = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_BGR2HSV))\n            last_hsv = self.last_hsv\n            last_luv = self.last_luv\n            if not last_hsv:\n                last_hsv = cv2.split(cv2.cvtColor(self.last_frame, cv2.COLOR_BGR2HSV))\n                last_luv = cv2.split(cv2.cvtColor(self.last_frame, cv2.COLOR_BGR2Luv))\n\n            delta_hsv = [0, 0, 0, 0]\n            for i in range(3):\n                num_pixels = curr_hsv[i].shape[0] * curr_hsv[i].shape[1]\n                curr_hsv[i] = curr_hsv[i].astype(np.int32)\n                last_hsv[i] = last_hsv[i].astype(np.int32)\n                delta_hsv[i] = np.sum(np.abs(curr_hsv[i] - last_hsv[i])) / float(num_pixels)\n            delta_hsv[3] = sum(delta_hsv[0:3]) / 3.0\n            delta_hsv_h, delta_hsv_s, delta_hsv_v, delta_hsv_avg = \\\n                delta_hsv\n\n            delta_luv = [0, 0, 0, 0]\n            for i in range(3):\n                num_pixels = curr_luv[i].shape[0] * curr_luv[i].shape[1]\n                curr_luv[i] = curr_luv[i].astype(np.int32)\n                last_luv[i] = last_luv[i].astype(np.int32)\n                delta_luv[i] = np.sum(np.abs(curr_luv[i] - last_luv[i])) / float(num_pixels)\n            delta_luv[3] = sum(delta_luv[0:3]) / 3.0\n            delta_luv_h, delta_luv_s, delta_luv_v, delta_luv_avg = \\\n                delta_luv\n\n            self.last_hsv = curr_hsv\n            self.last_luv = curr_luv\n        if delta_hsv_avg >= self.hsv_threshold and delta_hsv_avg - self.hsv_threshold >= self.delta_hsv_gap_threshold:\n            if self.last_shot_cut is None or ((frame_num - self.last_shot_cut) >= self.min_shot_len):\n                cut_list.append(frame_num)\n                self.last_shot_cut = frame_num\n        elif delta_hsv_avg >= self.hsv_threshold and \\\n                delta_hsv_avg - self.hsv_threshold < \\\n                self.delta_hsv_gap_threshold and \\\n                delta_luv_avg + self.hsv_weight * \\\n                (delta_hsv_avg - self.hsv_threshold) > self.luv_threshold:\n            if self.last_shot_cut is None or ((frame_num - self.last_shot_cut) >= self.min_shot_len):\n                cut_list.append(frame_num)\n                self.last_shot_cut = frame_num\n\n        self.last_frame = frame_img.copy()\n        return cut_list\n\n    def detect_shots(self, frame_source, frame_skip=0, show_progress=True, keep_resolution=False):\n        \"\"\"Perform shot detection on the given frame_source using the added\n            shotDetectors.\n\n            Blocks until all frames in the frame_source have been processed.\n            Results can be obtained by calling either the get_shot_list()\n            or get_cut_list() methods.\n            Arguments:\n                frame_source (shotdetect.video_manager.VideoManager or\n                    cv2.VideoCapture):\n                    A source of frames to process (using frame_source.read() as in\n                    VideoCapture).\n                    VideoManager is preferred as it allows concatenation of\n                    multiple videos as well as seeking, by defining start time\n                    and end time/duration.\n                end_time (int or FrameTimecode): Maximum number of frames to detect\n                    (set to None to detect all available frames). Only needed for\n                    OpenCV\n                    VideoCapture objects; for VideoManager objects, use\n                    set_duration() instead.\n                frame_skip (int): Not recommended except for extremely high\n                    framerate videos.\n                    Number of frames to skip (i.e. process every 1 in N+1 frames,\n                    where N is frame_skip, processing only 1/N+1 percent of the\n                    video,\n                    speeding up the detection time at the expense of accuracy).\n                    `frame_skip` **must** be 0 (the default) when using a\n                    StatsManager.\n                show_progress (bool): If True, and the ``tqdm`` module is\n                    available, displays\n                    a progress bar with the progress, framerate, and expected\n                    time to\n                    complete processing the video frame source.\n\n            Raises:\n                ValueError: `frame_skip` **must** be 0 (the default)\n                    if the shotManager\n                    was constructed with a StatsManager object.\n            \"\"\"\n\n        if frame_skip > 0 and self._stats_manager is not None:\n            raise ValueError('frame_skip must be 0 when using a StatsManager.')\n\n        curr_frame = 0\n        frame_paths = sorted(glob.glob(os.path.join(frame_source, '*')))\n        total_frames = len(frame_paths)\n        end_frame = total_frames\n\n        progress_bar = None\n        if tqdm and show_progress:\n            progress_bar = tqdm(total=total_frames, unit='frames')\n\n        try:\n            while True:\n                if end_frame is not None and curr_frame >= end_frame:\n                    break\n\n                frame_im = cv2.imread(frame_paths[curr_frame])\n                if not keep_resolution:\n                    if curr_frame == 0:\n                        downscale_factor = compute_downscale_factor(frame_im.shape[1])\n                    frame_im = frame_im[::downscale_factor, ::downscale_factor, :]\n\n                cut = self.process_frame(curr_frame, frame_im)\n\n                if len(cut) != 0:\n                    self.add_cut(cut[0])\n\n                curr_frame += 1\n                if progress_bar:\n                    progress_bar.update(1)\n\n        finally:\n            if progress_bar:\n                progress_bar.close()\n\n        return self.cut_list\n"
  },
  {
    "path": "animesr/utils/video_util.py",
    "content": "import glob\nimport os\nimport subprocess\n\ndefault_ffmpeg_exe_path = 'ffmpeg'\ndefault_ffprobe_exe_path = 'ffprobe'\ndefault_ffmpeg_vcodec = 'h264'\ndefault_ffmpeg_pix_fmt = 'yuv420p'\n\n\ndef get_video_fps(video_path, ret_type='float'):\n    \"\"\"Get the fps of the video.\n\n    Args:\n        video_path (str): the video path;\n        ret_type (str): the return type, it supports `str`, `float`, and `tuple` (numerator, denominator).\n\n    Returns:\n        --fps (str): if ret_type is `str`.\n        --fps (float): if ret_type is `float`.\n        --fps (tuple): if ret_type is tuple, (numerator, denominator).\n    \"\"\"\n\n    global default_ffprobe_exe_path\n\n    ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path)\n\n    cmd = [\n        ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-of', 'default=noprint_wrappers=1:nokey=1',\n        '-show_entries', 'stream=r_frame_rate', video_path\n    ]\n\n    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    fps = result.stdout.decode('utf-8').strip()\n\n    # e.g. 30/1\n    numerator, denominator = map(lambda x: int(x), fps.split('/'))\n    if ret_type == 'float':\n        return numerator / denominator\n    elif ret_type == 'str':\n        return str(numerator / denominator)\n    else:\n        return numerator, denominator\n\n\ndef get_video_num_frames(video_path):\n    \"\"\"Get the video's total number of frames.\"\"\"\n\n    global default_ffprobe_exe_path\n\n    ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path)\n\n    cmd = [\n        ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-count_packets', '-of',\n        'default=noprint_wrappers=1:nokey=1', '-show_entries', 'stream=nb_read_packets', video_path\n    ]\n\n    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    nb_frames = result.stdout.decode('utf-8').strip()\n\n    return int(nb_frames)\n\n\ndef get_video_bitrate(video_path):\n    \"\"\"Get the bitrate of the video.\"\"\"\n\n    global default_ffprobe_exe_path\n\n    ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path)\n\n    cmd = [\n        ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-of', 'default=noprint_wrappers=1:nokey=1',\n        '-show_entries', 'stream=bit_rate', video_path\n    ]\n\n    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    bitrate = result.stdout.decode('utf-8').strip()\n\n    if bitrate == 'N/A':\n        return bitrate\n\n    return int(bitrate) // 1000\n\n\ndef get_video_resolution(video_path):\n    \"\"\"Get the resolution (h and w) of the video.\n\n    Args:\n        video_path (str): the video path;\n\n    Returns:\n        h, w (int)\n    \"\"\"\n\n    global default_ffprobe_exe_path\n\n    ffprobe_exe_path = os.environ.get('ffprobe_exe_path', default_ffprobe_exe_path)\n\n    cmd = [\n        ffprobe_exe_path, '-v', 'quiet', '-select_streams', 'v', '-of', 'csv=s=x:p=0', '-show_entries',\n        'stream=width,height', video_path\n    ]\n\n    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n    resolution = result.stdout.decode('utf-8').strip()\n\n    # print(resolution)\n\n    w, h = resolution.split('x')\n    return int(h), int(w)\n\n\ndef video2frames(video_path, out_dir, force=False, high_quality=True, ss=None, to=None, vf=None):\n    \"\"\"Extract frames from the video\n\n    Args:\n        out_dir: where to save the frames\n        force: if out_dir is not empty, forceTrue will still extract frames\n        high_quality: whether to use the highest quality\n        ss: start time, format HH.MM.SS[.xxx], if None, extract full video\n        to: end time, format HH.MM.SS[.xxx], if None, extract full video\n        vf: video filter\n    \"\"\"\n    global default_ffmpeg_exe_path\n\n    ffmpeg_exc_path = os.environ.get('ffmpeg_exe_path', default_ffmpeg_exe_path)\n\n    imgs = glob.glob(os.path.join(out_dir, '*.png'))\n    length = len(imgs)\n    if length > 0:\n        print(f'{out_dir} already has frames!, force extract = {force}')\n        if not force:\n            return out_dir\n\n    print(f'extracting frames for {video_path}')\n\n    cmd = [\n        ffmpeg_exc_path,\n        f'-i {video_path}',\n        '-v error',\n        f'-ss {ss} -to {to}' if ss is not None and to is not None else '',\n        '-qscale:v 1 -qmin 1 -qmax 1 -vsync 0' if high_quality else '',\n        f'-vf {vf}' if vf is not None else '',\n        f'{out_dir}/%08d.png',\n    ]\n    print(' '.join(cmd))\n    subprocess.call(' '.join(cmd), shell=True)\n\n    return out_dir\n\n\ndef frames2video(frames_dir, out_path, fps=25, filter='*', suffix=None):\n    \"\"\"Combine frames under a folder to video\n    Args:\n        frames_dir: input folder where frames locate\n        out_path: the output video path\n        fps: output video fps\n        suffix: the frame suffix, e.g., png jpg\n    \"\"\"\n    global default_ffmpeg_vcodec, default_ffmpeg_pix_fmt, default_ffmpeg_exe_path\n\n    ffmpeg_exc_path = os.environ.get('ffmpeg_exe_path', default_ffmpeg_exe_path)\n    vcodec = os.environ.get('ffmpeg_vcodec', default_ffmpeg_vcodec)\n    pix_fmt = os.environ.get('ffmpeg_pix_fmt', default_ffmpeg_pix_fmt)\n\n    if suffix is None:\n        images_names = os.listdir(frames_dir)\n        image_name = images_names[0]\n        suffix = image_name.split('.')[-1]\n\n    cmd = [\n        ffmpeg_exc_path,\n        '-y',\n        '-r', str(fps),\n        '-f', 'image2',\n        '-pattern_type', 'glob',\n        '-i', f'{frames_dir}/{filter}.{suffix}',\n        '-vcodec', vcodec,\n        '-pix_fmt', pix_fmt,\n        out_path\n    ]  # yapf: disable\n    print(' '.join(cmd))\n    subprocess.call(cmd)\n\n    return out_path\n"
  },
  {
    "path": "cog.yaml",
    "content": "build:\n  gpu: true\n  cuda: \"11.6.2\"\n  python_version: \"3.10\"\n  system_packages:\n    - \"libgl1-mesa-glx\"\n    - \"libglib2.0-0\"\n    - \"ffmpeg\"\n  python_packages:\n    - \"ipython==8.4.0\"\n    - \"torch==1.13.1\"\n    - \"torchvision==0.14.1\"\n    - \"ffmpeg-python==0.2.0\"\n    - \"facexlib==0.2.5\"\n    - \"basicsr==1.4.2\"\n    - \"opencv-python==4.7.0.68\"\n    - \"Pillow==9.4.0\"\n    - \"psutil==5.9.4\"\n    - \"tqdm==4.64.1\"\n\npredict: \"predict.py:Predictor\"\n"
  },
  {
    "path": "options/train_animesr_step1_gan_BasicOPonly.yml",
    "content": "# general settings\nname: train_animesr_step1_gan_BasicOPonly\nmodel_type: VideoRecurrentGANCustomModel\nscale: 4\nnum_gpu: auto  # set num_gpu: 0 for cpu mode\nmanual_seed: 0\n\n# USM the ground-truth\nl1_gt_usm: False\npercep_gt_usm: False\ngan_gt_usm: False\n\n# dataset and data loader settings\ndatasets:\n  train:\n    name: AVC-Train\n    type: FFMPEGAnimeDataset\n    dataroot_gt: datasets/AVC-Train  # TO_MODIFY\n    test_mode: False\n    io_backend:\n      type: disk\n\n    num_frame: 15\n    gt_size: 256\n    interval_list: [1, 2, 3]\n    random_reverse: True\n    use_flip: true\n    use_rot: true\n    usm_sharp_gt: False\n\n    # data loader\n    use_shuffle: true\n    num_worker_per_gpu: 20\n    batch_size_per_gpu: 4\n    dataset_enlarge_ratio: 1\n    prefetch_mode: ~\n    pin_memory: True\n\n# network structures\nnetwork_g:\n  type: MSRSWVSR\n  num_feat: 64\n  num_block: [5, 3, 2]\n\nnetwork_d:\n  type: MultiScaleDiscriminator\n  num_in_ch: 3\n  num_feat: 64\n  num_layers: [3, 3, 3]\n  max_nf_mult: 8\n  norm_type: none\n  use_sigmoid: False\n  use_sn: True\n  use_downscale: True\n\n# path\npath:\n  pretrain_network_g: experiments/train_animesr_step1_net_BasicOPonly/models/net_g_300000.pth\n  param_key_g: params\n  strict_load_g: true\n  resume_state: ~\n\n# training settings\ntrain:\n  ema_decay: 0.999\n  optim_g:\n    type: Adam\n    lr: !!float 1e-4\n    weight_decay: 0\n    betas: [0.9, 0.99]\n  lr_flow: !!float 1e-5\n  optim_d:\n    type: Adam\n    lr: !!float 1e-4\n    weight_decay: 0\n    betas: [0.9, 0.99]\n\n  scheduler:\n    type: MultiStepLR\n    milestones: [300000]\n    gamma: 0.5\n\n  total_iter: 300000\n  warmup_iter: -1  # no warm up\n  fix_flow: ~\n\n  # losses\n  pixel_opt:\n    type: L1Loss\n    loss_weight: 1.0\n    reduction: mean\n  perceptual_opt:\n    type: PerceptualLoss\n    layer_weights:\n      # before relu\n      'conv1_2': 0.1\n      'conv2_2': 0.1\n      'conv3_4': 1\n      'conv4_4': 1\n      'conv5_4': 1\n    vgg_type: vgg19\n    use_input_norm: true\n    range_norm: false\n    perceptual_weight: 1.0\n    style_weight: 0\n    criterion: l1\n  gan_opt:\n    type: MultiScaleGANLoss\n    gan_type: lsgan\n    real_label_val: 1.0\n    fake_label_val: 0.0\n    loss_weight: !!float 1.0\n\n  net_d_iters: 1\n  net_d_init_iters: 0\n\n# validation settings\n#val:\n#  val_freq: !!float 1e4\n#  save_img: true\n\n# logging settings\nlogger:\n  print_freq: 100\n  save_checkpoint_freq: !!float 5e3\n  use_tb_logger: true\n  wandb:\n    project: ~\n    resume_id: ~\n\n# dist training settings\ndist_params:\n  backend: nccl\n  port: 29500\n\nfind_unused_parameters: true\n"
  },
  {
    "path": "options/train_animesr_step1_net_BasicOPonly.yml",
    "content": "# general settings\nname: train_animesr_step1_net_BasicOPonly\nmodel_type: VideoRecurrentCustomModel\nscale: 4\nnum_gpu: auto  # set num_gpu: 0 for cpu mode\nmanual_seed: 0\n\n# USM the ground-truth\nl1_gt_usm: True\n\n# dataset and data loader settings\ndatasets:\n  train:\n    name: AVC-Train\n    type: FFMPEGAnimeDataset\n    dataroot_gt: datasets/AVC-Train  # TO_MODIFY\n    test_mode: False\n    io_backend:\n      type: disk\n\n    num_frame: 15\n    gt_size: 256\n    interval_list: [1, 2, 3]\n    random_reverse: True\n    use_flip: true\n    use_rot: true\n    usm_sharp_gt: True\n    usm_weight: 0.3\n    usm_radius: 50\n\n    # data loader\n    use_shuffle: true\n    num_worker_per_gpu: 20\n    batch_size_per_gpu: 4\n    dataset_enlarge_ratio: 1\n    prefetch_mode: ~\n    pin_memory: True\n\n# network structures\nnetwork_g:\n  type: MSRSWVSR\n  num_feat: 64\n  num_block: [5, 3, 2]\n\n# path\npath:\n  resume_state: ~\n\n# training settings\ntrain:\n  ema_decay: 0.999\n  optim_g:\n    type: Adam\n    lr: !!float 2e-4\n    weight_decay: 0\n    betas: [0.9, 0.99]\n\n  scheduler:\n    type: MultiStepLR\n    milestones: [300000]\n    gamma: 0.5\n\n  total_iter: 300000\n  warmup_iter: -1  # no warm up\n\n  # losses\n  pixel_opt:\n    type: L1Loss\n    loss_weight: 1.0\n    reduction: mean\n\n# validation settings\n#val:\n#  val_freq: !!float 1e4\n#  save_img: true\n\n# logging settings\nlogger:\n  print_freq: 100\n  save_checkpoint_freq: !!float 5e3\n  use_tb_logger: true\n  wandb:\n    project: ~\n    resume_id: ~\n\n# dist training settings\ndist_params:\n  backend: nccl\n  port: 29500\n\nfind_unused_parameters: true\n"
  },
  {
    "path": "options/train_animesr_step2_lbo_1_gan.yml",
    "content": "# general settings\nname: train_animesr_step2_lbo_1_gan\nmodel_type: DegradationGANModel\nscale: 2\nnum_gpu: auto  # set num_gpu: 0 for cpu mode\nmanual_seed: 0\n\n# dataset and data loader settings\ndatasets:\n  train:\n    name: LBO_1\n    type: CustomPairedImageDataset\n    dataroot_gt: results/input_rescaling_strategy_lbo_1/frames  # TO_MODIFY\n    dataroot_lq: datasets/lbo_training_data/real_world_video_to_train_lbo_1  # TO_MODIFY\n    io_backend:\n      type: disk\n\n    gt_size: 256\n    use_hflip: true\n    use_rot: true\n\n    # data loader\n    use_shuffle: true\n    num_worker_per_gpu: 12\n    batch_size_per_gpu: 16\n    dataset_enlarge_ratio: 200\n    prefetch_mode: ~\n\n# network structures\nnetwork_g:\n  type: SimpleDegradationArch\n  num_in_ch: 3\n  num_out_ch: 3\n  num_feat: 64\n  downscale: 2\n\nnetwork_d:\n  type: MultiScaleDiscriminator\n  num_in_ch: 3\n  num_feat: 64\n  num_layers: [3]\n  max_nf_mult: 8\n  norm_type: none\n  use_sigmoid: False\n  use_sn: True\n  use_downscale: True\n\n\n# path\npath:\n  pretrain_network_g: experiments/train_animesr_step2_lbo_1_net/models/net_g_100000.pth\n  param_key_g: params\n  strict_load_g: true\n  resume_state: ~\n\n# training settings\ntrain:\n  optim_g:\n    type: Adam\n    lr: !!float 1e-4\n    weight_decay: 0\n    betas: [0.9, 0.99]\n  optim_d:\n    type: Adam\n    lr: !!float 1e-4\n    weight_decay: 0\n    betas: [0.9, 0.99]\n\n  scheduler:\n    type: MultiStepLR\n    milestones: [50000]\n    gamma: 0.5\n\n  total_iter: 100000\n  warmup_iter: -1  # no warm up\n\n  # losses\n  pixel_opt:\n    type: L1Loss\n    loss_weight: 1.0\n    reduction: mean\n  perceptual_opt:\n    type: PerceptualLoss\n    layer_weights:\n      # before relu\n      'conv1_2': 0.1\n      'conv2_2': 0.1\n      'conv3_4': 1\n      'conv4_4': 1\n      'conv5_4': 1\n    vgg_type: vgg19\n    use_input_norm: true\n    range_norm: false\n    perceptual_weight: 1.0\n    style_weight: 0\n    criterion: l1\n  gan_opt:\n    type: MultiScaleGANLoss\n    gan_type: lsgan\n    real_label_val: 1.0\n    fake_label_val: 0.0\n    loss_weight: !!float 1.0\n\n# logging settings\nlogger:\n  print_freq: 100\n  save_checkpoint_freq: !!float 5e3\n  use_tb_logger: true\n  wandb:\n    project: ~\n    resume_id: ~\n\n# dist training settings\ndist_params:\n  backend: nccl\n  port: 29500\n"
  },
  {
    "path": "options/train_animesr_step2_lbo_1_net.yml",
    "content": "# general settings\nname: train_animesr_step2_lbo_1_net\nmodel_type: DegradationModel\nscale: 2\nnum_gpu: auto  # set num_gpu: 0 for cpu mode\nmanual_seed: 0\n\n# dataset and data loader settings\ndatasets:\n  train:\n    name: LBO_1\n    type: CustomPairedImageDataset\n    dataroot_gt: results/input_rescaling_strategy_lbo_1/frames  # TO_MODIFY\n    dataroot_lq: datasets/lbo_training_data/real_world_video_to_train_lbo_1  # TO_MODIFY\n    io_backend:\n      type: disk\n\n    gt_size: 256\n    use_hflip: true\n    use_rot: true\n\n    # data loader\n    use_shuffle: true\n    num_worker_per_gpu: 12\n    batch_size_per_gpu: 16\n    dataset_enlarge_ratio: 200\n    prefetch_mode: ~\n\n# network structures\nnetwork_g:\n  type: SimpleDegradationArch\n  num_in_ch: 3\n  num_out_ch: 3\n  num_feat: 64\n  downscale: 2\n\n\n# path\npath:\n  resume_state: ~\n\n# training settings\ntrain:\n  optim_g:\n    type: Adam\n    lr: !!float 2e-4\n    weight_decay: 0\n    betas: [0.9, 0.99]\n\n  scheduler:\n    type: MultiStepLR\n    milestones: [50000]\n    gamma: 0.5\n\n  total_iter: 100000\n  warmup_iter: -1  # no warm up\n\n  # losses\n  l1_opt:\n    type: L1Loss\n    loss_weight: 1.0\n    reduction: mean\n\n  l2_opt:\n    type: MSELoss\n    loss_weight: 1.0\n    reduction: mean\n\n# logging settings\nlogger:\n  print_freq: 100\n  save_checkpoint_freq: !!float 5e3\n  use_tb_logger: true\n  wandb:\n    project: ~\n    resume_id: ~\n\n# dist training settings\ndist_params:\n  backend: nccl\n  port: 29500\n"
  },
  {
    "path": "options/train_animesr_step3_gan_3LBOs.yml",
    "content": "# general settings\nname: train_animesr_step3_gan_3LBOs\nmodel_type: VideoRecurrentGANCustomModel\nscale: 4\nnum_gpu: auto  # set num_gpu: 0 for cpu mode\nmanual_seed: 0\n\n# USM the ground-truth\nl1_gt_usm: False\npercep_gt_usm: False\ngan_gt_usm: False\n\n# dataset and data loader settings\ndatasets:\n  train:\n    name: AVC-Train\n    type: FFMPEGAnimeLBODataset\n    dataroot_gt: datasets/AVC-Train  # TO_MODIFY\n    test_mode: False\n    io_backend:\n      type: disk\n\n    num_frame: 15\n    gt_size: 256\n    interval_list: [1, 2, 3]\n    random_reverse: True\n    use_flip: true\n    use_rot: true\n    usm_sharp_gt: False\n    degradation_model_path: [weights/pretrained_lbo_1.pth, weights/pretrained_lbo_2.pth, weights/pretrained_lbo_3.pth]  # TO_MODIFY\n\n    # data loader\n    use_shuffle: true\n    num_worker_per_gpu: 5\n    batch_size_per_gpu: 4\n    dataset_enlarge_ratio: 1\n    prefetch_mode: ~\n    pin_memory: True\n\n# network structures\nnetwork_g:\n  type: MSRSWVSR\n  num_feat: 64\n  num_block: [5, 3, 2]\n\nnetwork_d:\n  type: MultiScaleDiscriminator\n  num_in_ch: 3\n  num_feat: 64\n  num_layers: [3, 3, 3]\n  max_nf_mult: 8\n  norm_type: none\n  use_sigmoid: False\n  use_sn: True\n  use_downscale: True\n\n# path\npath:\n  pretrain_network_g: experiments/train_animesr_step1_net_BasicOPonly/models/net_g_300000.pth\n  param_key_g: params\n  strict_load_g: true\n  resume_state: ~\n\n# training settings\ntrain:\n  ema_decay: 0.999\n  optim_g:\n    type: Adam\n    lr: !!float 1e-4\n    weight_decay: 0\n    betas: [0.9, 0.99]\n  lr_flow: !!float 1e-5\n  optim_d:\n    type: Adam\n    lr: !!float 1e-4\n    weight_decay: 0\n    betas: [0.9, 0.99]\n\n  scheduler:\n    type: MultiStepLR\n    milestones: [300000]\n    gamma: 0.5\n\n  total_iter: 300000\n  warmup_iter: -1  # no warm up\n  fix_flow: ~\n\n  # losses\n  pixel_opt:\n    type: L1Loss\n    loss_weight: 1.0\n    reduction: mean\n  perceptual_opt:\n    type: PerceptualLoss\n    layer_weights:\n      # before relu\n      'conv1_2': 0.1\n      'conv2_2': 0.1\n      'conv3_4': 1\n      'conv4_4': 1\n      'conv5_4': 1\n    vgg_type: vgg19\n    use_input_norm: true\n    range_norm: false\n    perceptual_weight: 1.0\n    style_weight: 0\n    criterion: l1\n  gan_opt:\n    type: MultiScaleGANLoss\n    gan_type: lsgan\n    real_label_val: 1.0\n    fake_label_val: 0.0\n    loss_weight: !!float 1.0\n\n  net_d_iters: 1\n  net_d_init_iters: 0\n\n# validation settings\n#val:\n#  val_freq: !!float 1e4\n#  save_img: true\n\n# logging settings\nlogger:\n  print_freq: 100\n  save_checkpoint_freq: !!float 5e3\n  use_tb_logger: true\n  wandb:\n    project: ~\n    resume_id: ~\n\n# dist training settings\ndist_params:\n  backend: nccl\n  port: 29500\n\nfind_unused_parameters: true\n"
  },
  {
    "path": "predict.py",
    "content": "import os\nimport shutil\nimport tempfile\nfrom subprocess import call\nfrom zipfile import ZipFile\nfrom typing import Optional\nimport mimetypes\nimport torch\n\nfrom cog import BasePredictor, Input, Path, BaseModel\n\n\ncall(\"python setup.py develop\", shell=True)\n\n\nclass ModelOutput(BaseModel):\n    video: Path\n    sr_frames: Optional[Path]\n\n\nclass Predictor(BasePredictor):\n    @torch.inference_mode()\n    def predict(\n        self,\n        video: Path = Input(\n            description=\"Input video file\",\n            default=None,\n        ),\n        frames: Path = Input(\n            description=\"Zip file of frames of a video. Ignored when video is provided.\",\n            default=None,\n        ),\n    ) -> ModelOutput:\n        \"\"\"Run a single prediction on the model\"\"\"\n        assert frames or video, \"Please provide frames of video input.\"\n\n        out_path = \"cog_temp\"\n        if os.path.exists(out_path):\n            shutil.rmtree(out_path)\n        os.makedirs(out_path, exist_ok=True)\n\n        if video:\n            print(\"processing video...\")\n            cmd = (\n                \"python scripts/inference_animesr_video.py -i \"\n                + str(video)\n                + \" -o \"\n                + out_path\n                + \" -n AnimeSR_v2 -s 4 --expname animesr_v2 --num_process_per_gpu 1\"\n            )\n            call(cmd, shell=True)\n\n        else:\n            print(\"processing frames...\")\n            unzip_frames = \"cog_frames_temp\"\n            if os.path.exists(unzip_frames):\n                shutil.rmtree(unzip_frames)\n            os.makedirs(unzip_frames)\n\n            with ZipFile(str(frames), \"r\") as zip_ref:\n                for zip_info in zip_ref.infolist():\n                    if zip_info.filename[-1] == \"/\" or zip_info.filename.startswith(\n                        \"__MACOSX\"\n                    ):\n                        continue\n                    mt = mimetypes.guess_type(zip_info.filename)\n                    if mt and mt[0] and mt[0].startswith(\"image/\"):\n                        zip_info.filename = os.path.basename(zip_info.filename)\n                        zip_ref.extract(zip_info, unzip_frames)\n\n            cmd = (\n                \"python scripts/inference_animesr_frames.py -i \"\n                + unzip_frames\n                + \" -o \"\n                + out_path\n                + \" -n AnimeSR_v2 --expname animesr_v2 --save_video_too --fps 20\"\n            )\n            call(cmd, shell=True)\n\n            frames_output = Path(tempfile.mkdtemp()) / \"out.zip\"\n            frames_out_dir = os.listdir(f\"{out_path}/animesr_v2/frames\")\n            assert len(frames_out_dir) == 1\n            frames_path = os.path.join(\n                f\"{out_path}/animesr_v2/frames\", frames_out_dir[0]\n            )\n            # by defult, sr_frames will be saved in cog_temp/animesr_v2/frames\n            sr_frames_files = os.listdir(frames_path)\n\n            with ZipFile(str(frames_output), \"w\") as zip:\n                for img in sr_frames_files:\n                    zip.write(os.path.join(frames_path, img))\n\n        # by defult, video will be saved in cog_temp/animesr_v2/videos\n        video_out_dir = os.listdir(f\"{out_path}/animesr_v2/videos\")\n        assert len(video_out_dir) == 1\n        if video_out_dir[0].endswith(\".mp4\"):\n            source = os.path.join(f\"{out_path}/animesr_v2/videos\", video_out_dir[0])\n        else:\n            video_output = os.listdir(\n                f\"{out_path}/animesr_v2/videos/{video_out_dir[0]}\"\n            )[0]\n            source = os.path.join(\n                f\"{out_path}/animesr_v2/videos\", video_out_dir[0], video_output\n            )\n        video_path = Path(tempfile.mkdtemp()) / \"out.mp4\"\n        shutil.copy(source, str(video_path))\n\n        if video:\n            return ModelOutput(video=video_path)\n        return ModelOutput(sr_frames=frames_output, video=video_path)\n"
  },
  {
    "path": "requirements.txt",
    "content": "basicsr\nfacexlib\nffmpeg-python\nnumpy\nopencv-python\npillow\npsutil\ntorch\ntorchvision\ntqdm\n"
  },
  {
    "path": "scripts/anime_videos_preprocessing.py",
    "content": "import argparse\nimport cv2\nimport glob\nimport numpy as np\nimport os\nimport shutil\nimport torch\nimport torchvision\nfrom multiprocessing import Pool\nfrom os import path as osp\nfrom PIL import Image\nfrom tqdm import tqdm\n\nfrom animesr.utils import video_util\nfrom animesr.utils.shot_detector import ShotDetector\nfrom basicsr.archs.spynet_arch import SpyNet\nfrom basicsr.utils import img2tensor\nfrom basicsr.utils.download_util import download_file_from_google_drive\nfrom facexlib.assessment import init_assessment_model\n\n\ndef main(args):\n    \"\"\"A script to prepare anime videos.\n\n    The preparation can be divided into following steps:\n    1. use ffmpeg to extract frames\n    2. shot detection\n    3. estimate flow\n    4. detect black frames\n    5. use hyperIQA to evaluate the quality of frames\n    6. generate at most 5 clips per video\n    \"\"\"\n\n    opt = dict()\n\n    opt['debug'] = args.debug\n    opt['n_thread'] = args.n_thread\n    opt['ss_idx'] = args.ss_idx\n    opt['to_idx'] = args.to_idx\n\n    # params for step1: extract frames\n    opt['video_root'] = f'{args.dataroot}/raw_videos'\n    opt['save_frames_root'] = f'{args.dataroot}/frames'\n    opt['meta_files_root'] = f'{args.dataroot}/meta'\n\n    # params for step2: shot detection\n    opt['detect_shot_root'] = f'{args.dataroot}/detect_shot'\n\n    # params for step3: flow estimation\n    opt['estimate_flow_root'] = f'{args.dataroot}/estimate_flow'\n    opt['spy_pretrain_weight'] = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'\n    opt['downscale_factor'] = 1\n\n    # params for step4: detect black frames\n    opt['black_flag_root'] = f'{args.dataroot}/black_flag'\n    opt['black_threshold'] = 0.98\n\n    # params for step5: image quality assessment\n    opt['num_patch_per_iqa'] = 5\n    opt['iqa_score_root'] = f'{args.dataroot}/iqa_score'\n\n    # params for step6: generate clips\n    opt['num_frames_per_clip'] = args.n_frames_per_clip\n    opt['num_clips_per_video'] = args.n_clips_per_video\n    opt['select_clips_root'] = f'{args.dataroot}/{args.select_clip_root}'\n    opt['select_clips_meta'] = osp.join(opt['select_clips_root'], 'meta_info')\n    opt['select_clips_frames'] = osp.join(opt['select_clips_root'], 'frames')\n    opt['select_done_flags'] = osp.join(opt['select_clips_root'], 'done_flags')\n\n    if '1' in args.run:\n        run_step1(opt)\n    if '2' in args.run:\n        run_step2(opt)\n    if '3' in args.run:\n        run_step3(opt)\n    if '4' in args.run:\n        run_step4(opt)\n    if '5' in args.run:\n        run_step5(opt)\n    if '6' in args.run:\n        run_step6(opt)\n\n\n# -------------------------------------------------------------------- #\n# --------------------------- step1 ---------------------------------- #\n# -------------------------------------------------------------------- #\n\n\ndef run_step1(opt):\n    \"\"\"extract frames\n\n    1. read all video files under video_root folder\n    2. filter out the videos that already have been processed\n    3. use multi-process to extract the remaining videos\n\n    \"\"\"\n\n    video_root = opt['video_root']\n    frames_root = opt['save_frames_root']\n    meta_root = opt['meta_files_root']\n    os.makedirs(frames_root, exist_ok=True)\n    os.makedirs(meta_root, exist_ok=True)\n\n    if not osp.isdir(video_root):\n        print(f'path {video_root} is not a valid folder, exit.')\n\n    videos_path = sorted(glob.glob(osp.join(video_root, '*')))\n    if opt['debug']:\n        videos_path = videos_path[:3]\n    else:\n        videos_path = videos_path[opt['ss_idx']:opt['to_idx']]\n    pbar = tqdm(total=len(videos_path), unit='video', desc='step1')\n    pool = Pool(opt['n_thread'])\n    for video_path in videos_path:\n        video_name = osp.splitext(osp.basename(video_path))[0]\n        if video_name.startswith('.'):\n            print(f'skip {video_name}')\n            continue\n        frame_path = osp.join(frames_root, video_name)\n        meta_path = osp.join(meta_root, f'{video_name}.txt')\n        pool.apply_async(\n            worker1, args=(opt, video_name, video_path, frame_path, meta_path), callback=lambda arg: pbar.update(1))\n    pool.close()\n    pool.join()\n\n\ndef worker1(opt, video_name, video_path, frame_path, meta_path):\n    # get info of video\n    fps = video_util.get_video_fps(video_path)\n    h, w = video_util.get_video_resolution(video_path)\n    num_frames = video_util.get_video_num_frames(video_path)\n    bit_rate = video_util.get_video_bitrate(video_path)\n\n    # check whether this video has been processed\n    flag = True\n    num_extracted_frames = 0\n    if osp.exists(frame_path):\n        num_extracted_frames = len(glob.glob(osp.join(frame_path, '*.png')))\n        if num_extracted_frames == num_frames:\n            print(f'skip {video_path} since there are already {num_frames} frames have been extracted.')\n            flag = False\n        else:\n            print(f'{num_extracted_frames} of {num_frames} have been extracted for {video_path}, re-run.')\n\n    # extract frames\n    os.makedirs(frame_path, exist_ok=True)\n    video_util.video2frames(video_path, frame_path, force=flag, high_quality=True)\n    if flag:\n        num_extracted_frames = len(glob.glob(osp.join(frame_path, '*.png')))\n\n    # write some metadata to meta file\n    with open(meta_path, 'w') as f:\n        f.write(f'Video Name: {video_name}\\n')\n        f.write(f'H: {h}\\n')\n        f.write(f'W: {w}\\n')\n        f.write(f'FPS: {fps}\\n')\n        f.write(f'Bit Rate: {bit_rate}kbps\\n')\n        f.write(f'{num_extracted_frames}/{num_frames} have been extracted\\n')\n\n\n# -------------------------------------------------------------------- #\n# --------------------------- step2 ---------------------------------- #\n# -------------------------------------------------------------------- #\n\n\ndef run_step2(opt):\n    \"\"\"shot detection. refer to lijian's pipeline\"\"\"\n    detect_shot_root = opt['detect_shot_root']\n    meta_root = opt['meta_files_root']\n    os.makedirs(detect_shot_root, exist_ok=True)\n    if not osp.exists(meta_root):\n        print('no videos has run step1, exit.')\n        return\n\n    # get the video which has been extracted frames\n    videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))\n    videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]\n\n    if opt['debug']:\n        videos_name = videos_name[:3]\n    else:\n        videos_name = videos_name[opt['ss_idx']:opt['to_idx']]\n\n    pbar = tqdm(total=len(videos_name), unit='video', desc='step2')\n    pool = Pool(opt['n_thread'])\n    for video_name in videos_name:\n        pool.apply_async(worker2, args=(opt, video_name), callback=lambda arg: pbar.update(1))\n    pool.close()\n    pool.join()\n\n\ndef worker2(opt, video_name):\n    video_frame_path = osp.join(opt['save_frames_root'], video_name)\n    detect_shot_file_path = osp.join(opt['detect_shot_root'], f'{video_name}.txt')\n    if osp.exists(detect_shot_file_path):\n        print(f'skip {video_name} since {detect_shot_file_path} already exist.')\n        return\n\n    detector = ShotDetector()\n    shot_list = detector.detect_shots(video_frame_path)\n    with open(detect_shot_file_path, 'w') as f:\n        for shot in shot_list:\n            f.write(f'{shot[0]} {shot[1]}\\n')\n\n\n# -------------------------------------------------------------------- #\n# --------------------------- step3 ---------------------------------- #\n# -------------------------------------------------------------------- #\n\n\ndef run_step3(opt):\n    estimate_flow_root = opt['estimate_flow_root']\n    meta_root = opt['meta_files_root']\n    os.makedirs(estimate_flow_root, exist_ok=True)\n    if not osp.exists(meta_root):\n        print('no videos has run step1, exit.')\n        return\n\n    # download the spynet checkpoint first\n    if not osp.exists(opt['spy_pretrain_weight']):\n        download_file_from_google_drive('1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF', opt['spy_pretrain_weight'])\n\n    # get the video which has been extracted frames\n    videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))\n    videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]\n\n    if opt['debug']:\n        videos_name = videos_name[:3]\n    else:\n        videos_name = videos_name[opt['ss_idx']:opt['to_idx']]\n\n    pbar = tqdm(total=len(videos_name), unit='video', desc='step3')\n\n    num_gpus = torch.cuda.device_count()\n    ctx = torch.multiprocessing.get_context('spawn')\n    pool = ctx.Pool(min(3 * num_gpus, opt['n_thread']))\n    for idx, video_name in enumerate(videos_name):\n        pool.apply_async(\n            worker3, args=(opt, video_name, torch.device(idx % num_gpus)), callback=lambda arg: pbar.update(1))\n    pool.close()\n    pool.join()\n\n\ndef read_img(img_path, device, downscale_factor=1):\n    img = cv2.imread(img_path)\n    h, w = img.shape[0:2]\n    if downscale_factor != 1:\n        img = cv2.resize(img, (w // downscale_factor, h // downscale_factor), interpolation=cv2.INTER_LANCZOS4)\n    img = img2tensor(img)\n    img = img.unsqueeze(0).to(device)\n    return img\n\n\n@torch.no_grad()\ndef worker3(opt, video_name, device):\n    video_frame_path = osp.join(opt['save_frames_root'], video_name)\n    frames_path = sorted(glob.glob(osp.join(video_frame_path, '*.png')))\n    estimate_flow_file_path = osp.join(opt['estimate_flow_root'], f'{video_name}.txt')\n    if osp.exists(estimate_flow_file_path):\n        with open(estimate_flow_file_path, 'r') as f:\n            lines = f.readlines()\n            length = len(lines)\n        if length == len(frames_path):\n            print(f'skip {video_name} since {length}/{len(frames_path)} have done.')\n            return\n        else:\n            print(f're-run {video_name} since only {length}/{len(frames_path)} have done.')\n\n    spynet = SpyNet(load_path=opt['spy_pretrain_weight']).eval().to(device)\n    downscale_factor = opt['downscale_factor']\n\n    flow_out_list = []\n\n    pbar = tqdm(total=len(frames_path), unit='frame', desc='worker3')\n    pre_img = None\n    for idx, frame_path in enumerate(frames_path):\n        img_name = osp.basename(frame_path)\n        cur_img = read_img(frame_path, device, downscale_factor=downscale_factor)\n\n        if pre_img is not None:\n            flow = spynet(cur_img, pre_img)\n            flow = flow.abs()\n            flow_max = flow.max().item()\n            flow_avg = flow.mean().item() * 2.0  # according to lijian's hyper-parameter\n        elif idx == 0:\n            flow_max = 0.0\n            flow_avg = 0.0\n        else:\n            raise RuntimeError(f'pre_img is none at {idx}')\n\n        flow_out_list.append(f'{img_name} {flow_max:.6f} {flow_avg:.6f}\\n')\n        pre_img = cur_img\n\n        pbar.update(1)\n\n    with open(estimate_flow_file_path, 'w') as f:\n        for line in flow_out_list:\n            f.write(line)\n\n\n# -------------------------------------------------------------------- #\n# --------------------------- step4 ---------------------------------- #\n# -------------------------------------------------------------------- #\n\n\ndef run_step4(opt):\n    black_flag_root = opt['black_flag_root']\n    meta_root = opt['meta_files_root']\n    os.makedirs(black_flag_root, exist_ok=True)\n    if not osp.exists(meta_root):\n        print('no videos has run step1, exit.')\n        return\n\n    # get the video which has been extracted frames\n    videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))\n    videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]\n\n    if opt['debug']:\n        videos_name = videos_name[:3]\n        os.makedirs('tmp_black', exist_ok=True)\n    else:\n        videos_name = videos_name[opt['ss_idx']:opt['to_idx']]\n\n    pbar = tqdm(total=len(videos_name), unit='video', desc='step4')\n\n    pool = Pool(opt['n_thread'])\n    for idx, video_name in enumerate(videos_name):\n        pool.apply_async(worker4, args=(opt, video_name), callback=lambda arg: pbar.update(1))\n    pool.close()\n    pool.join()\n\n\ndef worker4(opt, video_name):\n    video_frame_path = osp.join(opt['save_frames_root'], video_name)\n    black_flag_path = osp.join(opt['black_flag_root'], f'{video_name}.txt')\n    if osp.exists(black_flag_path):\n        print(f'skip {video_name} since {black_flag_path} already exists.')\n        return\n\n    frames_path = sorted(glob.glob(osp.join(video_frame_path, '*.png')))\n    out_list = []\n    pbar = tqdm(total=len(frames_path), unit='frame', desc='worker4')\n\n    for frame_path in frames_path:\n        img = cv2.imread(frame_path)\n        img_name = osp.basename(frame_path)\n        h, w = img.shape[0:2]\n        total_pixels = h * w\n        img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)\n        hist = cv2.calcHist([img_gray], [0], None, [256], [0.0, 255.0])\n        max_pixel = max(hist)[0]\n        percentage = max_pixel / total_pixels\n\n        if percentage > opt['black_threshold']:\n            out_list.append(f'{img_name} {0} {percentage:.6f}\\n')\n            if opt['debug']:\n                cv2.imwrite(osp.join('tmp_black', f'{video_name}_{img_name}'), img)\n        else:\n            out_list.append(f'{img_name} {1} {percentage:.6f}\\n')\n\n        pbar.update(1)\n\n    with open(black_flag_path, 'w') as f:\n        for line in out_list:\n            f.write(line)\n\n\n# -------------------------------------------------------------------- #\n# --------------------------- step5 ---------------------------------- #\n# -------------------------------------------------------------------- #\n\n\ndef run_step5(opt):\n    iqa_score_root = opt['iqa_score_root']\n    meta_root = opt['meta_files_root']\n    os.makedirs(iqa_score_root, exist_ok=True)\n    if not osp.exists(meta_root):\n        print('no videos has run step1, exit.')\n        return\n\n    # get the video which has been extracted frames\n    videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))\n    videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]\n\n    if opt['debug']:\n        videos_name = videos_name[:3]\n        os.makedirs('tmp_low_iqa', exist_ok=True)\n    else:\n        videos_name = videos_name[opt['ss_idx']:opt['to_idx']]\n\n    pbar = tqdm(total=len(videos_name), unit='video', desc='step5')\n\n    num_gpus = torch.cuda.device_count()\n    ctx = torch.multiprocessing.get_context('spawn')\n    pool = ctx.Pool(min(3 * num_gpus, opt['n_thread']))\n    for idx, video_name in enumerate(videos_name):\n        pool.apply_async(\n            worker5, args=(opt, video_name, torch.device(idx % num_gpus)), callback=lambda arg: pbar.update(1))\n    pool.close()\n    pool.join()\n\n\n@torch.no_grad()\ndef worker5(opt, video_name, device):\n    video_frame_path = osp.join(opt['save_frames_root'], video_name)\n    frames_path = sorted(glob.glob(osp.join(video_frame_path, '*.png')))\n    iqa_score_path = osp.join(opt['iqa_score_root'], f'{video_name}.txt')\n    if osp.exists(iqa_score_path):\n        with open(iqa_score_path, 'r') as f:\n            lines = f.readlines()\n            length = len(lines)\n        if length == len(frames_path):\n            print(f'skip {video_name} since {length}/{len(frames_path)} have done.')\n            return\n        else:\n            print(f're-run {video_name} since only {length}/{len(frames_path)} have done.')\n\n    assess_net = init_assessment_model('hypernet', device=device)\n    assess_net = assess_net.half()\n\n    # specified transformation in original hyperIQA\n    transforms_resize = torchvision.transforms.Compose([\n        torchvision.transforms.Resize((512, 384)),\n    ])\n    transforms_crop = torchvision.transforms.Compose([\n        torchvision.transforms.RandomCrop(size=224),\n        torchvision.transforms.ToTensor(),\n        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n    ])\n\n    iqa_out_list = []\n\n    pbar = tqdm(total=len(frames_path), unit='frame', desc='worker3')\n    for idx, frame_path in enumerate(frames_path):\n        img_name = osp.basename(frame_path)\n        cv2_img = cv2.imread(frame_path)\n        # BRG -> RGB\n        img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)\n        img = Image.fromarray(img)\n\n        patchs = []\n        img_resize = transforms_resize(img)\n        for _ in range(opt['num_patch_per_iqa']):\n            patchs.append(transforms_crop(img_resize))\n        patch = torch.stack(patchs, dim=0).to(device)\n\n        pred = assess_net(patch.half())\n        score = pred.mean().item()\n\n        iqa_out_list.append(f'{img_name} {score:.6f}\\n')\n        if opt['debug'] and score < 50.0:\n            cv2.imwrite(osp.join('tmp_low_iqa', f'{video_name}_{img_name}'), cv2_img)\n\n        pbar.update(1)\n\n    with open(iqa_score_path, 'w') as f:\n        for line in iqa_out_list:\n            f.write(line)\n\n\n# -------------------------------------------------------------------- #\n# --------------------------- step6 ---------------------------------- #\n# -------------------------------------------------------------------- #\n\n\ndef filter_frozen_shots(shots, flows):\n    \"\"\"select clips from input video.\"\"\"\n    flag_shot = np.ones(len(shots))\n\n    for idx, shot in enumerate(shots):\n        shot = shot.split(' ')\n        start = int(shot[0])\n        end = int(shot[1])\n\n        flow_in_shot = []\n\n        for i in range(start, end + 1, 1):\n            if i == 0:\n                continue\n            else:\n                flow_in_shot.append(float(flows[i].split(' ')[2]))\n\n        flow_in_shot = np.array(flow_in_shot)\n        flow_std = np.std(flow_in_shot)\n\n        if flow_std < 14.0:\n            flag_shot[idx] = 0\n\n    return flag_shot\n\n\ndef generate_clips(shots, flows, filter_frames, hyperiqa, max_length=500):\n    \"\"\"\n    hyperiqa [0, 100]\n    flows [0, 15000] (may be larger)\n    \"\"\"\n    clips = []\n    clip_scores = []\n    clip = []\n    shot_flow = 0\n    shot_hyperiqa = 0\n    for shot in shots:\n        shot = shot.split(' ')\n        start = int(shot[0])\n        end = int(shot[1])\n\n        pre_black = 0\n        for i in range(start, end + 1, 1):\n            if i == start:\n                stat = 0\n                pre_black = 1  # the first frame in shot do not need flow\n            else:\n                stat = 1\n\n            black_frame_thr = float(filter_frames[i].split(' ')[2])\n\n            # drop img when 90% of pixels are identical\n            if black_frame_thr < 0.90:\n                black_frame = 0\n            else:\n                black_frame = 1\n\n            # if current frame is a black frame, delete\n            if black_frame == 1:\n                pre_black = 1\n            elif pre_black == 0:\n                flow = float(flows[i].split(' ')[1])\n                shot_flow += flow\n            else:\n                pre_black = 0\n                flow = float(flows[i].split(' ')[1])\n\n            # calcu hyperiqa for non-black frames\n            if black_frame == 0:\n                curr_hyperiqa = float(hyperiqa[i].split(' ')[1])\n                shot_hyperiqa += curr_hyperiqa\n\n                clip.append(f'{i+1:08d} {stat} {flow} {curr_hyperiqa}')\n\n                if len(clip) == max_length:\n                    clips.append(clip.copy())\n                    clip_score = shot_flow / 150.0 + shot_hyperiqa\n                    clip_score = clip_score / len(clip)\n                    clip_scores.append(clip_score)\n                    clip = []\n                    shot_flow = 0\n                    shot_hyperiqa = 0\n\n    # print(len(clip))\n    # if len(clip) > 0:\n    #     clips.append(clip.copy())\n    #     clip_score = shot_flow / 150.0 + shot_hyperiqa\n    #     clip_score = clip_score / len(clip)\n    #     clip_scores.append(clip_score)\n\n    sorted_shot = np.argsort(-np.array(clip_scores))\n\n    return [clips[i] for i in sorted_shot], [clip_scores[i] for i in sorted_shot]\n\n\ndef run_step6(opt):\n    meta_root = opt['meta_files_root']\n    if not osp.exists(meta_root):\n        print('no videos has run step1, exit.')\n        return\n\n    # get the video which has been extracted frames\n    videos_name = sorted(glob.glob(osp.join(meta_root, '*.txt')))\n    videos_name = [osp.splitext(osp.basename(video_name))[0] for video_name in videos_name]\n\n    if opt['debug']:\n        videos_name = videos_name[:3]\n    else:\n        videos_name = videos_name[opt['ss_idx']:opt['to_idx']]\n\n    pbar = tqdm(total=len(videos_name), unit='video', desc='step6')\n\n    os.makedirs(opt['select_clips_meta'], exist_ok=True)\n    os.makedirs(opt['select_clips_frames'], exist_ok=True)\n    os.makedirs(opt['select_done_flags'], exist_ok=True)\n\n    pool = Pool(opt['n_thread'])\n    for video_name in videos_name:\n        pool.apply_async(worker6, args=(opt, video_name), callback=lambda arg: pbar.update(1))\n    pool.close()\n    pool.join()\n\n\ndef worker6(opt, video_name):\n    select_clips_meta = opt['select_clips_meta']\n    select_clips_frames = opt['select_clips_frames']\n    select_done_flags = opt['select_done_flags']\n\n    if osp.exists(osp.join(select_done_flags, f'{video_name}.txt')):\n        print(f'skip {video_name}.')\n        return\n\n    with open(osp.join(opt['detect_shot_root'], f'{video_name}.txt'), 'r') as f:\n        shots = f.readlines()\n        shots = [shot.strip() for shot in shots]\n    with open(osp.join(opt['estimate_flow_root'], f'{video_name}.txt'), 'r') as f:\n        flows = f.readlines()\n        flows = [flow.strip() for flow in flows]\n    with open(osp.join(opt['black_flag_root'], f'{video_name}.txt'), 'r') as f:\n        black_flags = f.readlines()\n        black_flags = [black_flag.strip() for black_flag in black_flags]\n    with open(osp.join(opt['iqa_score_root'], f'{video_name}.txt'), 'r') as f:\n        iqa_scores = f.readlines()\n        iqa_scores = [iqa_score.strip() for iqa_score in iqa_scores]\n\n    flag_shot = filter_frozen_shots(shots, flows)\n    flag = np.where(flag_shot == 1)\n    flag = flag[0].tolist()\n    filtered_shots = [shots[i] for i in flag]\n\n    clips, scores = generate_clips(\n        filtered_shots, flows, black_flags, iqa_scores, max_length=opt['num_frames_per_clip'])\n    with open(osp.join(select_clips_meta, f'{video_name}.txt'), 'w') as f:\n        for i, clip in enumerate(clips):\n            os.makedirs(osp.join(select_clips_frames, f'{video_name}_{i}'), exist_ok=True)\n            for idx, info in enumerate(clip):\n                f.write(f'clip: {i:02d} {info} {scores[i]}\\n')\n                img_name = info.split(' ')[0] + '.png'\n                shutil.copy(\n                    osp.join(opt['save_frames_root'], video_name, img_name),\n                    osp.join(select_clips_frames, f'{video_name}_{i}', f'{idx:08d}.png'))\n            if i >= opt['num_clips_per_video'] - 1:\n                break\n\n    with open(osp.join(select_done_flags, f'{video_name}.txt'), 'w') as f:\n        f.write(f'{i+1} clips are selected for {video_name}.')\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        '--dataroot',\n        type=str,\n        required=True,\n        help='dataset root, dataroot/raw_videos should contains your HQ videos to be processed.')\n    parser.add_argument('--n_thread', type=int, default=4, help='Thread number.')\n    parser.add_argument('--run', type=str, default='123456', help='run which steps')\n    parser.add_argument('--debug', action='store_true')\n    parser.add_argument('--ss_idx', type=int, default=None, help='ss index')\n    parser.add_argument('--to_idx', type=int, default=None, help='to index')\n    parser.add_argument('--n_frames_per_clip', type=int, default=100)\n    parser.add_argument('--n_clips_per_video', type=int, default=1)\n    parser.add_argument('--select_clip_root', type=str, default='select_clips')\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "scripts/inference_animesr_frames.py",
    "content": "\"\"\"inference AnimeSR on frames\"\"\"\nimport argparse\nimport cv2\nimport glob\nimport numpy as np\nimport os\nimport psutil\nimport queue\nimport threading\nimport time\nimport torch\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom animesr.utils.inference_base import get_base_argument_parser, get_inference_model\nfrom animesr.utils.video_util import frames2video\nfrom basicsr.data.transforms import mod_crop\nfrom basicsr.utils.img_util import img2tensor, tensor2img\n\n\ndef read_img(path, require_mod_crop=True, mod_scale=4, input_rescaling_factor=1.0):\n    \"\"\" read an image tensor from a given path\n    Args:\n        path: image path\n        require_mod_crop: mod crop or not. since the arch is multi-scale, so mod crop is needed by default\n        mod_scale: scale factor for mod_crop\n\n\n    Returns:\n        torch.Tensor: size(1, c, h, w)\n    \"\"\"\n    img = cv2.imread(path)\n    img = img.astype(np.float32) / 255.\n\n    if input_rescaling_factor != 1.0:\n        h, w = img.shape[:2]\n        img = cv2.resize(\n            img, (int(w * input_rescaling_factor), int(h * input_rescaling_factor)), interpolation=cv2.INTER_LANCZOS4)\n\n    if require_mod_crop:\n        img = mod_crop(img, mod_scale)\n\n    img = img2tensor(img, bgr2rgb=True, float32=True)\n    return img.unsqueeze(0)\n\n\nclass IOConsumer(threading.Thread):\n    \"\"\"Since IO time can take up a significant portion of the total inference time,\n    so we use multi thread to write frames individually.\n    \"\"\"\n\n    def __init__(self, args: argparse.Namespace, que, qid):\n        super().__init__()\n        self._queue = que\n        self.qid = qid\n        self.args = args\n\n    def run(self):\n        while True:\n            msg = self._queue.get()\n            if isinstance(msg, str) and msg == 'quit':\n                break\n\n            output = msg['output']\n            imgname = msg['imgname']\n            out_img = tensor2img(output.squeeze(0))\n            if self.args.outscale != self.args.netscale:\n                h, w = out_img.shape[:2]\n                out_img = cv2.resize(\n                    out_img, (int(\n                        w * self.args.outscale / self.args.netscale), int(h * self.args.outscale / self.args.netscale)),\n                    interpolation=cv2.INTER_LANCZOS4)\n            cv2.imwrite(imgname, out_img)\n\n        print(f'IO for worker {self.qid} is done.')\n\n\n@torch.no_grad()\ndef main():\n    \"\"\"Inference demo for AnimeSR.\n    It mainly for restoring anime frames.\n    \"\"\"\n    parser = get_base_argument_parser()\n    parser.add_argument('--input_rescaling_factor', type=float, default=1.0)\n    parser.add_argument('--num_io_consumer', type=int, default=3, help='number of IO consumer')\n    parser.add_argument(\n        '--sample_interval',\n        type=int,\n        default=1,\n        help='save 1 frame for every $sample_interval frames. this will be useful for calculating the metrics')\n    parser.add_argument('--save_video_too', action='store_true')\n    args = parser.parse_args()\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model = get_inference_model(args, device)\n\n    # prepare output dir\n    frame_output = osp.join(args.output, args.expname, 'frames')\n    os.makedirs(frame_output, exist_ok=True)\n\n    # the input format can be:\n    # 1. clip folder which contains frames\n    # or 2. a folder which contains several clips\n    first_level_dir = len(glob.glob(osp.join(args.input, '*.png'))) > 0\n    if args.input.endswith('/'):\n        args.input = args.input[:-1]\n    if first_level_dir:\n        videos_name = [osp.basename(args.input)]\n        args.input = osp.dirname(args.input)\n    else:\n        videos_name = sorted(os.listdir(args.input))\n\n    pbar1 = tqdm(total=len(videos_name), unit='video', desc='inference')\n\n    que = queue.Queue()\n    consumers = [IOConsumer(args, que, f'IO_{i}') for i in range(args.num_io_consumer)]\n    for consumer in consumers:\n        consumer.start()\n\n    for video_name in videos_name:\n        video_folder_path = osp.join(args.input, video_name)\n        imgs_list = sorted(glob.glob(osp.join(video_folder_path, '*')))\n        num_imgs = len(imgs_list)\n        os.makedirs(osp.join(frame_output, video_name), exist_ok=True)\n\n        # prepare\n        prev = read_img(\n            imgs_list[0],\n            require_mod_crop=True,\n            mod_scale=args.mod_scale,\n            input_rescaling_factor=args.input_rescaling_factor).to(device)\n        cur = prev\n        nxt = read_img(\n            imgs_list[min(1, num_imgs - 1)],\n            require_mod_crop=True,\n            mod_scale=args.mod_scale,\n            input_rescaling_factor=args.input_rescaling_factor).to(device)\n        c, h, w = prev.size()[-3:]\n        state = prev.new_zeros(1, 64, h, w)\n        out = prev.new_zeros(1, c, h * args.netscale, w * args.netscale)\n\n        pbar2 = tqdm(total=num_imgs, unit='frame', desc='inference')\n        tot_model_time = 0\n        cnt_model_time = 0\n        for idx in range(num_imgs):\n            torch.cuda.synchronize()\n            start = time.time()\n            img_name = osp.splitext(osp.basename(imgs_list[idx]))[0]\n\n            out, state = model.cell(torch.cat((prev, cur, nxt), dim=1), out, state)\n\n            torch.cuda.synchronize()\n            model_time = time.time() - start\n            tot_model_time += model_time\n            cnt_model_time += 1\n\n            if (idx + 1) % args.sample_interval == 0:\n                # put the output frame to the queue to be consumed\n                que.put({'output': out.cpu().clone(), 'imgname': osp.join(frame_output, video_name, f'{img_name}.png')})\n\n            torch.cuda.synchronize()\n            start = time.time()\n            prev = cur\n            cur = nxt\n            nxt = read_img(\n                imgs_list[min(idx + 2, num_imgs - 1)],\n                require_mod_crop=True,\n                mod_scale=args.mod_scale,\n                input_rescaling_factor=args.input_rescaling_factor).to(device)\n            torch.cuda.synchronize()\n            read_time = time.time() - start\n\n            pbar2.update(1)\n            pbar2.set_description(f'read_time: {read_time}, model_time: {tot_model_time/cnt_model_time}')\n\n            mem = psutil.virtual_memory()\n            # since the speed of producer (model inference) is faster than the consumer (I/O)\n            # if there is a risk of OOM, just sleep to let the consumer work\n            if mem.percent > 80.0:\n                time.sleep(30)\n\n        pbar1.update(1)\n\n    for _ in range(args.num_io_consumer):\n        que.put('quit')\n    for consumer in consumers:\n        consumer.join()\n\n    if not args.save_video_too:\n        return\n\n    # convert the frames to videos\n    video_output = osp.join(args.output, args.expname, 'videos')\n    os.makedirs(video_output, exist_ok=True)\n    for video_name in videos_name:\n        out_path = osp.join(video_output, f'{video_name}.mp4')\n        frames2video(\n            osp.join(frame_output, video_name), out_path, fps=24 if args.fps is None else args.fps, suffix='png')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "scripts/inference_animesr_video.py",
    "content": "import cv2\nimport ffmpeg\nimport glob\nimport mimetypes\nimport numpy as np\nimport os\nimport shutil\nimport subprocess\nimport torch\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom animesr.utils import video_util\nfrom animesr.utils.inference_base import get_base_argument_parser, get_inference_model\nfrom basicsr.data.transforms import mod_crop\nfrom basicsr.utils.img_util import img2tensor, tensor2img\nfrom basicsr.utils.logger import AvgTimer\n\n\ndef get_video_meta_info(video_path):\n    \"\"\"get the meta info of the video by using ffprobe with python interface\"\"\"\n    ret = {}\n    probe = ffmpeg.probe(video_path)\n    video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']\n    has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])\n    ret['width'] = video_streams[0]['width']\n    ret['height'] = video_streams[0]['height']\n    ret['fps'] = eval(video_streams[0]['avg_frame_rate'])\n    ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None\n    try:\n        ret['nb_frames'] = int(video_streams[0]['nb_frames'])\n    except KeyError:  # bilibili transcoder dont have nb_frames\n        ret['duration'] = float(probe['format']['duration'])\n        ret['nb_frames'] = int(ret['duration'] * ret['fps'])\n        print(ret['duration'], ret['nb_frames'])\n    return ret\n\n\ndef get_sub_video(args, num_process, process_idx):\n    \"\"\"Cut the whole video into num_process parts, return the process_idx-th part\"\"\"\n    if num_process == 1:\n        return args.input\n    meta = get_video_meta_info(args.input)\n    duration = int(meta['nb_frames'] / meta['fps'])\n    part_time = duration // num_process\n    print(f'duration: {duration}, part_time: {part_time}')\n    out_path = osp.join(args.output, 'inp_sub_videos', f'{process_idx:03d}.mp4')\n    cmd = [\n        args.ffmpeg_bin,\n        f'-i {args.input}',\n        f'-ss {part_time * process_idx}',\n        f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '',\n        '-async 1',\n        out_path,\n        '-y',\n    ]\n    print(' '.join(cmd))\n    subprocess.call(' '.join(cmd), shell=True)\n    return out_path\n\n\nclass Reader:\n    \"\"\"read frames from a video stream or frames list\"\"\"\n\n    def __init__(self, args, total_workers=1, worker_idx=0, device=torch.device('cuda')):\n        self.args = args\n        input_type = mimetypes.guess_type(args.input)[0]\n        self.input_type = 'folder' if input_type is None else input_type\n        self.paths = []  # for image&folder type\n        self.audio = None\n        self.input_fps = None\n        if self.input_type.startswith('video'):\n            video_path = get_sub_video(args, total_workers, worker_idx)\n            # read bgr from stream, which is the same format as opencv\n            self.stream_reader = (\n                ffmpeg\n                .input(video_path)\n                .output('pipe:', format='rawvideo', pix_fmt='bgr24', loglevel='error')\n                .run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)\n            )  # yapf: disable  # noqa\n            meta = get_video_meta_info(video_path)\n            self.width = meta['width']\n            self.height = meta['height']\n            self.input_fps = meta['fps']\n            self.audio = meta['audio']\n            self.nb_frames = meta['nb_frames']\n\n        else:\n            if self.input_type.startswith('image'):\n                self.paths = [args.input]\n            else:\n                paths = sorted(glob.glob(os.path.join(args.input, '*')))\n                tot_frames = len(paths)\n                num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0)\n                self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)]\n\n            self.nb_frames = len(self.paths)\n            assert self.nb_frames > 0, 'empty folder'\n            from PIL import Image\n            tmp_img = Image.open(self.paths[0])  # lazy load\n            self.width, self.height = tmp_img.size\n        self.idx = 0\n        self.device = device\n\n    def get_resolution(self):\n        return self.height, self.width\n\n    def get_fps(self):\n        \"\"\"the fps of sr video is set to the user input fps first, followed by the input fps,\n        If the first two values are None, then the commonly used fps 24 is set\"\"\"\n        if self.args.fps is not None:\n            return self.args.fps\n        elif self.input_fps is not None:\n            return self.input_fps\n        return 24\n\n    def get_audio(self):\n        return self.audio\n\n    def __len__(self):\n        \"\"\"return the number of frames for this worker, however, this may be not accurate for video stream\"\"\"\n        return self.nb_frames\n\n    def get_frame_from_stream(self):\n        img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3)  # 3 bytes for one pixel\n        if not img_bytes:\n            # end of stream\n            return None\n        img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])\n        return img\n\n    def get_frame_from_list(self):\n        if self.idx >= self.nb_frames:\n            return None\n        img = cv2.imread(self.paths[self.idx])\n        self.idx += 1\n        return img\n\n    def get_frame(self):\n        if self.input_type.startswith('video'):\n            img = self.get_frame_from_stream()\n        else:\n            img = self.get_frame_from_list()\n\n        if img is None:\n            raise StopIteration\n\n        # bgr uint8 numpy -> rgb float32 [0, 1] tensor on device\n        img = img.astype(np.float32) / 255.\n        img = mod_crop(img, self.args.mod_scale)\n        img = img2tensor(img, bgr2rgb=True, float32=True).unsqueeze(0).to(self.device)\n        if self.args.half:\n            # half precision won't make a big impact on visuals\n            img = img.half()\n        return img\n\n    def close(self):\n        # close the video stream\n        if self.input_type.startswith('video'):\n            self.stream_reader.stdin.close()\n            self.stream_reader.wait()\n\n\nclass Writer:\n    \"\"\"write frames to a video stream\"\"\"\n\n    def __init__(self, args, audio, height, width, video_save_path, fps):\n        out_width, out_height = int(width * args.outscale), int(height * args.outscale)\n        if out_height > 2160:\n            print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',\n                  'We highly recommend to decrease the outscale(aka, -s).')\n\n        vsp = video_save_path\n        if audio is not None:\n            self.stream_writer = (\n                ffmpeg\n                .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', framerate=fps)\n                .output(audio, vsp, pix_fmt='yuv420p', vcodec='libx264', loglevel='error', acodec='copy')\n                .overwrite_output()\n                .run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)\n            )  # yapf: disable  # noqa\n        else:\n            self.stream_writer = (\n                ffmpeg\n                .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', framerate=fps)\n                .output(vsp, pix_fmt='yuv420p', vcodec='libx264', loglevel='error')\n                .overwrite_output()\n                .run_async(pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)\n            )  # yapf: disable  # noqa\n\n        self.out_width = out_width\n        self.out_height = out_height\n        self.args = args\n\n    def write_frame(self, frame):\n        if self.args.outscale != self.args.netscale:\n            frame = cv2.resize(frame, (self.out_width, self.out_height), interpolation=cv2.INTER_LANCZOS4)\n        self.stream_writer.stdin.write(frame.tobytes())\n\n    def close(self):\n        self.stream_writer.stdin.close()\n        self.stream_writer.wait()\n\n\n@torch.no_grad()\ndef inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0):\n    # prepare model\n    model = get_inference_model(args, device)\n\n    # prepare reader and writer\n    reader = Reader(args, total_workers, worker_idx, device=device)\n    audio = reader.get_audio()\n    height, width = reader.get_resolution()\n    height = height - height % args.mod_scale\n    width = width - width % args.mod_scale\n    fps = reader.get_fps()\n    writer = Writer(args, audio, height, width, video_save_path, fps)\n\n    # initialize pre/cur/nxt frames, pre sr frame, and pre hidden state for inference\n    end_flag = False\n    prev = reader.get_frame()\n    cur = prev\n    try:\n        nxt = reader.get_frame()\n    except StopIteration:\n        end_flag = True\n        nxt = cur\n    state = prev.new_zeros(1, 64, height, width)\n    out = prev.new_zeros(1, 3, height * args.netscale, width * args.netscale)\n\n    pbar = tqdm(total=len(reader), unit='frame', desc='inference')\n    model_timer = AvgTimer()  # model inference time tracker\n    i_timer = AvgTimer()  # I(input read) time tracker\n    o_timer = AvgTimer()  # O(output write) time tracker\n    while True:\n        # inference at current step\n        torch.cuda.synchronize(device=device)\n        model_timer.start()\n        out, state = model.cell(torch.cat((prev, cur, nxt), dim=1), out, state)\n        torch.cuda.synchronize(device=device)\n        model_timer.record()\n\n        # write current sr frame to video stream\n        torch.cuda.synchronize(device=device)\n        o_timer.start()\n        output_frame = tensor2img(out, rgb2bgr=False)\n        writer.write_frame(output_frame)\n        torch.cuda.synchronize(device=device)\n        o_timer.record()\n\n        # if end of stream, break\n        if end_flag:\n            break\n\n        # move the sliding window\n        torch.cuda.synchronize(device=device)\n        i_timer.start()\n        prev = cur\n        cur = nxt\n        try:\n            nxt = reader.get_frame()\n        except StopIteration:\n            nxt = cur\n            end_flag = True\n        torch.cuda.synchronize(device=device)\n        i_timer.record()\n\n        # update&print infomation\n        pbar.update(1)\n        pbar.set_description(\n            f'I: {i_timer.get_avg_time():.4f} O: {o_timer.get_avg_time():.4f} Model: {model_timer.get_avg_time():.4f}')\n\n    reader.close()\n    writer.close()\n\n\ndef run(args):\n    if args.suffix is None:\n        args.suffix = ''\n    else:\n        args.suffix = f'_{args.suffix}'\n    video_save_path = osp.join(args.output, f'{args.video_name}{args.suffix}.mp4')\n\n    # set up multiprocessing\n    num_gpus = torch.cuda.device_count()\n    num_process = num_gpus * args.num_process_per_gpu\n    if num_process == 1:\n        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n        inference_video(args, video_save_path, device=device)\n        return\n\n    ctx = torch.multiprocessing.get_context('spawn')\n    pool = ctx.Pool(num_process)\n    out_sub_videos_dir = osp.join(args.output, 'out_sub_videos')\n    os.makedirs(out_sub_videos_dir, exist_ok=True)\n    os.makedirs(osp.join(args.output, 'inp_sub_videos'), exist_ok=True)\n\n    pbar = tqdm(total=num_process, unit='sub_video', desc='inference')\n    for i in range(num_process):\n        sub_video_save_path = osp.join(out_sub_videos_dir, f'{i:03d}.mp4')\n        pool.apply_async(\n            inference_video,\n            args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i),\n            callback=lambda arg: pbar.update(1))\n    pool.close()\n    pool.join()\n\n    # combine sub videos\n    # prepare vidlist.txt\n    with open(f'{args.output}/vidlist.txt', 'w') as f:\n        for i in range(num_process):\n            f.write(f'file \\'out_sub_videos/{i:03d}.mp4\\'\\n')\n    # To avoid video&audio desync as mentioned in https://github.com/xinntao/Real-ESRGAN/issues/388\n    # we use the solution provided in https://stackoverflow.com/a/52156277 to solve this issue\n    cmd = [\n        args.ffmpeg_bin,\n        '-f', 'concat',\n        '-safe', '0',\n        '-i', f'{args.output}/vidlist.txt',\n        '-c:v', 'copy',\n        '-af', 'aresample=async=1000',\n        video_save_path,\n        '-y',\n    ]  # yapf: disable\n    print(' '.join(cmd))\n    subprocess.call(cmd)\n    shutil.rmtree(out_sub_videos_dir)\n    shutil.rmtree(osp.join(args.output, 'inp_sub_videos'))\n    os.remove(f'{args.output}/vidlist.txt')\n\n\ndef main():\n    \"\"\"Inference demo for AnimeSR.\n    It mainly for restoring anime videos.\n    \"\"\"\n    parser = get_base_argument_parser()\n    parser.add_argument(\n        '--extract_frame_first',\n        action='store_true',\n        help='if input is a video, you can still extract the frames first, other wise AnimeSR will read from stream')\n    parser.add_argument(\n        '--num_process_per_gpu', type=int, default=1, help='the total process is number_process_per_gpu * num_gpu')\n    parser.add_argument(\n        '--suffix', type=str, default=None, help='you can add a suffix string to the sr video name, for example, x2')\n    args = parser.parse_args()\n    args.ffmpeg_bin = os.environ.get('ffmpeg_exe_path', 'ffmpeg')\n\n    args.input = args.input.rstrip('/').rstrip('\\\\')\n\n    if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'):\n        is_video = True\n    else:\n        is_video = False\n\n    if args.extract_frame_first and not is_video:\n        args.extract_frame_first = False\n\n    # prepare input and output\n    args.video_name = osp.splitext(osp.basename(args.input))[0]\n    args.output = osp.join(args.output, args.expname, 'videos', args.video_name)\n    os.makedirs(args.output, exist_ok=True)\n    if args.extract_frame_first:\n        inp_extracted_frames = osp.join(args.output, 'inp_extracted_frames')\n        os.makedirs(inp_extracted_frames, exist_ok=True)\n        video_util.video2frames(args.input, inp_extracted_frames, force=True, high_quality=True)\n        video_meta = get_video_meta_info(args.input)\n        args.fps = video_meta['fps']\n        args.input = inp_extracted_frames\n\n    run(args)\n\n    if args.extract_frame_first:\n        shutil.rmtree(args.input)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "scripts/metrics/MANIQA/inference_MANIQA.py",
    "content": "import argparse\nimport os\nimport random\nimport torch\nfrom pipal_data import NTIRE2022\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\nfrom tqdm import tqdm\nfrom utils import Normalize, ToTensor, crop_image\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Inference script of RealBasicVSR')\n    parser.add_argument('--model_path', help='checkpoint file', required=True)\n    parser.add_argument('--input_dir', help='directory of the input video', required=True)\n    parser.add_argument(\n        '--output_dir',\n        help='directory of the output results',\n        default='output/ensemble_attentionIQA2_finetune_e2/AnimeSR')\n    args = parser.parse_args()\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    # configuration\n    batch_size = 10\n    num_workers = 8\n    average_iters = 20\n    crop_size = 224\n    os.makedirs(args.output_dir, exist_ok=True)\n\n    model = torch.load(args.model_path)\n\n    # map to cuda, if available\n    cuda_flag = False\n    if torch.cuda.is_available():\n        model = model.cuda()\n        cuda_flag = True\n    model.eval()\n    total_avg_score = []\n    subfolder_namelist = []\n    for subfolder_name in sorted(os.listdir(args.input_dir)):\n        avg_score = 0.0\n        subfolder_root = os.path.join(args.input_dir, subfolder_name)\n\n        if os.path.isdir(subfolder_root) and subfolder_name != 'assemble-folder':\n            # data load\n            val_dataset = NTIRE2022(\n                ref_path=subfolder_root,\n                dis_path=subfolder_root,\n                transform=transforms.Compose([Normalize(0.5, 0.5), ToTensor()]),\n            )\n            val_loader = DataLoader(\n                dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=False)\n\n            name_list, pred_list = [], []\n            with open(os.path.join(args.output_dir, f'{subfolder_name}.txt'), 'w') as f:\n\n                for data in tqdm(val_loader):\n                    pred = 0\n\n                    for i in range(average_iters):\n                        if cuda_flag:\n                            x_d = data['d_img_org'].cuda()\n                        b, c, h, w = x_d.shape\n                        top = random.randint(0, h - crop_size)\n                        left = random.randint(0, w - crop_size)\n                        img = crop_image(top, left, crop_size, img=x_d)\n                        with torch.no_grad():\n                            pred += model(img)\n                    pred /= average_iters\n                    d_name = data['d_name']\n                    pred = pred.cpu().numpy()\n                    name_list.extend(d_name)\n                    pred_list.extend(pred)\n\n                for i in range(len(name_list)):\n                    f.write(f'{name_list[i]}, {float(pred_list[i][0]): .6f}\\n')\n                    avg_score += float(pred_list[i][0])\n\n                avg_score /= len(name_list)\n                f.write(f'The average score of {subfolder_name} is {avg_score:.6f}')\n                f.close()\n                subfolder_namelist.append(subfolder_name)\n                total_avg_score.append(avg_score)\n\n    with open(os.path.join(args.output_dir, 'average.txt'), 'w') as f:\n        for idx, averge_score in enumerate(total_avg_score):\n            string = f'Folder {subfolder_namelist[idx]}, Average Score: {averge_score:.6f}\\n'\n            f.write(string)\n            print(f'Folder {subfolder_namelist[idx]}, Average Score: {averge_score:.6f}')\n\n        print(f'Average Score of {len(subfolder_namelist)} Folders: {sum(total_avg_score) / len(total_avg_score):.6f}')\n        string = f'Average Score of {len(subfolder_namelist)} Folders: {sum(total_avg_score) / len(total_avg_score):.6f}'  # noqa E501\n        f.write(string)\n        f.close()\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "scripts/metrics/MANIQA/models/model_attentionIQA2.py",
    "content": "# flake8: noqa\nimport timm\nimport torch\nfrom einops import rearrange\nfrom models.swin import SwinTransformer\nfrom timm.models.vision_transformer import Block\nfrom torch import nn\n\n\nclass ChannelAttn(nn.Module):\n\n    def __init__(self, dim, drop=0.1):\n        super().__init__()\n        self.c_q = nn.Linear(dim, dim)\n        self.c_k = nn.Linear(dim, dim)\n        self.c_v = nn.Linear(dim, dim)\n        self.norm_fact = dim**-0.5\n        self.softmax = nn.Softmax(dim=-1)\n        self.attn_drop = nn.Dropout(drop)\n        self.proj_drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        _x = x\n        B, C, N = x.shape\n        q = self.c_q(x)\n        k = self.c_k(x)\n        v = self.c_v(x)\n\n        attn = q @ k.transpose(-2, -1) * self.norm_fact\n        attn = self.softmax(attn)\n        attn = self.attn_drop(attn)\n        x = (attn @ v).transpose(1, 2).reshape(B, C, N)\n        x = self.proj_drop(x)\n        x = x + _x\n        return x\n\n\nclass SaveOutput:\n\n    def __init__(self):\n        self.outputs = []\n\n    def __call__(self, module, module_in, module_out):\n        self.outputs.append(module_out)\n\n    def clear(self):\n        self.outputs = []\n\n\nclass AttentionIQA(nn.Module):\n\n    def __init__(self,\n                 embed_dim=72,\n                 num_outputs=1,\n                 patch_size=8,\n                 drop=0.1,\n                 depths=[2, 2],\n                 window_size=4,\n                 dim_mlp=768,\n                 num_heads=[4, 4],\n                 img_size=224,\n                 num_channel_attn=2,\n                 **kwargs):\n        super().__init__()\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.input_size = img_size // patch_size\n        self.patches_resolution = (img_size // patch_size, img_size // patch_size)\n\n        self.vit = timm.create_model('vit_base_patch8_224', pretrained=True)\n        self.save_output = SaveOutput()\n        hook_handles = []\n        for layer in self.vit.modules():\n            if isinstance(layer, Block):\n                handle = layer.register_forward_hook(self.save_output)\n                hook_handles.append(handle)\n\n        self.channel_attn1 = nn.Sequential(*[ChannelAttn(self.input_size**2) for i in range(num_channel_attn)])\n        self.channel_attn2 = nn.Sequential(*[ChannelAttn(self.input_size**2) for i in range(num_channel_attn)])\n\n        self.conv1 = nn.Conv2d(embed_dim * 4, embed_dim, 1, 1, 0)\n        self.swintransformer1 = SwinTransformer(\n            patches_resolution=self.patches_resolution,\n            depths=depths,\n            num_heads=num_heads,\n            embed_dim=embed_dim,\n            window_size=window_size,\n            dim_mlp=dim_mlp)\n        self.swintransformer2 = SwinTransformer(\n            patches_resolution=self.patches_resolution,\n            depths=depths,\n            num_heads=num_heads,\n            embed_dim=embed_dim // 2,\n            window_size=window_size,\n            dim_mlp=dim_mlp)\n        self.conv2 = nn.Conv2d(embed_dim, embed_dim // 2, 1, 1, 0)\n\n        self.fc_score = nn.Sequential(\n            nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(), nn.Dropout(drop),\n            nn.Linear(embed_dim // 2, num_outputs), nn.ReLU())\n        self.fc_weight = nn.Sequential(\n            nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(), nn.Dropout(drop),\n            nn.Linear(embed_dim // 2, num_outputs), nn.Sigmoid())\n\n    def extract_feature(self, save_output):\n        x6 = save_output.outputs[6][:, 1:]\n        x7 = save_output.outputs[7][:, 1:]\n        x8 = save_output.outputs[8][:, 1:]\n        x9 = save_output.outputs[9][:, 1:]\n        x = torch.cat((x6, x7, x8, x9), dim=2)\n        return x\n\n    def forward(self, x):\n        _x = self.vit(x)\n        x = self.extract_feature(self.save_output)\n        self.save_output.outputs.clear()\n\n        # stage 1\n        x = rearrange(x, 'b (h w) c -> b c (h w)', h=self.input_size, w=self.input_size)\n        x = self.channel_attn1(x)\n        x = rearrange(x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size)\n        x = self.conv1(x)\n        x = self.swintransformer1(x)\n\n        # stage2\n        x = rearrange(x, 'b c h w -> b c (h w)', h=self.input_size, w=self.input_size)\n        x = self.channel_attn2(x)\n        x = rearrange(x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size)\n        x = self.conv2(x)\n        x = self.swintransformer2(x)\n\n        x = rearrange(x, 'b c h w -> b (h w) c', h=self.input_size, w=self.input_size)\n        f = self.fc_score(x)\n        w = self.fc_weight(x)\n        s = torch.sum(f * w, dim=1) / torch.sum(w, dim=1)\n        return s\n"
  },
  {
    "path": "scripts/metrics/MANIQA/models/swin.py",
    "content": "\"\"\"\nisort:skip_file\n\"\"\"\n# flake8: noqa\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom einops import rearrange\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom torch import nn\n\"\"\" attention decoder mask \"\"\"\n\n\ndef get_attn_decoder_mask(seq):\n    subsequent_mask = torch.ones_like(seq).unsqueeze(-1).expand(seq.size(0), seq.size(1), seq.size(1))\n    subsequent_mask = subsequent_mask.triu(diagonal=1)  # upper triangular part of a matrix(2-D)\n    return subsequent_mask\n\n\n\"\"\" attention pad mask \"\"\"\n\n\ndef get_attn_pad_mask(seq_q, seq_k, i_pad):\n    batch_size, len_q = seq_q.size()\n    batch_size, len_k = seq_k.size()\n    pad_attn_mask = seq_k.data.eq(i_pad)\n    pad_attn_mask = pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)\n    return pad_attn_mask\n\n\nclass DecoderWindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim**-0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.W_Q = nn.Linear(dim, dim)\n        self.W_K = nn.Linear(dim, dim)\n        self.W_V = nn.Linear(dim, dim)\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, q, k, v, mask=None, attn_mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = q.shape\n\n        q = self.W_Q(q).view(B_, N, self.num_heads, C // self.num_heads).transpose(1, 2)\n        k = self.W_K(k).view(B_, N, self.num_heads, C // self.num_heads).transpose(1, 2)\n        v = self.W_V(v).view(B_, N, self.num_heads, C // self.num_heads).transpose(1, 2)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n        attn.masked_fill_(attn_mask, -1e9)\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\n\"\"\" decoder layer \"\"\"\n\n\nclass DecoderLayer(nn.Module):\n\n    def __init__(self,\n                 input_resolution=(28, 28),\n                 embed_dim=256,\n                 layer_norm_epsilon=1e-12,\n                 dim_mlp=1024,\n                 num_heads=4,\n                 dim_head=128,\n                 window_size=7,\n                 shift_size=0,\n                 i_layer=0,\n                 act_layer=nn.GELU,\n                 drop=0.,\n                 drop_path=0.):\n        super().__init__()\n        self.i_layer = i_layer\n        self.shift_size = shift_size\n        self.window_size = window_size\n        self.input_resolution = input_resolution\n        self.conv = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)\n        self.layer_norm1 = nn.LayerNorm(embed_dim, eps=layer_norm_epsilon)\n        self.dec_enc_attn_wmsa = DecoderWindowAttention(\n            dim=embed_dim,\n            window_size=to_2tuple(window_size),\n            num_heads=num_heads,\n        )\n        self.layer_norm2 = nn.LayerNorm(embed_dim, eps=layer_norm_epsilon)\n        self.dec_enc_attn_swmsa = DecoderWindowAttention(\n            dim=embed_dim,\n            window_size=to_2tuple(window_size),\n            num_heads=num_heads,\n        )\n        self.layer_norm3 = nn.LayerNorm(embed_dim, eps=layer_norm_epsilon)\n        self.mlp = Mlp(in_features=embed_dim, hidden_features=dim_mlp, act_layer=act_layer, drop=drop)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        # calculate attention mask for SW-MSA\n        H, W = self.input_resolution\n        img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n        h_slices = (slice(0, -self.window_size), slice(-self.window_size,\n                                                       -self.shift_size), slice(-self.shift_size, None))\n        w_slices = (slice(0, -self.window_size), slice(-self.window_size,\n                                                       -self.shift_size), slice(-self.shift_size, None))\n        cnt = 0\n        for h in h_slices:\n            for w in w_slices:\n                img_mask[:, h, w, :] = cnt\n                cnt += 1\n\n        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def partition(self, inputs, B, H, W, C, shift_size=0):\n        # partition mask_dec_inputs\n        inputs = inputs.view(B, H, W, C)\n        if shift_size > 0:\n            shifted_inputs = torch.roll(inputs, shifts=(-shift_size, -shift_size), dims=(1, 2))\n        else:\n            shifted_inputs = inputs\n        windows_inputs = window_partition(shifted_inputs, self.window_size)  # nW*B, window_size, window_size, C\n        windows_inputs = windows_inputs.view(-1, self.window_size * self.window_size,\n                                             C)  # nW*B, window_size*window_size, C\n        return windows_inputs\n\n    def reverse(self, inputs, B, H, W, C, shift_size=0):\n        # merge windows\n        inputs = inputs.view(-1, self.window_size, self.window_size, C)\n        inputs = window_reverse(inputs, self.window_size, H, W)  # B H' W' C\n        # reverse cyclic shift\n        if shift_size > 0:\n            inputs = torch.roll(inputs, shifts=(shift_size, shift_size), dims=(1, 2))\n        else:\n            inputs = inputs\n        inputs = inputs.view(B, H * W, C)\n        return inputs\n\n    def forward(self, mask_dec_inputs, enc_outputs, self_attn_mask, dec_enc_attn_mask):\n        H, W = self.input_resolution[0], self.input_resolution[1]\n        B, L, C = mask_dec_inputs.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        dec_enc_att_outputs = mask_dec_inputs\n        shortcut1 = dec_enc_att_outputs\n        dec_enc_att_outputs = self.layer_norm1(dec_enc_att_outputs)\n        enc_outputs = self.partition(enc_outputs, B, H, W, C, shift_size=0)\n        dec_enc_att_outputs = self.partition(dec_enc_att_outputs, B, H, W, C, shift_size=0)\n        dec_enc_att_outputs = self.dec_enc_attn_wmsa(\n            q=dec_enc_att_outputs, k=enc_outputs, v=enc_outputs, mask=None, attn_mask=dec_enc_attn_mask)\n        dec_enc_att_outputs = self.reverse(dec_enc_att_outputs, B, H, W, C, shift_size=0)\n        enc_outputs = self.reverse(enc_outputs, B, H, W, C, shift_size=0)\n        dec_enc_att_outputs = shortcut1 + self.drop_path(dec_enc_att_outputs)\n\n        shortcut2 = dec_enc_att_outputs\n        dec_enc_att_outputs = self.layer_norm2(dec_enc_att_outputs)\n        enc_outputs = self.partition(enc_outputs, B, H, W, C, shift_size=self.window_size // 2)\n        dec_enc_att_outputs = self.partition(dec_enc_att_outputs, B, H, W, C, shift_size=self.window_size // 2)\n        dec_enc_att_outputs = self.dec_enc_attn_swmsa(\n            q=dec_enc_att_outputs, k=enc_outputs, v=enc_outputs, mask=self.attn_mask, attn_mask=dec_enc_attn_mask)\n        dec_enc_att_outputs = self.reverse(dec_enc_att_outputs, B, H, W, C, shift_size=self.window_size // 2)\n        enc_outputs = self.reverse(enc_outputs, B, H, W, C, shift_size=self.window_size // 2)\n        dec_enc_att_outputs = shortcut2 + self.drop_path(dec_enc_att_outputs)\n\n        shortcut3 = dec_enc_att_outputs\n        dec_enc_att_outputs = self.layer_norm3(dec_enc_att_outputs)\n        dec_enc_att_outputs = self.mlp(dec_enc_att_outputs)\n        dec_enc_att_outputs = shortcut3 + self.drop_path(dec_enc_att_outputs)\n\n        dec_enc_att_outputs = rearrange(\n            dec_enc_att_outputs, 'b (h w) c -> b c h w', h=self.input_resolution[0], w=self.input_resolution[1])\n        # if self.i_layer % 2 == 0:\n        #     dec_enc_att_outputs = self.conv(dec_enc_att_outputs)\n        dec_enc_att_outputs = rearrange(dec_enc_att_outputs, 'b c h w -> b (h w) c')\n        return dec_enc_att_outputs\n\n\n\"\"\" decoder \"\"\"\n\n\nclass SwinDecoder(nn.Module):\n\n    def __init__(self,\n                 input_resolution=(28, 28),\n                 embed_dim=256,\n                 num_heads=4,\n                 num_layers=2,\n                 drop=0.1,\n                 i_pad=0,\n                 dim_mlp=1024,\n                 window_size=7,\n                 drop_path_rate=0.1):\n        super().__init__()\n        self.window_size = window_size\n        self.embed_dim = embed_dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.i_pad = i_pad\n        self.dropout = nn.Dropout(drop)\n        self.layers = nn.ModuleList()\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]\n        for i_layer in range(num_layers):\n            layer = DecoderLayer(\n                input_resolution=(input_resolution[0], input_resolution[1]),\n                embed_dim=embed_dim,\n                dim_mlp=dim_mlp,\n                window_size=window_size,\n                i_layer=i_layer + 1,\n                shift_size=window_size // 2,\n                drop_path=dpr[i_layer])\n            self.layers.append(layer)\n\n    def forward(self, y_embed, enc_outputs):\n        inputs_embed = y_embed\n        B, C, H, W = y_embed.shape\n        inputs_embed = rearrange(inputs_embed, 'b c h w -> b (h w) c')\n        dec_outputs = self.dropout(inputs_embed)\n\n        idx = 1\n        down_rate = 0\n        for layer in self.layers:\n            window_num = int((self.input_resolution[0] // 2 ** down_rate) // self.window_size) * \\\n                         int((self.input_resolution[1] // 2 ** down_rate) // self.window_size)\n            enc_inputs_length = self.window_size * self.window_size\n            dec_inputs_length = self.window_size * self.window_size\n            mask_enc_inputs = torch.ones(B * window_num, enc_inputs_length).cuda()\n            mask_dec_inputs = torch.ones(B * window_num, dec_inputs_length).cuda()\n\n            dec_attn_pad_mask = get_attn_pad_mask(mask_dec_inputs, mask_dec_inputs, self.i_pad)\n            dec_attn_decoder_mask = get_attn_decoder_mask(mask_dec_inputs)\n            dec_self_attn_mask = torch.gt((dec_attn_pad_mask + dec_attn_decoder_mask), 0)\n            dec_enc_attn_mask = get_attn_pad_mask(mask_dec_inputs, mask_enc_inputs, self.i_pad)\n\n            dec_self_attn_mask = dec_self_attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)\n            dec_enc_attn_mask = dec_enc_attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)\n\n            dec_outputs = layer(dec_outputs, enc_outputs[idx - 1], dec_self_attn_mask, dec_enc_attn_mask)\n            idx += 1\n        dec_outputs = rearrange(\n            dec_outputs,\n            'b (h w) c -> b c h w',\n            h=self.input_resolution[0] // 2**down_rate,\n            w=self.input_resolution[1] // 2**down_rate)\n        return dec_outputs\n\n\nclass Mlp(nn.Module):\n\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim**-0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1))\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass SwinBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 input_resolution,\n                 num_heads,\n                 window_size=7,\n                 shift_size=0,\n                 dim_mlp=1024.,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 drop=0.,\n                 attn_drop=0.,\n                 drop_path=0.,\n                 act_layer=nn.GELU,\n                 norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.dim_mlp = dim_mlp\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim,\n            window_size=to_2tuple(self.window_size),\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            qk_scale=qk_scale,\n            attn_drop=attn_drop,\n            proj_drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = self.dim_mlp\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (slice(0, -self.window_size), slice(-self.window_size,\n                                                           -self.shift_size), slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size), slice(-self.window_size,\n                                                           -self.shift_size), slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 input_resolution,\n                 depth,\n                 num_heads,\n                 window_size=7,\n                 dim_mlp=1024,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 drop=0.,\n                 attn_drop=0.,\n                 drop_path=0.,\n                 norm_layer=nn.LayerNorm,\n                 downsample=None,\n                 use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.conv = nn.Conv2d(dim, dim, 3, 1, 1)\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinBlock(\n                dim=dim,\n                input_resolution=input_resolution,\n                num_heads=num_heads,\n                window_size=window_size,\n                shift_size=0 if (i % 2 == 0) else window_size // 2,\n                dim_mlp=dim_mlp,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=drop,\n                attn_drop=attn_drop,\n                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                norm_layer=norm_layer) for i in range(depth)\n        ])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        x = rearrange(x, 'b (h w) c -> b c h w', h=self.input_resolution[0], w=self.input_resolution[1])\n        x = F.relu(self.conv(x))\n        x = rearrange(x, 'b c h w -> b (h w) c')\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass SwinTransformer(nn.Module):\n\n    def __init__(self,\n                 patches_resolution,\n                 depths=[2, 2, 6, 2],\n                 num_heads=[3, 6, 12, 24],\n                 embed_dim=256,\n                 drop=0.1,\n                 drop_rate=0.,\n                 drop_path_rate=0.1,\n                 dropout=0.,\n                 window_size=7,\n                 dim_mlp=1024,\n                 qkv_bias=True,\n                 qk_scale=None,\n                 attn_drop_rate=0.,\n                 norm_layer=nn.LayerNorm,\n                 downsample=None,\n                 use_checkpoint=False,\n                 **kwargs):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.depths = depths\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.pos_drop = nn.Dropout(p=drop_rate)\n        self.dropout = nn.Dropout(p=drop)\n        self.num_features = embed_dim\n        self.num_layers = len(depths)\n        self.patches_resolution = (patches_resolution[0], patches_resolution[1])\n        self.downsample = nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=3, stride=2, padding=1)\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=self.embed_dim,\n                input_resolution=patches_resolution,\n                depth=self.depths[i_layer],\n                num_heads=self.num_heads[i_layer],\n                window_size=self.window_size,\n                dim_mlp=dim_mlp,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                drop=dropout,\n                attn_drop=attn_drop_rate,\n                drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],\n                norm_layer=norm_layer,\n                downsample=downsample,\n                use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n    def forward(self, x):\n        x = self.dropout(x)\n        x = self.pos_drop(x)\n        x = rearrange(x, 'b c h w -> b (h w) c')\n        for layer in self.layers:\n            _x = x\n            x = layer(x)\n            x = 0.13 * x + _x\n        x = rearrange(x, 'b (h w) c -> b c h w', h=self.patches_resolution[0], w=self.patches_resolution[1])\n        return x\n"
  },
  {
    "path": "scripts/metrics/MANIQA/pipal_data.py",
    "content": "import cv2\nimport numpy as np\nimport os\nimport torch\n\n\nclass NTIRE2022(torch.utils.data.Dataset):\n\n    def __init__(self, ref_path, dis_path, transform):\n        super(NTIRE2022, self).__init__()\n        self.ref_path = ref_path\n        self.dis_path = dis_path\n        self.transform = transform\n\n        ref_files_data, dis_files_data = [], []\n        for dis in os.listdir(dis_path):\n            ref = dis\n            ref_files_data.append(ref)\n            dis_files_data.append(dis)\n        self.data_dict = {'r_img_list': ref_files_data, 'd_img_list': dis_files_data}\n\n    def __len__(self):\n        return len(self.data_dict['r_img_list'])\n\n    def __getitem__(self, idx):\n        # r_img: H x W x C -> C x H x W\n        r_img_name = self.data_dict['r_img_list'][idx]\n        r_img = cv2.imread(os.path.join(self.ref_path, r_img_name), cv2.IMREAD_COLOR)\n        r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)\n        r_img = np.array(r_img).astype('float32') / 255\n        r_img = np.transpose(r_img, (2, 0, 1))\n\n        d_img_name = self.data_dict['d_img_list'][idx]\n        d_img = cv2.imread(os.path.join(self.dis_path, d_img_name), cv2.IMREAD_COLOR)\n        d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)\n        d_img = np.array(d_img).astype('float32') / 255\n        d_img = np.transpose(d_img, (2, 0, 1))\n        sample = {'r_img_org': r_img, 'd_img_org': d_img, 'd_name': d_img_name}\n        if self.transform:\n            sample = self.transform(sample)\n        return sample\n"
  },
  {
    "path": "scripts/metrics/MANIQA/utils.py",
    "content": "import numpy as np\nimport torch\n\n\ndef crop_image(top, left, patch_size, img=None):\n    tmp_img = img[:, :, top:top + patch_size, left:left + patch_size]\n    return tmp_img\n\n\nclass RandCrop(object):\n\n    def __init__(self, patch_size, num_crop):\n        self.patch_size = patch_size\n        self.num_crop = num_crop\n\n    def __call__(self, sample):\n        # r_img : C x H x W (numpy)\n        r_img, d_img = sample['r_img_org'], sample['d_img_org']\n        d_name = sample['d_name']\n\n        c, h, w = d_img.shape\n        new_h = self.patch_size\n        new_w = self.patch_size\n        ret_r_img = np.zeros((c, self.patch_size, self.patch_size))\n        ret_d_img = np.zeros((c, self.patch_size, self.patch_size))\n        for _ in range(self.num_crop):\n            top = np.random.randint(0, h - new_h)\n            left = np.random.randint(0, w - new_w)\n            tmp_r_img = r_img[:, top:top + new_h, left:left + new_w]\n            tmp_d_img = d_img[:, top:top + new_h, left:left + new_w]\n            ret_r_img += tmp_r_img\n            ret_d_img += tmp_d_img\n        ret_r_img /= self.num_crop\n        ret_d_img /= self.num_crop\n\n        sample = {'r_img_org': ret_r_img, 'd_img_org': ret_d_img, 'd_name': d_name}\n\n        return sample\n\n\nclass Normalize(object):\n\n    def __init__(self, mean, var):\n        self.mean = mean\n        self.var = var\n\n    def __call__(self, sample):\n        # r_img: C x H x W (numpy)\n        r_img, d_img = sample['r_img_org'], sample['d_img_org']\n        d_name = sample['d_name']\n\n        r_img = (r_img - self.mean) / self.var\n        d_img = (d_img - self.mean) / self.var\n\n        sample = {'r_img_org': r_img, 'd_img_org': d_img, 'd_name': d_name}\n        return sample\n\n\nclass RandHorizontalFlip(object):\n\n    def __init__(self):\n        pass\n\n    def __call__(self, sample):\n        r_img, d_img = sample['r_img_org'], sample['d_img_org']\n        d_name = sample['d_name']\n        prob_lr = np.random.random()\n        # np.fliplr needs HxWxC\n        if prob_lr > 0.5:\n            d_img = np.fliplr(d_img).copy()\n            r_img = np.fliplr(r_img).copy()\n\n        sample = {'r_img_org': r_img, 'd_img_org': d_img, 'd_name': d_name}\n        return sample\n\n\nclass ToTensor(object):\n\n    def __init__(self):\n        pass\n\n    def __call__(self, sample):\n        r_img, d_img = sample['r_img_org'], sample['d_img_org']\n        d_name = sample['d_name']\n        d_img = torch.from_numpy(d_img).type(torch.FloatTensor)\n        r_img = torch.from_numpy(r_img).type(torch.FloatTensor)\n\n        sample = {'r_img_org': r_img, 'd_img_org': d_img, 'd_name': d_name}\n        return sample\n"
  },
  {
    "path": "scripts/metrics/README.md",
    "content": "# Instruction for calculating metrics\n\n## Prepare the frames\nFor fast evaluation, we measure 1 frames every 10 frames, that is, the 0-th frame, 10-th frame, 20-th frame, etc. will participate the metrics calculation.\nThis can be achieved by `sample_interval` argument.\n```bash\npython scripts/inference_animesr_frames.py -i AVC-RealLQ-ROOT -n AnimeSR_v1-PaperModel --expname animesr_v1_si10 --sample_interval 10\n```\n\n## MANIQA calculation\n### requirements\n`pip install timm==0.5.4`\n### checkpoint\nwe use the ensemble model provided by the authors to compute MANIQA. You can download the checkpoint from [Ondrive](https://1drv.ms/u/s!Akkg8_btnkS7mU7eb_dFDHkW05B2?e=G922hL).\n### inference:\n```bash\n# cd into scripts/metrics/MANIQA\npython inference_MANIQA.py --model_path MANIQA_CKPT_YOU_JUST_DOWNLOADED --input_dir ../../../results/animesr_v1_si10/frames --output_dir output/ensemble_attentionIQA2_finetune_e2/AnimeSR_v1_si10\n```\nnote that the result has certain randomness, but the error should be relatively small.\n### license\nthe MANIQA codes&checkpoint are original from [MANIQA](https://github.com/IIGROUP/MANIQA) and @[TianheWu](https://github.com/TianheWu).\n"
  },
  {
    "path": "setup.cfg",
    "content": "[flake8]\nignore =\n    # line break before binary operator (W503)\n    W503,\n    # line break after binary operator (W504)\n    W504,\nmax-line-length=120\n\n[yapf]\nbased_on_style = pep8\ncolumn_limit = 120\nblank_line_before_nested_class_or_def = true\nsplit_before_expression_after_opening_paren = true\n\n[isort]\nline_length = 120\nmulti_line_output = 0\nknown_standard_library = pkg_resources,setuptools\nknown_first_party = basicsr,facexlib,animesr\nknown_third_party = PIL,cv2,ffmpeg,numpy,psutil,torch,torchvision,tqdm\nno_lines_before = STDLIB,LOCALFOLDER\ndefault_section = THIRDPARTY\n"
  },
  {
    "path": "setup.py",
    "content": "#!/usr/bin/env python\n\nfrom setuptools import find_packages, setup\n\nimport os\nimport subprocess\nimport time\n\nversion_file = 'animesr/version.py'\n\n\ndef readme():\n    with open('README.md', encoding='utf-8') as f:\n        content = f.read()\n    return content\n\n\ndef get_git_hash():\n\n    def _minimal_ext_cmd(cmd):\n        # construct minimal environment\n        env = {}\n        for k in ['SYSTEMROOT', 'PATH', 'HOME']:\n            v = os.environ.get(k)\n            if v is not None:\n                env[k] = v\n        # LANGUAGE is used on win32\n        env['LANGUAGE'] = 'C'\n        env['LANG'] = 'C'\n        env['LC_ALL'] = 'C'\n        out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]\n        return out\n\n    try:\n        out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])\n        sha = out.strip().decode('ascii')\n    except OSError:\n        sha = 'unknown'\n\n    return sha\n\n\ndef get_hash():\n    if os.path.exists('.git'):\n        sha = get_git_hash()[:7]\n    elif os.path.exists(version_file):\n        try:\n            from animesr.version import __version__\n            sha = __version__.split('+')[-1]\n        except ImportError:\n            raise ImportError('Unable to get git version')\n    else:\n        sha = 'unknown'\n\n    return sha\n\n\ndef write_version_py():\n    content = \"\"\"# GENERATED VERSION FILE\n# TIME: {}\n__version__ = '{}'\n__gitsha__ = '{}'\nversion_info = ({})\n\"\"\"\n    sha = get_hash()\n    with open('VERSION', 'r') as f:\n        SHORT_VERSION = f.read().strip()\n    VERSION_INFO = ', '.join([x if x.isdigit() else f'\"{x}\"' for x in SHORT_VERSION.split('.')])\n\n    version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)\n    with open(version_file, 'w') as f:\n        f.write(version_file_str)\n\n\ndef get_version():\n    with open(version_file, 'r') as f:\n        exec(compile(f.read(), version_file, 'exec'))\n    return locals()['__version__']\n\n\ndef get_requirements(filename='requirements.txt'):\n    here = os.path.dirname(os.path.realpath(__file__))\n    with open(os.path.join(here, filename), 'r') as f:\n        requires = [line.replace('\\n', '') for line in f.readlines()]\n    return requires\n\n\nif __name__ == '__main__':\n    write_version_py()\n    setup(\n        name='animesr',\n        version=get_version(),\n        description='AnimeSR: Learning Real-World Super-Resolution Models for Animation Videos (NeurIPS 2022)',\n        long_description=readme(),\n        long_description_content_type='text/markdown',\n        author='Yanze Wu',\n        author_email='wuyanze123@gmail.com',\n        keywords='computer vision, pytorch, image restoration, super-resolution',\n        url='https://github.com/TencentARC/AnimeSR',\n        include_package_data=True,\n        packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),\n        classifiers=[\n            'Development Status :: 4 - Beta',\n            'License :: OSI Approved :: Apache Software License',\n            'Operating System :: OS Independent',\n            'Programming Language :: Python :: 3',\n            'Programming Language :: Python :: 3.7',\n            'Programming Language :: Python :: 3.8',\n        ],\n        license='BSD-3-Clause License',\n        setup_requires=['cython', 'numpy'],\n        install_requires=get_requirements(),\n        zip_safe=False)\n"
  }
]