[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n.idea/\n\nexamples/results/*\ngfpgan/*\ncheckpoints/*\nassets/*\nresults/*\nDockerfile\nstart_docker.sh\nstart.sh\n\ncheckpoints\n\n# Mac\n.DS_Store\n"
  },
  {
    "path": "LICENSE",
    "content": "Tencent is pleased to support the open source community by making SadTalker available.\n\nCopyright (C), a Tencent company. All rights reserved.\n\nSadTalker is licensed under the Apache 2.0 License, except for the third-party components listed below.\n\nTerms of the Apache License Version 2.0:\n---------------------------------------------\n                                Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n<img src='https://user-images.githubusercontent.com/4397546/229094115-862c747e-7397-4b54-ba4a-bd368bfe2e0f.png' width='500px'/>\n\n\n<!--<h2> 😭 SadTalker： <span style=\"font-size:12px\">Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation </span> </h2> -->\n\n  <a href='https://arxiv.org/abs/2211.12194'><img src='https://img.shields.io/badge/ArXiv-PDF-red'></a> &nbsp; <a href='https://sadtalker.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp; [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) &nbsp; [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/vinthony/SadTalker) &nbsp; [![sd webui-colab](https://img.shields.io/badge/Automatic1111-Colab-green)](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) &nbsp; <br> [![Replicate](https://replicate.com/cjwbw/sadtalker/badge)](https://replicate.com/cjwbw/sadtalker) [![Discord](https://dcbadge.vercel.app/api/server/rrayYqZ4tf?style=flat)](https://discord.gg/rrayYqZ4tf)\n\n<div>\n    <a target='_blank'>Wenxuan Zhang <sup>*,1,2</sup> </a>&emsp;\n    <a href='https://vinthony.github.io/' target='_blank'>Xiaodong Cun <sup>*,2</a>&emsp;\n    <a href='https://xuanwangvc.github.io/' target='_blank'>Xuan Wang <sup>3</sup></a>&emsp;\n    <a href='https://yzhang2016.github.io/' target='_blank'>Yong Zhang <sup>2</sup></a>&emsp;\n    <a href='https://xishen0220.github.io/' target='_blank'>Xi Shen <sup>2</sup></a>&emsp; </br>\n    <a href='https://yuguo-xjtu.github.io/' target='_blank'>Yu Guo<sup>1</sup> </a>&emsp;\n    <a href='https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ' target='_blank'>Ying Shan <sup>2</sup> </a>&emsp;\n    <a target='_blank'>Fei Wang <sup>1</sup> </a>&emsp;\n</div>\n<br>\n<div>\n    <sup>1</sup> Xi'an Jiaotong University &emsp; <sup>2</sup> Tencent AI Lab &emsp; <sup>3</sup> Ant Group &emsp; \n</div>\n<br>\n<i><strong><a href='https://arxiv.org/abs/2211.12194' target='_blank'>CVPR 2023</a></strong></i>\n<br>\n<br>\n\n\n![sadtalker](https://user-images.githubusercontent.com/4397546/222490039-b1f6156b-bf00-405b-9fda-0c9a9156f991.gif)\n\n<b>TL;DR: &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; single portrait image 🙎‍♂️  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; audio 🎤  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; =  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; talking head video 🎞.</b>\n\n<br>\n\n</div>\n\n\n\n## Highlights\n\n- The license has been updated to Apache 2.0, and we've removed the non-commercial restriction\n- **SadTalker has now officially been integrated into Discord, where you can use it for free by sending files. You can also generate high-quailty videos from text prompts. Join: [![Discord](https://dcbadge.vercel.app/api/server/rrayYqZ4tf?style=flat)](https://discord.gg/rrayYqZ4tf)**\n\n- We've published a [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) extension. Check out more details [here](docs/webui_extension.md). [Demo Video](https://user-images.githubusercontent.com/4397546/231495639-5d4bb925-ea64-4a36-a519-6389917dac29.mp4)\n\n- Full image mode is now available! [More details...](https://github.com/OpenTalker/SadTalker#full-bodyimage-generation)\n\n| still+enhancer in v0.0.1                 | still + enhancer   in v0.0.2       |   [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) |\n|:--------------------: |:--------------------: | :----: |\n| <video  src=\"https://user-images.githubusercontent.com/48216707/229484996-5d7be64f-2553-4c9e-a452-c5cf0b8ebafe.mp4\" type=\"video/mp4\"> </video> | <video  src=\"https://user-images.githubusercontent.com/4397546/230717873-355b7bf3-d3de-49f9-a439-9220e623fce7.mp4\" type=\"video/mp4\"> </video>  | <img src='./examples/source_image/full_body_2.png' width='380'> \n\n- Several new modes (Still, reference, and resize modes) are now available!\n\n- We're happy to see more community demos on [bilibili](https://search.bilibili.com/all?keyword=sadtalker), [YouTube](https://www.youtube.com/results?search_query=sadtalker) and [X (#sadtalker)](https://twitter.com/search?q=%23sadtalker&src).\n\n## Changelog \n\nThe previous changelog can be found [here](docs/changlelog.md).\n\n- __[2023.06.12]__: Added more new features in WebUI extension, see the discussion [here](https://github.com/OpenTalker/SadTalker/discussions/386).\n\n- __[2023.06.05]__: Released a new 512x512px (beta) face model. Fixed some bugs and improve the performance.\n\n- __[2023.04.15]__: Added a WebUI Colab notebook by [@camenduru](https://github.com/camenduru/): [![sd webui-colab](https://img.shields.io/badge/Automatic1111-Colab-green)](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb)\n\n- __[2023.04.12]__: Added a more detailed WebUI installation document and fixed a problem when reinstalling.\n\n- __[2023.04.12]__: Fixed the WebUI safe issues becasue of 3rd-party packages, and optimized the output path in `sd-webui-extension`.\n\n- __[2023.04.08]__: In v0.0.2, we added a logo watermark to the generated video to prevent abuse. _This watermark has since been removed in a later release._\n\n- __[2023.04.08]__: In v0.0.2, we added features for full image animation and a link to download checkpoints from Baidu. We also optimized the enhancer logic.\n\n## To-Do\n\nWe're tracking new updates in [issue #280](https://github.com/OpenTalker/SadTalker/issues/280).\n\n## Troubleshooting\n\nIf you have any problems, please read our [FAQs](docs/FAQ.md) before opening an issue.\n\n\n\n## 1. Installation.\n\nCommunity tutorials: [中文Windows教程 (Chinese Windows tutorial)](https://www.bilibili.com/video/BV1Dc411W7V6/) | [日本語コース (Japanese tutorial)](https://br-d.fanbox.cc/posts/5685086).\n\n### Linux/Unix\n\n1. Install [Anaconda](https://www.anaconda.com/), Python and `git`.\n\n2. Creating the env and install the requirements.\n  ```bash\n  git clone https://github.com/OpenTalker/SadTalker.git\n\n  cd SadTalker \n\n  conda create -n sadtalker python=3.8\n\n  conda activate sadtalker\n\n  pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\n\n  conda install ffmpeg\n\n  pip install -r requirements.txt\n\n  ### Coqui TTS is optional for gradio demo. \n  ### pip install TTS\n\n  ```  \n### Windows\n\nA video tutorial in chinese is available [here](https://www.bilibili.com/video/BV1Dc411W7V6/). You can also follow the following instructions:\n\n1. Install [Python 3.8](https://www.python.org/downloads/windows/) and check \"Add Python to PATH\".\n2. Install [git](https://git-scm.com/download/win) manually or using [Scoop](https://scoop.sh/): `scoop install git`.\n3. Install `ffmpeg`, following [this tutorial](https://www.wikihow.com/Install-FFmpeg-on-Windows) or using [scoop](https://scoop.sh/): `scoop install ffmpeg`.\n4. Download the SadTalker repository by running `git clone https://github.com/Winfredy/SadTalker.git`.\n5. Download the checkpoints and gfpgan models in the [downloads section](#2-download-models).\n6. Run `start.bat` from Windows Explorer as normal, non-administrator, user, and a Gradio-powered WebUI demo will be started.\n\n### macOS\n\nA tutorial on installing SadTalker on macOS can be found [here](docs/install.md).\n\n### Docker, WSL, etc\n\nPlease check out additional tutorials [here](docs/install.md).\n\n## 2. Download Models\n\nYou can run the following script on Linux/macOS to automatically download all the models:\n\n```bash\nbash scripts/download_models.sh\n```\n\nWe also provide an offline patch (`gfpgan/`), so no model will be downloaded when generating.\n\n### Pre-Trained Models\n\n* [Google Drive](https://drive.google.com/file/d/1gwWh45pF7aelNP_P78uDJL8Sycep-K7j/view?usp=sharing)\n* [GitHub Releases](https://github.com/OpenTalker/SadTalker/releases)\n* [Baidu (百度云盘)](https://pan.baidu.com/s/1kb1BCPaLOWX1JJb9Czbn6w?pwd=sadt) (Password: `sadt`)\n\n<!-- TODO add Hugging Face links -->\n\n### GFPGAN Offline Patch\n\n* [Google Drive](https://drive.google.com/file/d/19AIBsmfcHW6BRJmeqSFlG5fL445Xmsyi?usp=sharing)\n* [GitHub Releases](https://github.com/OpenTalker/SadTalker/releases)\n* [Baidu (百度云盘)](https://pan.baidu.com/s/1P4fRgk9gaSutZnn8YW034Q?pwd=sadt) (Password: `sadt`)\n\n<!-- TODO add Hugging Face links -->\n\n\n<details><summary>Model Details</summary>\n\n\nModel explains:\n\n##### New version \n| Model | Description\n| :--- | :----------\n|checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.\n|checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker.\n|checkpoints/SadTalker_V0.0.2_256.safetensors | packaged sadtalker checkpoints of old version, 256 face render).\n|checkpoints/SadTalker_V0.0.2_512.safetensors | packaged sadtalker checkpoints of old version, 512 face render).\n|gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`.\n  \n  \n##### Old version\n| Model | Description\n| :--- | :----------\n|checkpoints/auido2exp_00300-model.pth | Pre-trained ExpNet in Sadtalker.\n|checkpoints/auido2pose_00140-model.pth | Pre-trained PoseVAE in Sadtalker.\n|checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.\n|checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker.\n|checkpoints/facevid2vid_00189-model.pth.tar | Pre-trained face-vid2vid model from [the reappearance of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis).\n|checkpoints/epoch_20.pth | Pre-trained 3DMM extractor in [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction).\n|checkpoints/wav2lip.pth | Highly accurate lip-sync model in [Wav2lip](https://github.com/Rudrabha/Wav2Lip).\n|checkpoints/shape_predictor_68_face_landmarks.dat | Face landmark model used in [dilb](http://dlib.net/). \n|checkpoints/BFM | 3DMM library file.  \n|checkpoints/hub | Face detection models used in [face alignment](https://github.com/1adrianb/face-alignment).\n|gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`.\n\nThe final folder will be shown as:\n\n<img width=\"331\" alt=\"image\" src=\"https://user-images.githubusercontent.com/4397546/232511411-4ca75cbf-a434-48c5-9ae0-9009e8316484.png\">\n\n\n</details>\n\n## 3. Quick Start\n\nPlease read our document on [best practices and configuration tips](docs/best_practice.md)\n\n### WebUI Demos\n\n**Online Demo**: [HuggingFace](https://huggingface.co/spaces/vinthony/SadTalker) | [SDWebUI-Colab](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) | [Colab](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)\n\n**Local WebUI extension**: Please refer to [WebUI docs](docs/webui_extension.md).\n\n**Local gradio demo (recommanded)**: A Gradio instance similar to our [Hugging Face demo](https://huggingface.co/spaces/vinthony/SadTalker) can be run locally:\n\n```bash\n## you need manually install TTS(https://github.com/coqui-ai/TTS) via `pip install tts` in advanced.\npython app_sadtalker.py\n```\n\nYou can also start it more easily:\n\n- windows: just double click `webui.bat`, the requirements will be installed automatically.\n- Linux/Mac OS: run `bash webui.sh` to start the webui.\n\n\n### CLI usage\n\n##### Animating a portrait image from default config:\n```bash\npython inference.py --driven_audio <audio.wav> \\\n                    --source_image <video.mp4 or picture.png> \\\n                    --enhancer gfpgan \n```\nThe results will be saved in `results/$SOME_TIMESTAMP/*.mp4`.\n\n##### Full body/image Generation:\n\nUsing `--still` to generate a natural full body video. You can add `enhancer` to improve the quality of the generated video. \n\n```bash\npython inference.py --driven_audio <audio.wav> \\\n                    --source_image <video.mp4 or picture.png> \\\n                    --result_dir <a file to store results> \\\n                    --still \\\n                    --preprocess full \\\n                    --enhancer gfpgan \n```\n\nMore examples and configuration and tips can be founded in the [ >>> best practice documents <<<](docs/best_practice.md).\n\n## Citation\n\nIf you find our work useful in your research, please consider citing:\n\n```bibtex\n@article{zhang2022sadtalker,\n  title={SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation},\n  author={Zhang, Wenxuan and Cun, Xiaodong and Wang, Xuan and Zhang, Yong and Shen, Xi and Guo, Yu and Shan, Ying and Wang, Fei},\n  journal={arXiv preprint arXiv:2211.12194},\n  year={2022}\n}\n```\n\n## Acknowledgements\n\nFacerender code borrows heavily from [zhanglonghao's reproduction of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis) and [PIRender](https://github.com/RenYurui/PIRender). We thank the authors for sharing their wonderful code. In training process, we also used the model from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction) and [Wav2lip](https://github.com/Rudrabha/Wav2Lip). We thank for their wonderful work.\n\nWe also use the following 3rd-party libraries:\n\n- **Face Utils**: https://github.com/xinntao/facexlib\n- **Face Enhancement**: https://github.com/TencentARC/GFPGAN\n- **Image/Video Enhancement**:https://github.com/xinntao/Real-ESRGAN\n\n## Extensions:\n\n- [SadTalker-Video-Lip-Sync](https://github.com/Zz-ww/SadTalker-Video-Lip-Sync) from [@Zz-ww](https://github.com/Zz-ww): SadTalker for Video Lip Editing\n\n## Related Works\n- [StyleHEAT: One-Shot High-Resolution Editable Talking Face Generation via Pre-trained StyleGAN (ECCV 2022)](https://github.com/FeiiYin/StyleHEAT)\n- [CodeTalker: Speech-Driven 3D Facial Animation with Discrete Motion Prior (CVPR 2023)](https://github.com/Doubiiu/CodeTalker)\n- [VideoReTalking: Audio-based Lip Synchronization for Talking Head Video Editing In the Wild (SIGGRAPH Asia 2022)](https://github.com/vinthony/video-retalking)\n- [DPE: Disentanglement of Pose and Expression for General Video Portrait Editing (CVPR 2023)](https://github.com/Carlyx/DPE)\n- [3D GAN Inversion with Facial Symmetry Prior (CVPR 2023)](https://github.com/FeiiYin/SPI/)\n- [T2M-GPT: Generating Human Motion from Textual Descriptions with Discrete Representations (CVPR 2023)](https://github.com/Mael-zys/T2M-GPT)\n\n## Disclaimer\n\nThis is not an official product of Tencent. \n\n```\n1. Please carefully read and comply with the open-source license applicable to this code before using it. \n2. Please carefully read and comply with the intellectual property declaration applicable to this code before using it.\n3. This open-source code runs completely offline and does not collect any personal information or other data. If you use this code to provide services to end-users and collect related data, please take necessary compliance measures according to applicable laws and regulations (such as publishing privacy policies, adopting necessary data security strategies, etc.). If the collected data involves personal information, user consent must be obtained (if applicable). Any legal liabilities arising from this are unrelated to Tencent.\n4. Without Tencent's written permission, you are not authorized to use the names or logos legally owned by Tencent, such as \"Tencent.\" Otherwise, you may be liable for legal responsibilities.\n5. This open-source code does not have the ability to directly provide services to end-users. If you need to use this code for further model training or demos, as part of your product to provide services to end-users, or for similar use, please comply with applicable laws and regulations for your product or service. Any legal liabilities arising from this are unrelated to Tencent.\n6. It is prohibited to use this open-source code for activities that harm the legitimate rights and interests of others (including but not limited to fraud, deception, infringement of others' portrait rights, reputation rights, etc.), or other behaviors that violate applicable laws and regulations or go against social ethics and good customs (including providing incorrect or false information, spreading pornographic, terrorist, and violent information, etc.). Otherwise, you may be liable for legal responsibilities.\n```\n\nLOGO: color and font suggestion: [ChatGPT](https://chat.openai.com), logo font: [Montserrat Alternates\n](https://fonts.google.com/specimen/Montserrat+Alternates?preview.text=SadTalker&preview.text_type=custom&query=mont).\n\nAll the copyrights of the demo images and audio are from community users or the generation from stable diffusion. Feel free to contact us if you would like use to remove them.\n\n\n<!-- Spelling fixed on Tuesday, September 12, 2023 by @fakerybakery (https://github.com/fakerybakery). These changes are licensed under the Apache 2.0 license. -->\n"
  },
  {
    "path": "app_sadtalker.py",
    "content": "import os, sys\nimport gradio as gr\nfrom src.gradio_demo import SadTalker  \n\n\ntry:\n    import webui  # in webui\n    in_webui = True\nexcept:\n    in_webui = False\n\n\ndef toggle_audio_file(choice):\n    if choice == False:\n        return gr.update(visible=True), gr.update(visible=False)\n    else:\n        return gr.update(visible=False), gr.update(visible=True)\n    \ndef ref_video_fn(path_of_ref_video):\n    if path_of_ref_video is not None:\n        return gr.update(value=True)\n    else:\n        return gr.update(value=False)\n\ndef sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None):\n\n    sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)\n\n    with gr.Blocks(analytics_enabled=False) as sadtalker_interface:\n        gr.Markdown(\"<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \\\n                    <a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \\\n                    <a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a>  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \\\n                     <a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>\")\n        \n        with gr.Row().style(equal_height=False):\n            with gr.Column(variant='panel'):\n                with gr.Tabs(elem_id=\"sadtalker_source_image\"):\n                    with gr.TabItem('Upload image'):\n                        with gr.Row():\n                            source_image = gr.Image(label=\"Source image\", source=\"upload\", type=\"filepath\", elem_id=\"img2img_image\").style(width=512)\n\n                with gr.Tabs(elem_id=\"sadtalker_driven_audio\"):\n                    with gr.TabItem('Upload OR TTS'):\n                        with gr.Column(variant='panel'):\n                            driven_audio = gr.Audio(label=\"Input audio\", source=\"upload\", type=\"filepath\")\n\n                        if sys.platform != 'win32' and not in_webui: \n                            from src.utils.text2speech import TTSTalker\n                            tts_talker = TTSTalker()\n                            with gr.Column(variant='panel'):\n                                input_text = gr.Textbox(label=\"Generating audio from text\", lines=5, placeholder=\"please enter some text here, we genreate the audio from text using @Coqui.ai TTS.\")\n                                tts = gr.Button('Generate audio',elem_id=\"sadtalker_audio_generate\", variant='primary')\n                                tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])\n                            \n            with gr.Column(variant='panel'): \n                with gr.Tabs(elem_id=\"sadtalker_checkbox\"):\n                    with gr.TabItem('Settings'):\n                        gr.Markdown(\"need help? please visit our [best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md) for more detials\")\n                        with gr.Column(variant='panel'):\n                            # width = gr.Slider(minimum=64, elem_id=\"img2img_width\", maximum=2048, step=8, label=\"Manually Crop Width\", value=512) # img2img_width\n                            # height = gr.Slider(minimum=64, elem_id=\"img2img_height\", maximum=2048, step=8, label=\"Manually Crop Height\", value=512) # img2img_width\n                            pose_style = gr.Slider(minimum=0, maximum=46, step=1, label=\"Pose style\", value=0) # \n                            size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info=\"use 256/512 model?\") # \n                            preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info=\"How to handle input image?\")\n                            is_still_mode = gr.Checkbox(label=\"Still Mode (fewer head motion, works with preprocess `full`)\")\n                            batch_size = gr.Slider(label=\"batch size in generation\", step=1, maximum=10, value=2)\n                            enhancer = gr.Checkbox(label=\"GFPGAN as Face enhancer\")\n                            submit = gr.Button('Generate', elem_id=\"sadtalker_generate\", variant='primary')\n                            \n                with gr.Tabs(elem_id=\"sadtalker_genearted\"):\n                        gen_video = gr.Video(label=\"Generated video\", format=\"mp4\").style(width=256)\n\n        if warpfn:\n            submit.click(\n                        fn=warpfn(sad_talker.test), \n                        inputs=[source_image,\n                                driven_audio,\n                                preprocess_type,\n                                is_still_mode,\n                                enhancer,\n                                batch_size,                            \n                                size_of_image,\n                                pose_style\n                                ], \n                        outputs=[gen_video]\n                        )\n        else:\n            submit.click(\n                        fn=sad_talker.test, \n                        inputs=[source_image,\n                                driven_audio,\n                                preprocess_type,\n                                is_still_mode,\n                                enhancer,\n                                batch_size,                            \n                                size_of_image,\n                                pose_style\n                                ], \n                        outputs=[gen_video]\n                        )\n\n    return sadtalker_interface\n \n\nif __name__ == \"__main__\":\n\n    demo = sadtalker_demo()\n    demo.queue()\n    demo.launch()\n\n\n"
  },
  {
    "path": "cog.yaml",
    "content": "build:\n  gpu: true\n  cuda: \"11.3\"\n  python_version: \"3.8\"\n  system_packages:\n    - \"ffmpeg\"\n    - \"libgl1-mesa-glx\"\n    - \"libglib2.0-0\"\n  python_packages:\n    - \"torch==1.12.1\"\n    - \"torchvision==0.13.1\"\n    - \"torchaudio==0.12.1\"\n    - \"joblib==1.1.0\"\n    - \"scikit-image==0.19.3\"\n    - \"basicsr==1.4.2\"\n    - \"facexlib==0.3.0\"\n    - \"resampy==0.3.1\"\n    - \"pydub==0.25.1\"\n    - \"scipy==1.10.1\"\n    - \"kornia==0.6.8\"\n    - \"face_alignment==1.3.5\"\n    - \"imageio==2.19.3\"\n    - \"imageio-ffmpeg==0.4.7\"\n    - \"librosa==0.9.2\" #\n    - \"tqdm==4.65.0\"\n    - \"yacs==0.1.8\"\n    - \"gfpgan==1.3.8\"\n    - \"dlib-bin==19.24.1\"\n    - \"av==10.0.0\"\n    - \"trimesh==3.9.20\"\n  run:\n    - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document \"/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth\" \"https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth\"\n    - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document \"/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip\" \"https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip\"\n\npredict: \"predict.py:Predictor\"\n"
  },
  {
    "path": "docs/FAQ.md",
    "content": "\n## Frequency Asked Question\n\n**Q: `ffmpeg` is not recognized as an internal or external command**\n\nIn Linux, you can install the ffmpeg via `conda install ffmpeg`. Or on Mac OS X, try to install ffmpeg via `brew install ffmpeg`. On windows, make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), then, following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) installation to install `ffmpeg`.\n\n**Q: Running Requirments.**\n\nPlease refer to the discussion here: https://github.com/Winfredy/SadTalker/issues/124#issuecomment-1508113989\n\n\n**Q: ModuleNotFoundError: No module named 'ai'**\n\nplease check the checkpoint's size of the `epoch_20.pth`. (https://github.com/Winfredy/SadTalker/issues/167, https://github.com/Winfredy/SadTalker/issues/113)\n\n**Q: Illegal Hardware Error: Mac M1**\n\nplease reinstall the `dlib` by `pip install dlib` individually. (https://github.com/Winfredy/SadTalker/issues/129, https://github.com/Winfredy/SadTalker/issues/109)\n\n\n**Q: FileNotFoundError: [Errno 2] No such file or directory: checkpoints\\BFM_Fitting\\similarity_Lm3D_all.mat**\n\nMake sure you have downloaded the checkpoints and gfpgan as [here](https://github.com/Winfredy/SadTalker#-2-download-trained-models) and placed them in the right place. \n\n**Q: RuntimeError: unexpected EOF, expected 237192 more bytes. The file might be corrupted.**\n\nThe files are not automatically downloaded. Please update the code and download the gfpgan folders as [here](https://github.com/Winfredy/SadTalker#-2-download-trained-models).\n\n**Q: CUDA out of memory error**\n\nplease refer to https://stackoverflow.com/questions/73747731/runtimeerror-cuda-out-of-memory-how-setting-max-split-size-mb\n\n``` \n# windows\nset PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 \npython inference.py ...\n\n# linux\nexport PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 \npython inference.py ...\n```\n\n**Q: Error while decoding stream #0:0: Invalid data found when processing input [mp3float @ 0000015037628c00] Header missing**\n\nOur method only support wav or mp3 files as input, please make sure the feeded audios are in these formats.\n"
  },
  {
    "path": "docs/best_practice.md",
    "content": "# Best Practices and Tips for configuration\n\n> Our model only works on REAL people or the portrait image similar to REAL person. The anime talking head genreation method will be released in future.\n\nAdvanced confiuration options for `inference.py`:\n\n| Name        | Configuration | default |   Explaination  | \n|:------------- |:------------- |:----- | :------------- |\n| Enhance Mode | `--enhancer` | None | Using `gfpgan` or `RestoreFormer` to enhance the generated face via face restoration network \n| Background Enhancer | `--background_enhancer` | None | Using `realesrgan` to enhance the full video. \n| Still Mode   | ` --still` | False |  Using the same pose parameters as the original image, fewer head motion.\n| Expressive Mode | `--expression_scale` | 1.0 | a larger value will make the expression motion stronger.\n| save path | `--result_dir` |`./results` | The file will be save in the newer location.\n| preprocess | `--preprocess` | `crop` | Run and produce the results in the croped input image. Other choices: `resize`, where the images will be resized to the specific resolution. `full` Run the full image animation, use with `--still` to get better results.\n| ref Mode (eye) | `--ref_eyeblink` | None | A video path, where we borrow the eyeblink from this reference video to provide more natural eyebrow movement.\n| ref Mode (pose) | `--ref_pose` | None | A video path, where we borrow the pose from the head reference video. \n| 3D Mode | `--face3dvis` | False | Need additional installation. More details to generate the 3d face can be founded [here](docs/face3d.md). \n| free-view Mode | `--input_yaw`,<br> `--input_pitch`,<br> `--input_roll` | None | Genearting novel view or free-view 4D talking head from a single image. More details can be founded [here](https://github.com/Winfredy/SadTalker#generating-4d-free-view-talking-examples-from-audio-and-a-single-image).\n\n\n### About `--preprocess`\n\nOur system automatically handles the input images via `crop`, `resize` and `full`.\n\nIn `crop` mode, we only generate the croped image via the facial keypoints and generated the facial anime avator. The animation of both expression and head pose are realistic.\n\n> Still mode will stop the eyeblink and head pose movement.\n\n|  [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) | crop | crop w/still |\n|:--------------------: |:--------------------: | :----: |\n| <img src='../examples/source_image/full_body_2.png' width='380'> | ![full_body_2](example_crop.gif) | ![full_body_2](example_crop_still.gif) |\n\n\nIn `resize` mode, we resize the whole images to generate the fully talking head video. Thus, an image similar to the ID photo can be produced. ⚠️ It will produce bad results for full person images.\n\n\n \n\n| <img src='../examples/source_image/full_body_2.png' width='380'> |  <img src='../examples/source_image/full4.jpeg' width='380'> |\n|:--------------------: |:--------------------: |\n| ❌ not suitable for resize mode | ✅ good for resize mode |\n| <img src='resize_no.gif'> |  <img src='resize_good.gif' width='380'> |\n\nIn `full` mode, our model will automatically process the croped region and paste back to the original image. Remember to use `--still` to keep the original head pose.\n\n| input | `--still` | `--still` & `enhancer` |\n|:--------------------: |:--------------------: | :--:|\n| <img src='../examples/source_image/full_body_2.png' width='380'> |  <img src='./example_full.gif' width='380'> |  <img src='./example_full_enhanced.gif' width='380'> \n\n\n### About `--enhancer`\n\nFor higher resolution, we intergate [gfpgan](https://github.com/TencentARC/GFPGAN) and [real-esrgan](https://github.com/xinntao/Real-ESRGAN) for different purpose. Just adding `--enhancer <gfpgan or RestoreFormer>` or `--background_enhancer <realesrgan>` for the enhancement of the face and the full image.\n\n```bash\n# make sure above packages are available:\npip install gfpgan\npip install realesrgan\n```\n\n### About `--face3dvis`\n\nThis flag indicate that we can generated the 3d-rendered face and it's 3d facial landmarks. More details can be founded [here](face3d.md).\n\n| Input        | Animated 3d face | \n|:-------------: | :-------------: |\n|  <img src='../examples/source_image/art_0.png' width='200px'> | <video src=\"https://user-images.githubusercontent.com/4397546/226856847-5a6a0a4d-a5ec-49e2-9b05-3206db65e8e3.mp4\"></video>  | \n\n> Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.\n\n\n\n#### Reference eye-link mode.\n\n| Input, w/ reference video   ,  reference video    | \n|:-------------: | \n|  ![free_view](using_ref_video.gif)| \n| If the reference video is shorter than the input audio, we will loop the reference video . \n\n\n\n#### Generating 4D free-view talking examples from audio and a single image\n\nWe use `input_yaw`, `input_pitch`, `input_roll` to control head pose. For example, `--input_yaw -20 30 10` means the input head yaw degree changes from -20 to 30 and then changes from 30 to 10.\n```bash\npython inference.py --driven_audio <audio.wav> \\\n                    --source_image <video.mp4 or picture.png> \\\n                    --result_dir <a file to store results> \\\n                    --input_yaw -20 30 10\n```\n\n| Results, Free-view results,  Novel view results  | \n|:-------------: | \n|  ![free_view](free_view_result.gif)| \n"
  },
  {
    "path": "docs/changlelog.md",
    "content": "## changelogs\n\n\n- __[2023.04.06]__: stable-diffiusion webui extension is release.\n\n- __[2023.04.03]__: Enable TTS in huggingface and gradio local demo.\n\n- __[2023.03.30]__: Launch beta version of the full body mode.\n\n- __[2023.03.30]__: Launch new feature: through using reference videos, our algorithm can generate videos with more natural eye blinking and some eyebrow movement.\n\n- __[2023.03.29]__: `resize mode` is online by `python infererence.py --preprocess resize`! Where we can produce a larger crop of the image as discussed in https://github.com/Winfredy/SadTalker/issues/35.\n\n- __[2023.03.29]__: local gradio demo is online! `python app.py` to start the demo. New `requirments.txt` is used to avoid the bugs in `librosa`.\n\n- __[2023.03.28]__: Online demo is launched in [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/vinthony/SadTalker), thanks AK!\n \n- __[2023.03.22]__: Launch new feature: generating the 3d face animation from a single image. New applications about it will be updated.\n\n- __[2023.03.22]__: Launch new feature: `still mode`, where only a small head pose will be produced via `python inference.py --still`. \n\n- __[2023.03.18]__: Support `expression intensity`, now you can change the intensity of the generated motion: `python inference.py --expression_scale 1.3 (some value > 1)`.\n\n- __[2023.03.18]__: Reconfig the data folders, now you can download the checkpoint automatically using `bash scripts/download_models.sh`.\n- __[2023.03.18]__: We have offically integrate the [GFPGAN](https://github.com/TencentARC/GFPGAN) for face enhancement, using `python inference.py --enhancer gfpgan` for  better visualization performance.\n- __[2023.03.14]__: Specify the version of package `joblib` to remove the errors in using `librosa`, [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) is online!\n- __[2023.03.06]__: Solve some bugs in code and errors in installation \n- __[2023.03.03]__: Release the test code for audio-driven single image animation!\n- __[2023.02.28]__: SadTalker has been accepted by CVPR 2023!\n"
  },
  {
    "path": "docs/face3d.md",
    "content": "## 3D Face Visualization\n\nWe use `pytorch3d` to visualize the 3D faces from a single image.\n\nThe requirements for 3D visualization are difficult to install, so here's a tutorial:\n\n```bash\ngit clone https://github.com/OpenTalker/SadTalker.git\ncd SadTalker \nconda create -n sadtalker3d python=3.8\nsource activate sadtalker3d\n\nconda install ffmpeg\nconda install -c fvcore -c iopath -c conda-forge fvcore iopath\nconda install libgcc gmp\n\npip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113\n\n# insintall pytorch3d\npip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html\n\npip install -r requirements3d.txt\n\n### install gpfgan for enhancer\npip install git+https://github.com/TencentARC/GFPGAN\n\n\n### when occurs gcc version problem `from pytorch import _C` from pytorch3d, add the anaconda path to LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/$YOUR_ANACONDA_PATH/lib/\n\n``` \n\nThen, generate the result via:\n\n```bash\n\n\npython inference.py --driven_audio <audio.wav> \\\n                    --source_image <video.mp4 or picture.png> \\\n                    --result_dir <a file to store results> \\\n                    --face3dvis\n\n```\n\nThe results will appear, named `face3d.mp4`.\n\nMore applications about 3D face rendering will be released soon.\n"
  },
  {
    "path": "docs/install.md",
    "content": "### macOS\n\nThis method has been tested on a M1 Mac (13.3)\n\n```bash\ngit clone https://github.com/OpenTalker/SadTalker.git\ncd SadTalker \nconda create -n sadtalker python=3.8\nconda activate sadtalker\n# install pytorch 2.0\npip install torch torchvision torchaudio\nconda install ffmpeg\npip install -r requirements.txt\npip install dlib # macOS needs to install the original dlib.\n```\n\n### Windows Native\n\n- Make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) tutorial to install `ffmpeg` or using scoop.\n\n\n### Windows WSL\n\n\n- Make sure the environment: `export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH`\n\n\n### Docker Installation\n\nA community Docker image by [@thegenerativegeneration](https://github.com/thegenerativegeneration) is available on the [Docker hub](https://hub.docker.com/repository/docker/wawa9000/sadtalker), which can be used directly:\n```bash\ndocker run --gpus \"all\" --rm -v $(pwd):/host_dir wawa9000/sadtalker \\\n    --driven_audio /host_dir/deyu.wav \\\n    --source_image /host_dir/image.jpg \\\n    --expression_scale 1.0 \\\n    --still \\\n    --result_dir /host_dir\n```\n\n"
  },
  {
    "path": "docs/webui_extension.md",
    "content": "## Run SadTalker as a Stable Diffusion WebUI Extension.\n\n1. Install the lastest version of [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and install SadTalker via `extension`.\n<img width=\"726\" alt=\"image\" src=\"https://user-images.githubusercontent.com/4397546/230698519-267d1d1f-6e99-4dd4-81e1-7b889259efbd.png\">\n\n2. Download the checkpoints manually, for Linux and Mac:\n\n    ```bash\n\n    cd SOMEWHERE_YOU_LIKE\n\n    bash <(wget -qO- https://raw.githubusercontent.com/Winfredy/OpenTalker/main/scripts/download_models.sh)\n    ```\n\n    For Windows, you can download all the checkpoints [here](https://github.com/OpenTalker/SadTalker/tree/main#2-download-models).\n\n3.1. Option 1: put the checkpoint in `stable-diffusion-webui/models/SadTalker` or `stable-diffusion-webui/extensions/SadTalker/checkpoints/`, the checkpoints will be detected automatically.\n\n3.2. Option 2: Set the path of `SADTALKTER_CHECKPOINTS` in `webui_user.sh`(linux) or `webui_user.bat`(windows) by:\n\n    > only works if you are directly starting webui from `webui_user.sh` or `webui_user.bat`.\n\n    ```bash\n    # Windows (webui_user.bat)\n    set SADTALKER_CHECKPOINTS=D:\\SadTalker\\checkpoints\n\n    # Linux/macOS (webui_user.sh)\n    export SADTALKER_CHECKPOINTS=/path/to/SadTalker/checkpoints\n    ```\n\n4. Start the WebUI via `webui.sh or webui_user.sh(linux)` or `webui_user.bat(windows)` or any other method. SadTalker can also be used in stable-diffusion-webui directly.\n    \n    <img width=\"726\" alt=\"image\" src=\"https://user-images.githubusercontent.com/4397546/230698614-58015182-2916-4240-b324-e69022ef75b3.png\">\n    \n## Questions\n\n1. if you are running on CPU, you need to specific `--disable-safe-unpickle` in `webui_user.sh` or `webui_user.bat`.\n\n    ```bash\n    # windows (webui_user.bat)\n    set COMMANDLINE_ARGS=\"--disable-safe-unpickle\"\n\n    # linux (webui_user.sh)\n    export COMMANDLINE_ARGS=\"--disable-safe-unpickle\"\n    ```\n\n\n\n(If you're unable to use the `full` mode, please read this [discussion](https://github.com/Winfredy/SadTalker/issues/78).)\n"
  },
  {
    "path": "inference.py",
    "content": "from glob import glob\nimport shutil\nimport torch\nfrom time import  strftime\nimport os, sys, time\nfrom argparse import ArgumentParser\n\nfrom src.utils.preprocess import CropAndExtract\nfrom src.test_audio2coeff import Audio2Coeff  \nfrom src.facerender.animate import AnimateFromCoeff\nfrom src.generate_batch import get_data\nfrom src.generate_facerender_batch import get_facerender_data\nfrom src.utils.init_path import init_path\n\ndef main(args):\n    #torch.backends.cudnn.enabled = False\n\n    pic_path = args.source_image\n    audio_path = args.driven_audio\n    save_dir = os.path.join(args.result_dir, strftime(\"%Y_%m_%d_%H.%M.%S\"))\n    os.makedirs(save_dir, exist_ok=True)\n    pose_style = args.pose_style\n    device = args.device\n    batch_size = args.batch_size\n    input_yaw_list = args.input_yaw\n    input_pitch_list = args.input_pitch\n    input_roll_list = args.input_roll\n    ref_eyeblink = args.ref_eyeblink\n    ref_pose = args.ref_pose\n\n    current_root_path = os.path.split(sys.argv[0])[0]\n\n    sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)\n\n    #init model\n    preprocess_model = CropAndExtract(sadtalker_paths, device)\n\n    audio_to_coeff = Audio2Coeff(sadtalker_paths,  device)\n    \n    animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)\n\n    #crop image and extract 3dmm from image\n    first_frame_dir = os.path.join(save_dir, 'first_frame_dir')\n    os.makedirs(first_frame_dir, exist_ok=True)\n    print('3DMM Extraction for source image')\n    first_coeff_path, crop_pic_path, crop_info =  preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\\\n                                                                             source_image_flag=True, pic_size=args.size)\n    if first_coeff_path is None:\n        print(\"Can't get the coeffs of the input\")\n        return\n\n    if ref_eyeblink is not None:\n        ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]\n        ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)\n        os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)\n        print('3DMM Extraction for the reference video providing eye blinking')\n        ref_eyeblink_coeff_path, _, _ =  preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)\n    else:\n        ref_eyeblink_coeff_path=None\n\n    if ref_pose is not None:\n        if ref_pose == ref_eyeblink: \n            ref_pose_coeff_path = ref_eyeblink_coeff_path\n        else:\n            ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]\n            ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)\n            os.makedirs(ref_pose_frame_dir, exist_ok=True)\n            print('3DMM Extraction for the reference video providing pose')\n            ref_pose_coeff_path, _, _ =  preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)\n    else:\n        ref_pose_coeff_path=None\n\n    #audio2ceoff\n    batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)\n    coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)\n\n    # 3dface render\n    if args.face3dvis:\n        from src.face3d.visualize import gen_composed_video\n        gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))\n    \n    #coeff2video\n    data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, \n                                batch_size, input_yaw_list, input_pitch_list, input_roll_list,\n                                expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)\n    \n    result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \\\n                                enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)\n    \n    shutil.move(result, save_dir+'.mp4')\n    print('The generated video is named:', save_dir+'.mp4')\n\n    if not args.verbose:\n        shutil.rmtree(save_dir)\n\n    \nif __name__ == '__main__':\n\n    parser = ArgumentParser()  \n    parser.add_argument(\"--driven_audio\", default='./examples/driven_audio/bus_chinese.wav', help=\"path to driven audio\")\n    parser.add_argument(\"--source_image\", default='./examples/source_image/full_body_1.png', help=\"path to source image\")\n    parser.add_argument(\"--ref_eyeblink\", default=None, help=\"path to reference video providing eye blinking\")\n    parser.add_argument(\"--ref_pose\", default=None, help=\"path to reference video providing pose\")\n    parser.add_argument(\"--checkpoint_dir\", default='./checkpoints', help=\"path to output\")\n    parser.add_argument(\"--result_dir\", default='./results', help=\"path to output\")\n    parser.add_argument(\"--pose_style\", type=int, default=0,  help=\"input pose style from [0, 46)\")\n    parser.add_argument(\"--batch_size\", type=int, default=2,  help=\"the batch size of facerender\")\n    parser.add_argument(\"--size\", type=int, default=256,  help=\"the image size of the facerender\")\n    parser.add_argument(\"--expression_scale\", type=float, default=1.,  help=\"the batch size of facerender\")\n    parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help=\"the input yaw degree of the user \")\n    parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help=\"the input pitch degree of the user\")\n    parser.add_argument('--input_roll', nargs='+', type=int, default=None, help=\"the input roll degree of the user\")\n    parser.add_argument('--enhancer',  type=str, default=None, help=\"Face enhancer, [gfpgan, RestoreFormer]\")\n    parser.add_argument('--background_enhancer',  type=str, default=None, help=\"background enhancer, [realesrgan]\")\n    parser.add_argument(\"--cpu\", dest=\"cpu\", action=\"store_true\") \n    parser.add_argument(\"--face3dvis\", action=\"store_true\", help=\"generate 3d face and 3d landmarks\") \n    parser.add_argument(\"--still\", action=\"store_true\", help=\"can crop back to the original videos for the full body aniamtion\") \n    parser.add_argument(\"--preprocess\", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help=\"how to preprocess the images\" ) \n    parser.add_argument(\"--verbose\",action=\"store_true\", help=\"saving the intermedia output or not\" ) \n    parser.add_argument(\"--old_version\",action=\"store_true\", help=\"use the pth other than safetensor version\" ) \n\n\n    # net structure and parameters\n    parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')\n    parser.add_argument('--init_path', type=str, default=None, help='Useless')\n    parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')\n    parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')\n    parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')\n\n    # default renderer parameters\n    parser.add_argument('--focal', type=float, default=1015.)\n    parser.add_argument('--center', type=float, default=112.)\n    parser.add_argument('--camera_d', type=float, default=10.)\n    parser.add_argument('--z_near', type=float, default=5.)\n    parser.add_argument('--z_far', type=float, default=15.)\n\n    args = parser.parse_args()\n\n    if torch.cuda.is_available() and not args.cpu:\n        args.device = \"cuda\"\n    else:\n        args.device = \"cpu\"\n\n    main(args)\n\n"
  },
  {
    "path": "launcher.py",
    "content": "# this scripts installs necessary requirements and launches main program in webui.py\n# borrow from : https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/launch.py\nimport subprocess\nimport os\nimport sys\nimport importlib.util\nimport shlex\nimport platform\nimport json\n\npython = sys.executable\ngit = os.environ.get('GIT', \"git\")\nindex_url = os.environ.get('INDEX_URL', \"\")\nstored_commit_hash = None\nskip_install = False\ndir_repos = \"repositories\"\nscript_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))\n\nif 'GRADIO_ANALYTICS_ENABLED' not in os.environ:\n    os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'\n\n\ndef check_python_version():\n    is_windows = platform.system() == \"Windows\"\n    major = sys.version_info.major\n    minor = sys.version_info.minor\n    micro = sys.version_info.micro\n\n    if is_windows:\n        supported_minors = [10]\n    else:\n        supported_minors = [7, 8, 9, 10, 11]\n\n    if not (major == 3 and minor in supported_minors):\n\n        raise (f\"\"\"\nINCOMPATIBLE PYTHON VERSION\nThis program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.\nIf you encounter an error with \"RuntimeError: Couldn't install torch.\" message,\nor any other error regarding unsuccessful package (library) installation,\nplease downgrade (or upgrade) to the latest version of 3.10 Python\nand delete current Python and \"venv\" folder in WebUI's directory.\nYou can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/\n{\"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases\" if is_windows else \"\"}\nUse --skip-python-version-check to suppress this warning.\n\"\"\")\n\n\ndef commit_hash():\n    global stored_commit_hash\n\n    if stored_commit_hash is not None:\n        return stored_commit_hash\n\n    try:\n        stored_commit_hash = run(f\"{git} rev-parse HEAD\").strip()\n    except Exception:\n        stored_commit_hash = \"<none>\"\n\n    return stored_commit_hash\n\n\ndef run(command, desc=None, errdesc=None, custom_env=None, live=False):\n    if desc is not None:\n        print(desc)\n\n    if live:\n        result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)\n        if result.returncode != 0:\n            raise RuntimeError(f\"\"\"{errdesc or 'Error running command'}.\nCommand: {command}\nError code: {result.returncode}\"\"\")\n\n        return \"\"\n\n    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)\n\n    if result.returncode != 0:\n\n        message = f\"\"\"{errdesc or 'Error running command'}.\nCommand: {command}\nError code: {result.returncode}\nstdout: {result.stdout.decode(encoding=\"utf8\", errors=\"ignore\") if len(result.stdout)>0 else '<empty>'}\nstderr: {result.stderr.decode(encoding=\"utf8\", errors=\"ignore\") if len(result.stderr)>0 else '<empty>'}\n\"\"\"\n        raise RuntimeError(message)\n\n    return result.stdout.decode(encoding=\"utf8\", errors=\"ignore\")\n\n\ndef check_run(command):\n    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)\n    return result.returncode == 0\n\n\ndef is_installed(package):\n    try:\n        spec = importlib.util.find_spec(package)\n    except ModuleNotFoundError:\n        return False\n\n    return spec is not None\n\n\ndef repo_dir(name):\n    return os.path.join(script_path, dir_repos, name)\n\n\ndef run_python(code, desc=None, errdesc=None):\n    return run(f'\"{python}\" -c \"{code}\"', desc, errdesc)\n\n\ndef run_pip(args, desc=None):\n    if skip_install:\n        return\n\n    index_url_line = f' --index-url {index_url}' if index_url != '' else ''\n    return run(f'\"{python}\" -m pip {args} --prefer-binary{index_url_line}', desc=f\"Installing {desc}\", errdesc=f\"Couldn't install {desc}\")\n\n\ndef check_run_python(code):\n    return check_run(f'\"{python}\" -c \"{code}\"')\n\n\ndef git_clone(url, dir, name, commithash=None):\n    # TODO clone into temporary dir and move if successful\n\n    if os.path.exists(dir):\n        if commithash is None:\n            return\n\n        current_hash = run(f'\"{git}\" -C \"{dir}\" rev-parse HEAD', None, f\"Couldn't determine {name}'s hash: {commithash}\").strip()\n        if current_hash == commithash:\n            return\n\n        run(f'\"{git}\" -C \"{dir}\" fetch', f\"Fetching updates for {name}...\", f\"Couldn't fetch {name}\")\n        run(f'\"{git}\" -C \"{dir}\" checkout {commithash}', f\"Checking out commit for {name} with hash: {commithash}...\", f\"Couldn't checkout commit {commithash} for {name}\")\n        return\n\n    run(f'\"{git}\" clone \"{url}\" \"{dir}\"', f\"Cloning {name} into {dir}...\", f\"Couldn't clone {name}\")\n\n    if commithash is not None:\n        run(f'\"{git}\" -C \"{dir}\" checkout {commithash}', None, \"Couldn't checkout {name}'s hash: {commithash}\")\n\n\ndef git_pull_recursive(dir):\n    for subdir, _, _ in os.walk(dir):\n        if os.path.exists(os.path.join(subdir, '.git')):\n            try:\n                output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])\n                print(f\"Pulled changes for repository in '{subdir}':\\n{output.decode('utf-8').strip()}\\n\")\n            except subprocess.CalledProcessError as e:\n                print(f\"Couldn't perform 'git pull' on repository in '{subdir}':\\n{e.output.decode('utf-8').strip()}\\n\")\n\n\ndef run_extension_installer(extension_dir):\n    path_installer = os.path.join(extension_dir, \"install.py\")\n    if not os.path.isfile(path_installer):\n        return\n\n    try:\n        env = os.environ.copy()\n        env['PYTHONPATH'] = os.path.abspath(\".\")\n\n        print(run(f'\"{python}\" \"{path_installer}\"', errdesc=f\"Error running install.py for extension {extension_dir}\", custom_env=env))\n    except Exception as e:\n        print(e, file=sys.stderr)\n\n\ndef prepare_environment():\n    global skip_install\n\n    torch_command = os.environ.get('TORCH_COMMAND', \"pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\")\n\n    ## check windows \n    if sys.platform != 'win32':\n        requirements_file = os.environ.get('REQS_FILE', \"req.txt\")\n    else:\n        requirements_file = os.environ.get('REQS_FILE', \"requirements.txt\")\n\n    commit = commit_hash()\n\n    print(f\"Python {sys.version}\")\n    print(f\"Commit hash: {commit}\")\n\n    if not is_installed(\"torch\") or not is_installed(\"torchvision\"):\n        run(f'\"{python}\" -m {torch_command}', \"Installing torch and torchvision\", \"Couldn't install torch\", live=True)\n\n    run_pip(f\"install -r \\\"{requirements_file}\\\"\", \"requirements for SadTalker WebUI (may take longer time in first time)\")\n\n    if sys.platform != 'win32' and not is_installed('tts'):\n        run_pip(f\"install TTS\", \"install TTS individually in SadTalker, which might not work on windows.\")\n\n\ndef start():\n    print(f\"Launching SadTalker Web UI\")\n    from app_sadtalker import sadtalker_demo\n    demo = sadtalker_demo()\n    demo.queue()\n    demo.launch()\n\nif __name__ == \"__main__\":\n    prepare_environment()\n    start()"
  },
  {
    "path": "predict.py",
    "content": "\"\"\"run bash scripts/download_models.sh first to prepare the weights file\"\"\"\nimport os\nimport shutil\nfrom argparse import Namespace\nfrom src.utils.preprocess import CropAndExtract\nfrom src.test_audio2coeff import Audio2Coeff\nfrom src.facerender.animate import AnimateFromCoeff\nfrom src.generate_batch import get_data\nfrom src.generate_facerender_batch import get_facerender_data\nfrom src.utils.init_path import init_path\nfrom cog import BasePredictor, Input, Path\n\ncheckpoints = \"checkpoints\"\n\n\nclass Predictor(BasePredictor):\n    def setup(self):\n        \"\"\"Load the model into memory to make running multiple predictions efficient\"\"\"\n        device = \"cuda\"\n\n        \n        sadtalker_paths = init_path(checkpoints,os.path.join(\"src\",\"config\"))\n\n        # init model\n        self.preprocess_model = CropAndExtract(sadtalker_paths, device\n        )\n\n        self.audio_to_coeff = Audio2Coeff(\n            sadtalker_paths,\n            device,\n        )\n\n        self.animate_from_coeff = {\n            \"full\": AnimateFromCoeff(\n                sadtalker_paths,\n                device,\n            ),\n            \"others\": AnimateFromCoeff(\n                sadtalker_paths,\n                device,\n            ),\n        }\n\n    def predict(\n        self,\n        source_image: Path = Input(\n            description=\"Upload the source image, it can be video.mp4 or picture.png\",\n        ),\n        driven_audio: Path = Input(\n            description=\"Upload the driven audio, accepts .wav and .mp4 file\",\n        ),\n        enhancer: str = Input(\n            description=\"Choose a face enhancer\",\n            choices=[\"gfpgan\", \"RestoreFormer\"],\n            default=\"gfpgan\",\n        ),\n        preprocess: str = Input(\n            description=\"how to preprocess the images\",\n            choices=[\"crop\", \"resize\", \"full\"],\n            default=\"full\",\n        ),\n        ref_eyeblink: Path = Input(\n            description=\"path to reference video providing eye blinking\",\n            default=None,\n        ),\n        ref_pose: Path = Input(\n            description=\"path to reference video providing pose\",\n            default=None,\n        ),\n        still: bool = Input(\n            description=\"can crop back to the original videos for the full body aniamtion when preprocess is full\",\n            default=True,\n        ),\n    ) -> Path:\n        \"\"\"Run a single prediction on the model\"\"\"\n\n        animate_from_coeff = (\n            self.animate_from_coeff[\"full\"]\n            if preprocess == \"full\"\n            else self.animate_from_coeff[\"others\"]\n        )\n\n        args = load_default()\n        args.pic_path = str(source_image)\n        args.audio_path = str(driven_audio)\n        device = \"cuda\"\n        args.still = still\n        args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)\n        args.ref_pose = None if ref_pose is None else str(ref_pose)\n\n        # crop image and extract 3dmm from image\n        results_dir = \"results\"\n        if os.path.exists(results_dir):\n            shutil.rmtree(results_dir)\n        os.makedirs(results_dir)\n        first_frame_dir = os.path.join(results_dir, \"first_frame_dir\")\n        os.makedirs(first_frame_dir)\n\n        print(\"3DMM Extraction for source image\")\n        first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(\n            args.pic_path, first_frame_dir, preprocess, source_image_flag=True\n        )\n        if first_coeff_path is None:\n            print(\"Can't get the coeffs of the input\")\n            return\n\n        if ref_eyeblink is not None:\n            ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[\n                0\n            ]\n            ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)\n            os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)\n            print(\"3DMM Extraction for the reference video providing eye blinking\")\n            ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(\n                ref_eyeblink, ref_eyeblink_frame_dir\n            )\n        else:\n            ref_eyeblink_coeff_path = None\n\n        if ref_pose is not None:\n            if ref_pose == ref_eyeblink:\n                ref_pose_coeff_path = ref_eyeblink_coeff_path\n            else:\n                ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]\n                ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)\n                os.makedirs(ref_pose_frame_dir, exist_ok=True)\n                print(\"3DMM Extraction for the reference video providing pose\")\n                ref_pose_coeff_path, _, _ = self.preprocess_model.generate(\n                    ref_pose, ref_pose_frame_dir\n                )\n        else:\n            ref_pose_coeff_path = None\n\n        # audio2ceoff\n        batch = get_data(\n            first_coeff_path,\n            args.audio_path,\n            device,\n            ref_eyeblink_coeff_path,\n            still=still,\n        )\n        coeff_path = self.audio_to_coeff.generate(\n            batch, results_dir, args.pose_style, ref_pose_coeff_path\n        )\n        # coeff2video\n        print(\"coeff2video\")\n        data = get_facerender_data(\n            coeff_path,\n            crop_pic_path,\n            first_coeff_path,\n            args.audio_path,\n            args.batch_size,\n            args.input_yaw,\n            args.input_pitch,\n            args.input_roll,\n            expression_scale=args.expression_scale,\n            still_mode=still,\n            preprocess=preprocess,\n        )\n        animate_from_coeff.generate(\n            data, results_dir, args.pic_path, crop_info,\n            enhancer=enhancer, background_enhancer=args.background_enhancer,\n            preprocess=preprocess)\n\n        output = \"/tmp/out.mp4\"\n        mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if \"enhanced.mp4\" in f][0])\n        shutil.copy(mp4_path, output)\n\n        return Path(output)\n\n\ndef load_default():\n    return Namespace(\n        pose_style=0,\n        batch_size=2,\n        expression_scale=1.0,\n        input_yaw=None,\n        input_pitch=None,\n        input_roll=None,\n        background_enhancer=None,\n        face3dvis=False,\n        net_recon=\"resnet50\",\n        init_path=None,\n        use_last_fc=False,\n        bfm_folder=\"./src/config/\",\n        bfm_model=\"BFM_model_front.mat\",\n        focal=1015.0,\n        center=112.0,\n        camera_d=10.0,\n        z_near=5.0,\n        z_far=15.0,\n    )\n"
  },
  {
    "path": "quick_demo.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"attachments\": {},\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"M74Gs_TjYl_B\"\n      },\n      \"source\": [\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)\"\n      ]\n    },\n    {\n      \"attachments\": {},\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"view-in-github\"\n      },\n      \"source\": [\n        \"### SadTalker：Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation \\n\",\n        \"\\n\",\n        \"[arxiv](https://arxiv.org/abs/2211.12194) | [project](https://sadtalker.github.io) | [Github](https://github.com/Winfredy/SadTalker)\\n\",\n        \"\\n\",\n        \"Wenxuan Zhang, Xiaodong Cun, Xuan Wang, Yong Zhang, Xi Shen, Yu Guo, Ying Shan, Fei Wang.\\n\",\n        \"\\n\",\n        \"Xi'an Jiaotong University, Tencent AI Lab, Ant Group\\n\",\n        \"\\n\",\n        \"CVPR 2023\\n\",\n        \"\\n\",\n        \"TL;DR: A realistic and stylized talking head video generation method from a single image and audio\\n\"\n      ]\n    },\n    {\n      \"attachments\": {},\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"kA89DV-sKS4i\"\n      },\n      \"source\": [\n        \"Installation (around 5 mins)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"qJ4CplXsYl_E\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\\n\",\n        \"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Mdq6j4E5KQAR\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.8 2\\n\",\n        \"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.9 1\\n\",\n        \"!sudo apt install python3.8\\n\",\n        \"\\n\",\n        \"!sudo apt-get install python3.8-distutils\\n\",\n        \"\\n\",\n        \"!python --version\\n\",\n        \"\\n\",\n        \"!apt-get update\\n\",\n        \"\\n\",\n        \"!apt install software-properties-common\\n\",\n        \"\\n\",\n        \"!sudo dpkg --remove --force-remove-reinstreq python3-pip python3-setuptools python3-wheel\\n\",\n        \"\\n\",\n        \"!apt-get install python3-pip\\n\",\n        \"\\n\",\n        \"print('Git clone project and install requirements...')\\n\",\n        \"!git clone https://github.com/Winfredy/SadTalker &> /dev/null\\n\",\n        \"%cd SadTalker\\n\",\n        \"!export PYTHONPATH=/content/SadTalker:$PYTHONPATH\\n\",\n        \"!python3.8 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\\n\",\n        \"!apt update\\n\",\n        \"!apt install ffmpeg &> /dev/null\\n\",\n        \"!python3.8 -m pip install -r requirements.txt\"\n      ]\n    },\n    {\n      \"attachments\": {},\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"DddcKB_nKsnk\"\n      },\n      \"source\": [\n        \"Download models (1 mins)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"eDw3_UN8K2xa\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"print('Download pre-trained models...')\\n\",\n        \"!rm -rf checkpoints\\n\",\n        \"!bash scripts/download_models.sh\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"kK7DYeo7Yl_H\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# borrow from makeittalk\\n\",\n        \"import ipywidgets as widgets\\n\",\n        \"import glob\\n\",\n        \"import matplotlib.pyplot as plt\\n\",\n        \"print(\\\"Choose the image name to animate: (saved in folder 'examples/')\\\")\\n\",\n        \"img_list = glob.glob1('examples/source_image', '*.png')\\n\",\n        \"img_list.sort()\\n\",\n        \"img_list = [item.split('.')[0] for item in img_list]\\n\",\n        \"default_head_name = widgets.Dropdown(options=img_list, value='full3')\\n\",\n        \"def on_change(change):\\n\",\n        \"    if change['type'] == 'change' and change['name'] == 'value':\\n\",\n        \"        plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\\n\",\n        \"        plt.axis('off')\\n\",\n        \"        plt.show()\\n\",\n        \"default_head_name.observe(on_change)\\n\",\n        \"display(default_head_name)\\n\",\n        \"plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\\n\",\n        \"plt.axis('off')\\n\",\n        \"plt.show()\"\n      ]\n    },\n    {\n      \"attachments\": {},\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"-khNZcnGK4UK\"\n      },\n      \"source\": [\n        \"Animation\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"ToBlDusjK5sS\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# selected audio from exmaple/driven_audio\\n\",\n        \"img = 'examples/source_image/{}.png'.format(default_head_name.value)\\n\",\n        \"print(img)\\n\",\n        \"!python3.8 inference.py --driven_audio ./examples/driven_audio/RD_Radio31_000.wav \\\\\\n\",\n        \"           --source_image {img} \\\\\\n\",\n        \"           --result_dir ./results --still --preprocess full --enhancer gfpgan\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"fAjwGmKKYl_I\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# visualize code from makeittalk\\n\",\n        \"from IPython.display import HTML\\n\",\n        \"from base64 import b64encode\\n\",\n        \"import os, sys\\n\",\n        \"\\n\",\n        \"# get the last from results\\n\",\n        \"\\n\",\n        \"results = sorted(os.listdir('./results/'))\\n\",\n        \"\\n\",\n        \"mp4_name = glob.glob('./results/*.mp4')[0]\\n\",\n        \"\\n\",\n        \"mp4 = open('{}'.format(mp4_name),'rb').read()\\n\",\n        \"data_url = \\\"data:video/mp4;base64,\\\" + b64encode(mp4).decode()\\n\",\n        \"\\n\",\n        \"print('Display animation: {}'.format(mp4_name), file=sys.stderr)\\n\",\n        \"display(HTML(\\\"\\\"\\\"\\n\",\n        \"  <video width=256 controls>\\n\",\n        \"        <source src=\\\"%s\\\" type=\\\"video/mp4\\\">\\n\",\n        \"  </video>\\n\",\n        \"  \\\"\\\"\\\" % data_url))\\n\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"accelerator\": \"GPU\",\n    \"colab\": {\n      \"provenance\": []\n    },\n    \"gpuClass\": \"standard\",\n    \"kernelspec\": {\n      \"display_name\": \"base\",\n      \"language\": \"python\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\",\n      \"version\": \"3.9.7\"\n    },\n    \"vscode\": {\n      \"interpreter\": {\n        \"hash\": \"db5031b3636a3f037ea48eb287fd3d023feb9033aefc2a9652a92e470fb0851b\"\n      }\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "req.txt",
    "content": "llvmlite==0.38.1\nnumpy==1.21.6\nface_alignment==1.3.5\nimageio==2.19.3\nimageio-ffmpeg==0.4.7\nlibrosa==0.10.0.post2\nnumba==0.55.1\nresampy==0.3.1\npydub==0.25.1 \nscipy==1.10.1\nkornia==0.6.8\ntqdm\nyacs==0.1.8\npyyaml  \njoblib==1.1.0\nscikit-image==0.19.3\nbasicsr==1.4.2\nfacexlib==0.3.0\ngradio\ngfpgan\nav\nsafetensors\n"
  },
  {
    "path": "requirements.txt",
    "content": "numpy==1.23.4\nface_alignment==1.3.5\nimageio==2.19.3\nimageio-ffmpeg==0.4.7\nlibrosa==0.9.2 # \nnumba\nresampy==0.3.1\npydub==0.25.1 \nscipy==1.10.1\nkornia==0.6.8\ntqdm\nyacs==0.1.8\npyyaml  \njoblib==1.1.0\nscikit-image==0.19.3\nbasicsr==1.4.2\nfacexlib==0.3.0\ngradio\ngfpgan\nav\nsafetensors\n"
  },
  {
    "path": "requirements3d.txt",
    "content": "numpy==1.23.4\nface_alignment==1.3.5\nimageio==2.19.3\nimageio-ffmpeg==0.4.7\nlibrosa==0.9.2 # \nnumba\nresampy==0.3.1\npydub==0.25.1 \nscipy==1.5.3\nkornia==0.6.8\ntqdm\nyacs==0.1.8\npyyaml  \njoblib==1.1.0\nscikit-image==0.19.3\nbasicsr==1.4.2\nfacexlib==0.3.0\ntrimesh==3.9.20\ngradio\ngfpgan\nsafetensors"
  },
  {
    "path": "scripts/download_models.sh",
    "content": "mkdir ./checkpoints  \n\n# lagency download link\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip\n# unzip -n ./checkpoints/hub.zip -d ./checkpoints/\n\n\n#### download the new links.\nwget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O  ./checkpoints/mapping_00109-model.pth.tar\nwget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O  ./checkpoints/mapping_00229-model.pth.tar\nwget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O  ./checkpoints/SadTalker_V0.0.2_256.safetensors\nwget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O  ./checkpoints/SadTalker_V0.0.2_512.safetensors\n\n\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip\n# unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/\n\n### enhancer \nmkdir -p ./gfpgan/weights\nwget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth \nwget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth \nwget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth \nwget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth \n\n"
  },
  {
    "path": "scripts/extension.py",
    "content": "import os, sys\r\nfrom pathlib import Path\r\nimport tempfile\r\nimport gradio as gr\r\nfrom modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call\r\nfrom modules.shared import opts, OptionInfo\r\nfrom modules import shared, paths, script_callbacks\r\nimport launch\r\nimport glob\r\nfrom huggingface_hub import snapshot_download\r\n\r\n\r\n\r\ndef check_all_files_safetensor(current_dir):\r\n    kv = {\r\n        \"SadTalker_V0.0.2_256.safetensors\": \"sadtalker-256\",\r\n        \"SadTalker_V0.0.2_512.safetensors\": \"sadtalker-512\",\r\n        \"mapping_00109-model.pth.tar\" : \"mapping-109\" ,\r\n        \"mapping_00229-model.pth.tar\" : \"mapping-229\" ,\r\n    }\r\n\r\n    if not os.path.isdir(current_dir):\r\n        return False\r\n    \r\n    dirs = os.listdir(current_dir)\r\n\r\n    for f in dirs:\r\n        if f in kv.keys():\r\n            del kv[f]\r\n\r\n    return len(kv.keys()) == 0\r\n\r\ndef check_all_files(current_dir):\r\n    kv = {\r\n        \"auido2exp_00300-model.pth\": \"audio2exp\",\r\n        \"auido2pose_00140-model.pth\": \"audio2pose\",\r\n        \"epoch_20.pth\": \"face_recon\",\r\n        \"facevid2vid_00189-model.pth.tar\": \"face-render\",\r\n        \"mapping_00109-model.pth.tar\" : \"mapping-109\" ,\r\n        \"mapping_00229-model.pth.tar\" : \"mapping-229\" ,\r\n        \"wav2lip.pth\": \"wav2lip\",\r\n        \"shape_predictor_68_face_landmarks.dat\": \"dlib\",\r\n    }\r\n\r\n    if not os.path.isdir(current_dir):\r\n        return False\r\n    \r\n    dirs = os.listdir(current_dir)\r\n\r\n    for f in dirs:\r\n        if f in kv.keys():\r\n            del kv[f]\r\n\r\n    return len(kv.keys()) == 0\r\n\r\n    \r\n\r\ndef download_model(local_dir='./checkpoints'):\r\n    REPO_ID = 'vinthony/SadTalker'\r\n    snapshot_download(repo_id=REPO_ID, local_dir=local_dir, local_dir_use_symlinks=False)\r\n\r\ndef get_source_image(image):   \r\n        return image\r\n\r\ndef get_img_from_txt2img(x):\r\n    talker_path = Path(paths.script_path) / \"outputs\"\r\n    imgs_from_txt_dir = str(talker_path / \"txt2img-images/\")\r\n    imgs = glob.glob(imgs_from_txt_dir+'/*/*.png')\r\n    imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x)))\r\n    img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1])\r\n    return img_from_txt_path, img_from_txt_path\r\n\r\ndef get_img_from_img2img(x):\r\n    talker_path = Path(paths.script_path) / \"outputs\"\r\n    imgs_from_img_dir = str(talker_path / \"img2img-images/\")\r\n    imgs = glob.glob(imgs_from_img_dir+'/*/*.png')\r\n    imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x)))\r\n    img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1])\r\n    return img_from_img_path, img_from_img_path\r\n \r\ndef get_default_checkpoint_path():\r\n    # check the path of models/checkpoints and extensions/\r\n    checkpoint_path = Path(paths.script_path) / \"models\"/ \"SadTalker\" \r\n    extension_checkpoint_path = Path(paths.script_path) / \"extensions\"/ \"SadTalker\" / \"checkpoints\"\r\n\r\n    if check_all_files_safetensor(checkpoint_path):\r\n        # print('founding sadtalker checkpoint in ' + str(checkpoint_path))\r\n        return checkpoint_path\r\n\r\n    if check_all_files_safetensor(extension_checkpoint_path):\r\n        # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))\r\n        return extension_checkpoint_path\r\n    \r\n    if check_all_files(checkpoint_path):\r\n        # print('founding sadtalker checkpoint in ' + str(checkpoint_path))\r\n        return checkpoint_path\r\n\r\n    if check_all_files(extension_checkpoint_path):\r\n        # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))\r\n        return extension_checkpoint_path\r\n\r\n    return None\r\n\r\n\r\n\r\ndef install():\r\n\r\n    kv = {\r\n        \"face_alignment\": \"face-alignment==1.3.5\",\r\n        \"imageio\": \"imageio==2.19.3\",\r\n        \"imageio_ffmpeg\": \"imageio-ffmpeg==0.4.7\",\r\n        \"librosa\":\"librosa==0.8.0\",\r\n        \"pydub\":\"pydub==0.25.1\",\r\n        \"scipy\":\"scipy==1.8.1\",\r\n        \"tqdm\": \"tqdm\",\r\n        \"yacs\":\"yacs==0.1.8\",\r\n        \"yaml\": \"pyyaml\", \r\n        \"av\":\"av\",\r\n        \"gfpgan\": \"gfpgan\",\r\n    }\r\n\r\n    # # dlib is not necessary currently\r\n    # if 'darwin' in sys.platform:\r\n    #     kv['dlib'] = \"dlib\"\r\n    # else:\r\n    #     kv['dlib'] = 'dlib-bin'\r\n\r\n    # #### we need to have a newer version of imageio for our method.\r\n    # launch.run_pip(\"install imageio==2.19.3\", \"requirements for SadTalker\")\r\n\r\n    for k,v in kv.items():\r\n        if not launch.is_installed(k):\r\n            print(k, launch.is_installed(k))\r\n            launch.run_pip(\"install \"+ v, \"requirements for SadTalker\")\r\n\r\n    if os.getenv('SADTALKER_CHECKPOINTS'):\r\n        print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))\r\n\r\n    elif get_default_checkpoint_path() is not None:\r\n        os.environ['SADTALKER_CHECKPOINTS'] = str(get_default_checkpoint_path())\r\n    else:\r\n\r\n        print(\r\n            \"\"\"\"\r\n            SadTalker will not support download all the files from hugging face, which will take a long time.\r\n             \r\n            please manually set the SADTALKER_CHECKPOINTS in `webui_user.bat`(windows) or `webui_user.sh`(linux)\r\n            \"\"\"\r\n            )\r\n        \r\n        # python = sys.executable\r\n\r\n        # launch.run(f'\"{python}\" -m pip uninstall -y huggingface_hub', live=True)\r\n        # launch.run(f'\"{python}\" -m pip install --upgrade git+https://github.com/huggingface/huggingface_hub@main', live=True)\r\n        # ### run the scripts to downlod models to correct localtion.\r\n        # # print('download models for SadTalker')\r\n        # # launch.run(\"cd \" + paths.script_path+\"/extensions/SadTalker && bash ./scripts/download_models.sh\", live=True)\r\n        # # print('SadTalker is successfully installed!')\r\n        # download_model(paths.script_path+'/extensions/SadTalker/checkpoints')\r\n    \r\n \r\ndef on_ui_tabs():\r\n    install()\r\n\r\n    sys.path.extend([paths.script_path+'/extensions/SadTalker']) \r\n    \r\n    repo_dir = paths.script_path+'/extensions/SadTalker/'\r\n\r\n    result_dir = opts.sadtalker_result_dir\r\n    os.makedirs(result_dir, exist_ok=True)\r\n\r\n    from app_sadtalker import sadtalker_demo  \r\n\r\n    if  os.getenv('SADTALKER_CHECKPOINTS'):\r\n        checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')\r\n    else:\r\n        checkpoint_path = repo_dir+'checkpoints/'\r\n\r\n    audio_to_video = sadtalker_demo(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', warpfn = wrap_queued_call)\r\n   \r\n    return [(audio_to_video, \"SadTalker\", \"extension\")]\r\n\r\ndef on_ui_settings():\r\n    talker_path = Path(paths.script_path) / \"outputs\"\r\n    section = ('extension', \"SadTalker\") \r\n    opts.add_option(\"sadtalker_result_dir\", OptionInfo(str(talker_path / \"SadTalker/\"), \"Path to save results of sadtalker\", section=section)) \r\n\r\nscript_callbacks.on_ui_settings(on_ui_settings)\r\nscript_callbacks.on_ui_tabs(on_ui_tabs)\r\n"
  },
  {
    "path": "scripts/test.sh",
    "content": "# ### some test command before commit.\n# python inference.py --preprocess crop --size 256\n# python inference.py --preprocess crop --size 512\n\n# python inference.py --preprocess extcrop --size 256\n# python inference.py --preprocess extcrop --size 512\n\n# python inference.py --preprocess resize --size 256\n# python inference.py --preprocess resize --size 512\n\n# python inference.py --preprocess full --size 256\n# python inference.py --preprocess full --size 512\n\n# python inference.py --preprocess extfull --size 256\n# python inference.py --preprocess extfull --size 512\n\npython inference.py --preprocess full --size 256 --enhancer gfpgan\npython inference.py --preprocess full --size 512 --enhancer gfpgan\n\npython inference.py --preprocess full --size 256 --enhancer gfpgan --still\npython inference.py --preprocess full --size 512 --enhancer gfpgan --still\n"
  },
  {
    "path": "src/audio2exp_models/audio2exp.py",
    "content": "from tqdm import tqdm\nimport torch\nfrom torch import nn\n\n\nclass Audio2Exp(nn.Module):\n    def __init__(self, netG, cfg, device, prepare_training_loss=False):\n        super(Audio2Exp, self).__init__()\n        self.cfg = cfg\n        self.device = device\n        self.netG = netG.to(device)\n\n    def test(self, batch):\n\n        mel_input = batch['indiv_mels']                         # bs T 1 80 16\n        bs = mel_input.shape[0]\n        T = mel_input.shape[1]\n\n        exp_coeff_pred = []\n\n        for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames\n            \n            current_mel_input = mel_input[:,i:i+10]\n\n            #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1))           #bs T 64\n            ref = batch['ref'][:, :, :64][:, i:i+10]\n            ratio = batch['ratio_gt'][:, i:i+10]                               #bs T\n\n            audiox = current_mel_input.view(-1, 1, 80, 16)                  # bs*T 1 80 16\n\n            curr_exp_coeff_pred  = self.netG(audiox, ref, ratio)         # bs T 64 \n\n            exp_coeff_pred += [curr_exp_coeff_pred]\n\n        # BS x T x 64\n        results_dict = {\n            'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)\n            }\n        return results_dict\n\n\n"
  },
  {
    "path": "src/audio2exp_models/networks.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nclass Conv2d(nn.Module):\n    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.conv_block = nn.Sequential(\n                            nn.Conv2d(cin, cout, kernel_size, stride, padding),\n                            nn.BatchNorm2d(cout)\n                            )\n        self.act = nn.ReLU()\n        self.residual = residual\n        self.use_act = use_act\n\n    def forward(self, x):\n        out = self.conv_block(x)\n        if self.residual:\n            out += x\n        \n        if self.use_act:\n            return self.act(out)\n        else:\n            return out\n\nclass SimpleWrapperV2(nn.Module):\n    def __init__(self) -> None:\n        super().__init__()\n        self.audio_encoder = nn.Sequential(\n            Conv2d(1, 32, kernel_size=3, stride=1, padding=1),\n            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),\n            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),\n\n            Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),\n            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),\n            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),\n\n            Conv2d(64, 128, kernel_size=3, stride=3, padding=1),\n            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),\n            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),\n\n            Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),\n            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),\n\n            Conv2d(256, 512, kernel_size=3, stride=1, padding=0),\n            Conv2d(512, 512, kernel_size=1, stride=1, padding=0),\n            )\n\n        #### load the pre-trained audio_encoder \n        #self.audio_encoder = self.audio_encoder.to(device)  \n        '''\n        wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']\n        state_dict = self.audio_encoder.state_dict()\n\n        for k,v in wav2lip_state_dict.items():\n            if 'audio_encoder' in k:\n                print('init:', k)\n                state_dict[k.replace('module.audio_encoder.', '')] = v\n        self.audio_encoder.load_state_dict(state_dict)\n        '''\n\n        self.mapping1 = nn.Linear(512+64+1, 64)\n        #self.mapping2 = nn.Linear(30, 64)\n        #nn.init.constant_(self.mapping1.weight, 0.)\n        nn.init.constant_(self.mapping1.bias, 0.)\n\n    def forward(self, x, ref, ratio):\n        x = self.audio_encoder(x).view(x.size(0), -1)\n        ref_reshape = ref.reshape(x.size(0), -1)\n        ratio = ratio.reshape(x.size(0), -1)\n        \n        y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) \n        out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial\n        return out\n"
  },
  {
    "path": "src/audio2pose_models/audio2pose.py",
    "content": "import torch\nfrom torch import nn\nfrom src.audio2pose_models.cvae import CVAE\nfrom src.audio2pose_models.discriminator import PoseSequenceDiscriminator\nfrom src.audio2pose_models.audio_encoder import AudioEncoder\n\nclass Audio2Pose(nn.Module):\n    def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):\n        super().__init__()\n        self.cfg = cfg\n        self.seq_len = cfg.MODEL.CVAE.SEQ_LEN\n        self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE\n        self.device = device\n\n        self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)\n        self.audio_encoder.eval()\n        for param in self.audio_encoder.parameters():\n            param.requires_grad = False\n\n        self.netG = CVAE(cfg)\n        self.netD_motion = PoseSequenceDiscriminator(cfg)\n        \n        \n    def forward(self, x):\n\n        batch = {}\n        coeff_gt = x['gt'].cuda().squeeze(0)           #bs frame_len+1 73\n        batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6\n        batch['ref'] = coeff_gt[:, 0, 64:70]  #bs  6\n        batch['class'] = x['class'].squeeze(0).cuda() # bs\n        indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16\n\n        # forward\n        audio_emb_list = []\n        audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512\n        batch['audio_emb'] = audio_emb\n        batch = self.netG(batch)\n\n        pose_motion_pred = batch['pose_motion_pred']           # bs frame_len 6\n        pose_gt = coeff_gt[:, 1:, 64:70].clone()               # bs frame_len 6\n        pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred  # bs frame_len 6\n\n        batch['pose_pred'] = pose_pred\n        batch['pose_gt'] = pose_gt\n\n        return batch\n\n    def test(self, x):\n\n        batch = {}\n        ref = x['ref']                            #bs 1 70\n        batch['ref'] = x['ref'][:,0,-6:]  \n        batch['class'] = x['class']  \n        bs = ref.shape[0]\n        \n        indiv_mels= x['indiv_mels']               # bs T 1 80 16\n        indiv_mels_use = indiv_mels[:, 1:]        # we regard the ref as the first frame\n        num_frames = x['num_frames']\n        num_frames = int(num_frames) - 1\n\n        #  \n        div = num_frames//self.seq_len\n        re = num_frames%self.seq_len\n        audio_emb_list = []\n        pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, \n                                                device=batch['ref'].device)]\n\n        for i in range(div):\n            z = torch.randn(bs, self.latent_dim).to(ref.device)\n            batch['z'] = z\n            audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512\n            batch['audio_emb'] = audio_emb\n            batch = self.netG.test(batch)\n            pose_motion_pred_list.append(batch['pose_motion_pred'])  #list of bs seq_len 6\n        \n        if re != 0:\n            z = torch.randn(bs, self.latent_dim).to(ref.device)\n            batch['z'] = z\n            audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len  512\n            if audio_emb.shape[1] != self.seq_len:\n                pad_dim = self.seq_len-audio_emb.shape[1]\n                pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) \n                audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) \n            batch['audio_emb'] = audio_emb\n            batch = self.netG.test(batch)\n            pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])   \n        \n        pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)\n        batch['pose_motion_pred'] = pose_motion_pred\n\n        pose_pred = ref[:, :1, -6:] + pose_motion_pred  # bs T 6\n\n        batch['pose_pred'] = pose_pred\n        return batch\n"
  },
  {
    "path": "src/audio2pose_models/audio_encoder.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass Conv2d(nn.Module):\n    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.conv_block = nn.Sequential(\n                            nn.Conv2d(cin, cout, kernel_size, stride, padding),\n                            nn.BatchNorm2d(cout)\n                            )\n        self.act = nn.ReLU()\n        self.residual = residual\n\n    def forward(self, x):\n        out = self.conv_block(x)\n        if self.residual:\n            out += x\n        return self.act(out)\n\nclass AudioEncoder(nn.Module):\n    def __init__(self, wav2lip_checkpoint, device):\n        super(AudioEncoder, self).__init__()\n\n        self.audio_encoder = nn.Sequential(\n            Conv2d(1, 32, kernel_size=3, stride=1, padding=1),\n            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),\n            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),\n\n            Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),\n            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),\n            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),\n\n            Conv2d(64, 128, kernel_size=3, stride=3, padding=1),\n            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),\n            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),\n\n            Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),\n            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),\n\n            Conv2d(256, 512, kernel_size=3, stride=1, padding=0),\n            Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)\n\n        #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.\n        # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']\n        # state_dict = self.audio_encoder.state_dict()\n\n        # for k,v in wav2lip_state_dict.items():\n        #     if 'audio_encoder' in k:\n        #         state_dict[k.replace('module.audio_encoder.', '')] = v\n        # self.audio_encoder.load_state_dict(state_dict)\n\n\n    def forward(self, audio_sequences):\n        # audio_sequences = (B, T, 1, 80, 16)\n        B = audio_sequences.size(0)\n\n        audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)\n\n        audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1\n        dim = audio_embedding.shape[1]\n        audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))\n\n        return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 \n"
  },
  {
    "path": "src/audio2pose_models/cvae.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom src.audio2pose_models.res_unet import ResUnet\n\ndef class2onehot(idx, class_num):\n\n    assert torch.max(idx).item() < class_num\n    onehot = torch.zeros(idx.size(0), class_num).to(idx.device)\n    onehot.scatter_(1, idx, 1)\n    return onehot\n\nclass CVAE(nn.Module):\n    def __init__(self, cfg):\n        super().__init__()\n        encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES\n        decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES\n        latent_size = cfg.MODEL.CVAE.LATENT_SIZE\n        num_classes = cfg.DATASET.NUM_CLASSES\n        audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE\n        audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE\n        seq_len = cfg.MODEL.CVAE.SEQ_LEN\n\n        self.latent_size = latent_size\n\n        self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,\n                                audio_emb_in_size, audio_emb_out_size, seq_len)\n        self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,\n                                audio_emb_in_size, audio_emb_out_size, seq_len)\n    def reparameterize(self, mu, logvar):\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return mu + eps * std\n\n    def forward(self, batch):\n        batch = self.encoder(batch)\n        mu = batch['mu']\n        logvar = batch['logvar']\n        z = self.reparameterize(mu, logvar)\n        batch['z'] = z\n        return self.decoder(batch)\n\n    def test(self, batch):\n        '''\n        class_id = batch['class']\n        z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)\n        batch['z'] = z\n        '''\n        return self.decoder(batch)\n\nclass ENCODER(nn.Module):\n    def __init__(self, layer_sizes, latent_size, num_classes, \n                audio_emb_in_size, audio_emb_out_size, seq_len):\n        super().__init__()\n\n        self.resunet = ResUnet()\n        self.num_classes = num_classes\n        self.seq_len = seq_len\n\n        self.MLP = nn.Sequential()\n        layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6\n        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):\n            self.MLP.add_module(\n                name=\"L{:d}\".format(i), module=nn.Linear(in_size, out_size))\n            self.MLP.add_module(name=\"A{:d}\".format(i), module=nn.ReLU())\n\n        self.linear_means = nn.Linear(layer_sizes[-1], latent_size)\n        self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)\n        self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)\n\n        self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))\n\n    def forward(self, batch):\n        class_id = batch['class']\n        pose_motion_gt = batch['pose_motion_gt']                             #bs seq_len 6\n        ref = batch['ref']                             #bs 6\n        bs = pose_motion_gt.shape[0]\n        audio_in = batch['audio_emb']                          # bs seq_len audio_emb_in_size\n\n        #pose encode\n        pose_emb = self.resunet(pose_motion_gt.unsqueeze(1))          #bs 1 seq_len 6 \n        pose_emb = pose_emb.reshape(bs, -1)                    #bs seq_len*6\n\n        #audio mapping\n        print(audio_in.shape)\n        audio_out = self.linear_audio(audio_in)                # bs seq_len audio_emb_out_size\n        audio_out = audio_out.reshape(bs, -1)\n\n        class_bias = self.classbias[class_id]                  #bs latent_size\n        x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size\n        x_out = self.MLP(x_in)\n\n        mu = self.linear_means(x_out)\n        logvar = self.linear_means(x_out)                      #bs latent_size \n\n        batch.update({'mu':mu, 'logvar':logvar})\n        return batch\n\nclass DECODER(nn.Module):\n    def __init__(self, layer_sizes, latent_size, num_classes, \n                audio_emb_in_size, audio_emb_out_size, seq_len):\n        super().__init__()\n\n        self.resunet = ResUnet()\n        self.num_classes = num_classes\n        self.seq_len = seq_len\n\n        self.MLP = nn.Sequential()\n        input_size = latent_size + seq_len*audio_emb_out_size + 6\n        for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):\n            self.MLP.add_module(\n                name=\"L{:d}\".format(i), module=nn.Linear(in_size, out_size))\n            if i+1 < len(layer_sizes):\n                self.MLP.add_module(name=\"A{:d}\".format(i), module=nn.ReLU())\n            else:\n                self.MLP.add_module(name=\"sigmoid\", module=nn.Sigmoid())\n        \n        self.pose_linear = nn.Linear(6, 6)\n        self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)\n\n        self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))\n\n    def forward(self, batch):\n\n        z = batch['z']                                          #bs latent_size\n        bs = z.shape[0]\n        class_id = batch['class']\n        ref = batch['ref']                             #bs 6\n        audio_in = batch['audio_emb']                           # bs seq_len audio_emb_in_size\n        #print('audio_in: ', audio_in[:, :, :10])\n\n        audio_out = self.linear_audio(audio_in)                 # bs seq_len audio_emb_out_size\n        #print('audio_out: ', audio_out[:, :, :10])\n        audio_out = audio_out.reshape([bs, -1])                 # bs seq_len*audio_emb_out_size\n        class_bias = self.classbias[class_id]                   #bs latent_size\n\n        z = z + class_bias\n        x_in = torch.cat([ref, z, audio_out], dim=-1)\n        x_out = self.MLP(x_in)                                  # bs layer_sizes[-1]\n        x_out = x_out.reshape((bs, self.seq_len, -1))\n\n        #print('x_out: ', x_out)\n\n        pose_emb = self.resunet(x_out.unsqueeze(1))             #bs 1 seq_len 6\n\n        pose_motion_pred = self.pose_linear(pose_emb.squeeze(1))       #bs seq_len 6\n\n        batch.update({'pose_motion_pred':pose_motion_pred})\n        return batch\n"
  },
  {
    "path": "src/audio2pose_models/discriminator.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nclass ConvNormRelu(nn.Module):\n    def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,\n                 kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):\n        super().__init__()\n        if kernel_size is None:\n            if downsample:\n                kernel_size, stride, padding = 4, 2, 1\n            else:\n                kernel_size, stride, padding = 3, 1, 1\n\n        if conv_type == '2d':\n            self.conv = nn.Conv2d(\n                in_channels,\n                out_channels,\n                kernel_size,\n                stride,\n                padding,\n                bias=False,\n            )\n            if norm == 'BN':\n                self.norm = nn.BatchNorm2d(out_channels)\n            elif norm == 'IN':\n                self.norm = nn.InstanceNorm2d(out_channels)\n            else:\n                raise NotImplementedError\n        elif conv_type == '1d':\n            self.conv = nn.Conv1d(\n                in_channels,\n                out_channels,\n                kernel_size,\n                stride,\n                padding,\n                bias=False,\n            )\n            if norm == 'BN':\n                self.norm = nn.BatchNorm1d(out_channels)\n            elif norm == 'IN':\n                self.norm = nn.InstanceNorm1d(out_channels)\n            else:\n                raise NotImplementedError\n        nn.init.kaiming_normal_(self.conv.weight)\n\n        self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        x = self.conv(x)\n        if isinstance(self.norm, nn.InstanceNorm1d):\n            x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1))  # normalize on [C]\n        else:\n            x = self.norm(x)\n        x = self.act(x)\n        return x\n\n\nclass PoseSequenceDiscriminator(nn.Module):\n    def __init__(self, cfg):\n        super().__init__()\n        self.cfg = cfg\n        leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU\n\n        self.seq = nn.Sequential(\n            ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky),  # B, 256, 64\n            ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky),  # B, 512, 32\n            ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky),  # B, 1024, 16\n            nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True)  # B, 1, 16\n        )\n\n    def forward(self, x):\n        x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)\n        x = self.seq(x)\n        x = x.squeeze(1)\n        return x"
  },
  {
    "path": "src/audio2pose_models/networks.py",
    "content": "import torch.nn as nn\nimport torch\n\n\nclass ResidualConv(nn.Module):\n    def __init__(self, input_dim, output_dim, stride, padding):\n        super(ResidualConv, self).__init__()\n\n        self.conv_block = nn.Sequential(\n            nn.BatchNorm2d(input_dim),\n            nn.ReLU(),\n            nn.Conv2d(\n                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding\n            ),\n            nn.BatchNorm2d(output_dim),\n            nn.ReLU(),\n            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),\n        )\n        self.conv_skip = nn.Sequential(\n            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),\n            nn.BatchNorm2d(output_dim),\n        )\n\n    def forward(self, x):\n\n        return self.conv_block(x) + self.conv_skip(x)\n\n\nclass Upsample(nn.Module):\n    def __init__(self, input_dim, output_dim, kernel, stride):\n        super(Upsample, self).__init__()\n\n        self.upsample = nn.ConvTranspose2d(\n            input_dim, output_dim, kernel_size=kernel, stride=stride\n        )\n\n    def forward(self, x):\n        return self.upsample(x)\n\n\nclass Squeeze_Excite_Block(nn.Module):\n    def __init__(self, channel, reduction=16):\n        super(Squeeze_Excite_Block, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel // reduction, channel, bias=False),\n            nn.Sigmoid(),\n        )\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        return x * y.expand_as(x)\n\n\nclass ASPP(nn.Module):\n    def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):\n        super(ASPP, self).__init__()\n\n        self.aspp_block1 = nn.Sequential(\n            nn.Conv2d(\n                in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]\n            ),\n            nn.ReLU(inplace=True),\n            nn.BatchNorm2d(out_dims),\n        )\n        self.aspp_block2 = nn.Sequential(\n            nn.Conv2d(\n                in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]\n            ),\n            nn.ReLU(inplace=True),\n            nn.BatchNorm2d(out_dims),\n        )\n        self.aspp_block3 = nn.Sequential(\n            nn.Conv2d(\n                in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]\n            ),\n            nn.ReLU(inplace=True),\n            nn.BatchNorm2d(out_dims),\n        )\n\n        self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)\n        self._init_weights()\n\n    def forward(self, x):\n        x1 = self.aspp_block1(x)\n        x2 = self.aspp_block2(x)\n        x3 = self.aspp_block3(x)\n        out = torch.cat([x1, x2, x3], dim=1)\n        return self.output(out)\n\n    def _init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n\nclass Upsample_(nn.Module):\n    def __init__(self, scale=2):\n        super(Upsample_, self).__init__()\n\n        self.upsample = nn.Upsample(mode=\"bilinear\", scale_factor=scale)\n\n    def forward(self, x):\n        return self.upsample(x)\n\n\nclass AttentionBlock(nn.Module):\n    def __init__(self, input_encoder, input_decoder, output_dim):\n        super(AttentionBlock, self).__init__()\n\n        self.conv_encoder = nn.Sequential(\n            nn.BatchNorm2d(input_encoder),\n            nn.ReLU(),\n            nn.Conv2d(input_encoder, output_dim, 3, padding=1),\n            nn.MaxPool2d(2, 2),\n        )\n\n        self.conv_decoder = nn.Sequential(\n            nn.BatchNorm2d(input_decoder),\n            nn.ReLU(),\n            nn.Conv2d(input_decoder, output_dim, 3, padding=1),\n        )\n\n        self.conv_attn = nn.Sequential(\n            nn.BatchNorm2d(output_dim),\n            nn.ReLU(),\n            nn.Conv2d(output_dim, 1, 1),\n        )\n\n    def forward(self, x1, x2):\n        out = self.conv_encoder(x1) + self.conv_decoder(x2)\n        out = self.conv_attn(out)\n        return out * x2"
  },
  {
    "path": "src/audio2pose_models/res_unet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom src.audio2pose_models.networks import ResidualConv, Upsample\n\n\nclass ResUnet(nn.Module):\n    def __init__(self, channel=1, filters=[32, 64, 128, 256]):\n        super(ResUnet, self).__init__()\n\n        self.input_layer = nn.Sequential(\n            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),\n            nn.BatchNorm2d(filters[0]),\n            nn.ReLU(),\n            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),\n        )\n        self.input_skip = nn.Sequential(\n            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)\n        )\n\n        self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)\n        self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)\n\n        self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)\n\n        self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))\n        self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)\n\n        self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))\n        self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)\n\n        self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))\n        self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)\n\n        self.output_layer = nn.Sequential(\n            nn.Conv2d(filters[0], 1, 1, 1),\n            nn.Sigmoid(),\n        )\n\n    def forward(self, x):\n        # Encode\n        x1 = self.input_layer(x) + self.input_skip(x)\n        x2 = self.residual_conv_1(x1)\n        x3 = self.residual_conv_2(x2)\n        # Bridge\n        x4 = self.bridge(x3)\n\n        # Decode\n        x4 = self.upsample_1(x4)\n        x5 = torch.cat([x4, x3], dim=1)\n\n        x6 = self.up_residual_conv1(x5)\n\n        x6 = self.upsample_2(x6)\n        x7 = torch.cat([x6, x2], dim=1)\n\n        x8 = self.up_residual_conv2(x7)\n\n        x8 = self.upsample_3(x8)\n        x9 = torch.cat([x8, x1], dim=1)\n\n        x10 = self.up_residual_conv3(x9)\n\n        output = self.output_layer(x10)\n\n        return output"
  },
  {
    "path": "src/config/auido2exp.yaml",
    "content": "DATASET:\n  TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt\n  EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt\n  TRAIN_BATCH_SIZE: 32\n  EVAL_BATCH_SIZE: 32\n  EXP: True\n  EXP_DIM: 64\n  FRAME_LEN: 32\n  COEFF_LEN: 73\n  NUM_CLASSES: 46\n  AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav\n  COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm\n  LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb\n  DEBUG: True\n  NUM_REPEATS: 2\n  T: 40\n  \n\nMODEL:\n  FRAMEWORK: V2\n  AUDIOENCODER:\n    LEAKY_RELU: True\n    NORM: 'IN'\n  DISCRIMINATOR:\n    LEAKY_RELU: False\n    INPUT_CHANNELS: 6\n  CVAE:\n    AUDIO_EMB_IN_SIZE: 512\n    AUDIO_EMB_OUT_SIZE: 128\n    SEQ_LEN: 32\n    LATENT_SIZE: 256\n    ENCODER_LAYER_SIZES: [192, 1024]\n    DECODER_LAYER_SIZES: [1024, 192]\n    \n\nTRAIN:\n  MAX_EPOCH: 300\n  GENERATOR:\n    LR: 2.0e-5\n  DISCRIMINATOR:\n    LR: 1.0e-5\n  LOSS:\n    W_FEAT: 0\n    W_COEFF_EXP: 2\n    W_LM: 1.0e-2\n    W_LM_MOUTH: 0\n    W_REG: 0\n    W_SYNC: 0\n    W_COLOR: 0\n    W_EXPRESSION: 0\n    W_LIPREADING: 0.01\n    W_LIPREADING_VV: 0\n    W_EYE_BLINK: 4\n\nTAG:\n  NAME:  small_dataset\n\n\n"
  },
  {
    "path": "src/config/auido2pose.yaml",
    "content": "DATASET:\n  TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt\n  EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt\n  TRAIN_BATCH_SIZE: 64\n  EVAL_BATCH_SIZE: 1\n  EXP: True\n  EXP_DIM: 64\n  FRAME_LEN: 32\n  COEFF_LEN: 73\n  NUM_CLASSES: 46\n  AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav\n  COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb\n  DEBUG: True\n  \n\nMODEL:\n  AUDIOENCODER:\n    LEAKY_RELU: True\n    NORM: 'IN'\n  DISCRIMINATOR:\n    LEAKY_RELU: False\n    INPUT_CHANNELS: 6\n  CVAE:\n    AUDIO_EMB_IN_SIZE: 512\n    AUDIO_EMB_OUT_SIZE: 6\n    SEQ_LEN: 32\n    LATENT_SIZE: 64\n    ENCODER_LAYER_SIZES: [192, 128]\n    DECODER_LAYER_SIZES: [128, 192]\n    \n\nTRAIN:\n  MAX_EPOCH: 150\n  GENERATOR:\n    LR: 1.0e-4\n  DISCRIMINATOR:\n    LR: 1.0e-4\n  LOSS:\n    LAMBDA_REG: 1\n    LAMBDA_LANDMARKS: 0\n    LAMBDA_VERTICES: 0\n    LAMBDA_GAN_MOTION: 0.7\n    LAMBDA_GAN_COEFF: 0\n    LAMBDA_KL: 1\n\nTAG:\n  NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder\n\n\n"
  },
  {
    "path": "src/config/facerender.yaml",
    "content": "model_params:\n  common_params:\n    num_kp: 15 \n    image_channel: 3                    \n    feature_channel: 32\n    estimate_jacobian: False   # True\n  kp_detector_params:\n     temperature: 0.1\n     block_expansion: 32            \n     max_features: 1024\n     scale_factor: 0.25         # 0.25\n     num_blocks: 5\n     reshape_channel: 16384  # 16384 = 1024 * 16\n     reshape_depth: 16\n  he_estimator_params:\n     block_expansion: 64            \n     max_features: 2048\n     num_bins: 66\n  generator_params:\n    block_expansion: 64\n    max_features: 512\n    num_down_blocks: 2\n    reshape_channel: 32\n    reshape_depth: 16         # 512 = 32 * 16\n    num_resblocks: 6\n    estimate_occlusion_map: True\n    dense_motion_params:\n      block_expansion: 32\n      max_features: 1024\n      num_blocks: 5\n      reshape_depth: 16\n      compress: 4\n  discriminator_params:\n    scales: [1]\n    block_expansion: 32                 \n    max_features: 512\n    num_blocks: 4\n    sn: True\n  mapping_params:\n      coeff_nc: 70\n      descriptor_nc: 1024\n      layer: 3\n      num_kp: 15\n      num_bins: 66\n\n"
  },
  {
    "path": "src/config/facerender_still.yaml",
    "content": "model_params:\n  common_params:\n    num_kp: 15 \n    image_channel: 3                    \n    feature_channel: 32\n    estimate_jacobian: False   # True\n  kp_detector_params:\n     temperature: 0.1\n     block_expansion: 32            \n     max_features: 1024\n     scale_factor: 0.25         # 0.25\n     num_blocks: 5\n     reshape_channel: 16384  # 16384 = 1024 * 16\n     reshape_depth: 16\n  he_estimator_params:\n     block_expansion: 64            \n     max_features: 2048\n     num_bins: 66\n  generator_params:\n    block_expansion: 64\n    max_features: 512\n    num_down_blocks: 2\n    reshape_channel: 32\n    reshape_depth: 16         # 512 = 32 * 16\n    num_resblocks: 6\n    estimate_occlusion_map: True\n    dense_motion_params:\n      block_expansion: 32\n      max_features: 1024\n      num_blocks: 5\n      reshape_depth: 16\n      compress: 4\n  discriminator_params:\n    scales: [1]\n    block_expansion: 32                 \n    max_features: 512\n    num_blocks: 4\n    sn: True\n  mapping_params:\n      coeff_nc: 73\n      descriptor_nc: 1024\n      layer: 3\n      num_kp: 15\n      num_bins: 66\n\n"
  },
  {
    "path": "src/face3d/data/__init__.py",
    "content": "\"\"\"This package includes all the modules related to data loading and preprocessing\n\n To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.\n You need to implement four functions:\n    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).\n    -- <__len__>:                       return the size of dataset.\n    -- <__getitem__>:                   get a data point from data loader.\n    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.\n\nNow you can use the dataset class by specifying flag '--dataset_mode dummy'.\nSee our template dataset class 'template_dataset.py' for more details.\n\"\"\"\nimport numpy as np\nimport importlib\nimport torch.utils.data\nfrom face3d.data.base_dataset import BaseDataset\n\n\ndef find_dataset_using_name(dataset_name):\n    \"\"\"Import the module \"data/[dataset_name]_dataset.py\".\n\n    In the file, the class called DatasetNameDataset() will\n    be instantiated. It has to be a subclass of BaseDataset,\n    and it is case-insensitive.\n    \"\"\"\n    dataset_filename = \"data.\" + dataset_name + \"_dataset\"\n    datasetlib = importlib.import_module(dataset_filename)\n\n    dataset = None\n    target_dataset_name = dataset_name.replace('_', '') + 'dataset'\n    for name, cls in datasetlib.__dict__.items():\n        if name.lower() == target_dataset_name.lower() \\\n           and issubclass(cls, BaseDataset):\n            dataset = cls\n\n    if dataset is None:\n        raise NotImplementedError(\"In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase.\" % (dataset_filename, target_dataset_name))\n\n    return dataset\n\n\ndef get_option_setter(dataset_name):\n    \"\"\"Return the static method <modify_commandline_options> of the dataset class.\"\"\"\n    dataset_class = find_dataset_using_name(dataset_name)\n    return dataset_class.modify_commandline_options\n\n\ndef create_dataset(opt, rank=0):\n    \"\"\"Create a dataset given the option.\n\n    This function wraps the class CustomDatasetDataLoader.\n        This is the main interface between this package and 'train.py'/'test.py'\n\n    Example:\n        >>> from data import create_dataset\n        >>> dataset = create_dataset(opt)\n    \"\"\"\n    data_loader = CustomDatasetDataLoader(opt, rank=rank)\n    dataset = data_loader.load_data()\n    return dataset\n\nclass CustomDatasetDataLoader():\n    \"\"\"Wrapper class of Dataset class that performs multi-threaded data loading\"\"\"\n\n    def __init__(self, opt, rank=0):\n        \"\"\"Initialize this class\n\n        Step 1: create a dataset instance given the name [dataset_mode]\n        Step 2: create a multi-threaded data loader.\n        \"\"\"\n        self.opt = opt\n        dataset_class = find_dataset_using_name(opt.dataset_mode)\n        self.dataset = dataset_class(opt)\n        self.sampler = None\n        print(\"rank %d %s dataset [%s] was created\" % (rank, self.dataset.name, type(self.dataset).__name__))\n        if opt.use_ddp and opt.isTrain:\n            world_size = opt.world_size\n            self.sampler = torch.utils.data.distributed.DistributedSampler(\n                    self.dataset,\n                    num_replicas=world_size,\n                    rank=rank,\n                    shuffle=not opt.serial_batches\n                )\n            self.dataloader = torch.utils.data.DataLoader(\n                        self.dataset,\n                        sampler=self.sampler,\n                        num_workers=int(opt.num_threads / world_size), \n                        batch_size=int(opt.batch_size / world_size), \n                        drop_last=True)\n        else:\n            self.dataloader = torch.utils.data.DataLoader(\n                self.dataset,\n                batch_size=opt.batch_size,\n                shuffle=(not opt.serial_batches) and opt.isTrain,\n                num_workers=int(opt.num_threads),\n                drop_last=True\n            )\n\n    def set_epoch(self, epoch):\n        self.dataset.current_epoch = epoch\n        if self.sampler is not None:\n            self.sampler.set_epoch(epoch)\n\n    def load_data(self):\n        return self\n\n    def __len__(self):\n        \"\"\"Return the number of data in the dataset\"\"\"\n        return min(len(self.dataset), self.opt.max_dataset_size)\n\n    def __iter__(self):\n        \"\"\"Return a batch of data\"\"\"\n        for i, data in enumerate(self.dataloader):\n            if i * self.opt.batch_size >= self.opt.max_dataset_size:\n                break\n            yield data\n"
  },
  {
    "path": "src/face3d/data/base_dataset.py",
    "content": "\"\"\"This module implements an abstract base class (ABC) 'BaseDataset' for datasets.\n\nIt also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.\n\"\"\"\nimport random\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\nfrom abc import ABC, abstractmethod\n\n\nclass BaseDataset(data.Dataset, ABC):\n    \"\"\"This class is an abstract base class (ABC) for datasets.\n\n    To create a subclass, you need to implement the following four functions:\n    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).\n    -- <__len__>:                       return the size of dataset.\n    -- <__getitem__>:                   get a data point.\n    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize the class; save the options in the class\n\n        Parameters:\n            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        self.opt = opt\n        # self.root = opt.dataroot\n        self.current_epoch = 0\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        \"\"\"Add new dataset-specific options, and rewrite default values for existing options.\n\n        Parameters:\n            parser          -- original option parser\n            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n        \"\"\"\n        return parser\n\n    @abstractmethod\n    def __len__(self):\n        \"\"\"Return the total number of images in the dataset.\"\"\"\n        return 0\n\n    @abstractmethod\n    def __getitem__(self, index):\n        \"\"\"Return a data point and its metadata information.\n\n        Parameters:\n            index - - a random integer for data indexing\n\n        Returns:\n            a dictionary of data with their names. It ususally contains the data itself and its metadata information.\n        \"\"\"\n        pass\n\n\ndef get_transform(grayscale=False):\n    transform_list = []\n    if grayscale:\n        transform_list.append(transforms.Grayscale(1))\n    transform_list += [transforms.ToTensor()]\n    return transforms.Compose(transform_list)\n\ndef get_affine_mat(opt, size):\n    shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False\n    w, h = size\n\n    if 'shift' in opt.preprocess:\n        shift_pixs = int(opt.shift_pixs)\n        shift_x = random.randint(-shift_pixs, shift_pixs)\n        shift_y = random.randint(-shift_pixs, shift_pixs)\n    if 'scale' in opt.preprocess:\n        scale = 1 + opt.scale_delta * (2 * random.random() - 1)\n    if 'rot' in opt.preprocess:\n        rot_angle = opt.rot_angle * (2 * random.random() - 1)\n        rot_rad = -rot_angle * np.pi/180\n    if 'flip' in opt.preprocess:\n        flip = random.random() > 0.5\n\n    shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])\n    flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])\n    shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])\n    rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])\n    scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])\n    shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])\n    \n    affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin    \n    affine_inv = np.linalg.inv(affine)\n    return affine, affine_inv, flip\n\ndef apply_img_affine(img, affine_inv, method=Image.BICUBIC):\n    return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)\n\ndef apply_lm_affine(landmark, affine, flip, size):\n    _, h = size\n    lm = landmark.copy()\n    lm[:, 1] = h - 1 - lm[:, 1]\n    lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)\n    lm = lm @ np.transpose(affine)\n    lm[:, :2] = lm[:, :2] / lm[:, 2:]\n    lm = lm[:, :2]\n    lm[:, 1] = h - 1 - lm[:, 1]\n    if flip:\n        lm_ = lm.copy()\n        lm_[:17] = lm[16::-1]\n        lm_[17:22] = lm[26:21:-1]\n        lm_[22:27] = lm[21:16:-1]\n        lm_[31:36] = lm[35:30:-1]\n        lm_[36:40] = lm[45:41:-1]\n        lm_[40:42] = lm[47:45:-1]\n        lm_[42:46] = lm[39:35:-1]\n        lm_[46:48] = lm[41:39:-1]\n        lm_[48:55] = lm[54:47:-1]\n        lm_[55:60] = lm[59:54:-1]\n        lm_[60:65] = lm[64:59:-1]\n        lm_[65:68] = lm[67:64:-1]\n        lm = lm_\n    return lm\n"
  },
  {
    "path": "src/face3d/data/flist_dataset.py",
    "content": "\"\"\"This script defines the custom dataset for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os.path\nfrom data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine\nfrom data.image_folder import make_dataset\nfrom PIL import Image\nimport random\nimport util.util as util\nimport numpy as np\nimport json\nimport torch\nfrom scipy.io import loadmat, savemat\nimport pickle\nfrom util.preprocess import align_img, estimate_norm\nfrom util.load_mats import load_lm3d\n\n\ndef default_flist_reader(flist):\n    \"\"\"\n    flist format: impath label\\nimpath label\\n ...(same to caffe's filelist)\n    \"\"\"\n    imlist = []\n    with open(flist, 'r') as rf:\n        for line in rf.readlines():\n            impath = line.strip()\n            imlist.append(impath)\n\n    return imlist\n\ndef jason_flist_reader(flist):\n    with open(flist, 'r') as fp:\n        info = json.load(fp)\n    return info\n\ndef parse_label(label):\n    return torch.tensor(np.array(label).astype(np.float32))\n\n\nclass FlistDataset(BaseDataset):\n    \"\"\"\n    It requires one directories to host training images '/path/to/data/train'\n    You can train the model with the dataset flag '--dataroot /path/to/data'.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize this dataset class.\n\n        Parameters:\n            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        BaseDataset.__init__(self, opt)\n        \n        self.lm3d_std = load_lm3d(opt.bfm_folder)\n        \n        msk_names = default_flist_reader(opt.flist)\n        self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]\n\n        self.size = len(self.msk_paths) \n        self.opt = opt\n        \n        self.name = 'train' if opt.isTrain else 'val'\n        if '_' in opt.flist:\n            self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]\n        \n\n    def __getitem__(self, index):\n        \"\"\"Return a data point and its metadata information.\n\n        Parameters:\n            index (int)      -- a random integer for data indexing\n\n        Returns a dictionary that contains A, B, A_paths and B_paths\n            img (tensor)       -- an image in the input domain\n            msk (tensor)       -- its corresponding attention mask\n            lm  (tensor)       -- its corresponding 3d landmarks\n            im_paths (str)     -- image paths\n            aug_flag (bool)    -- a flag used to tell whether its raw or augmented\n        \"\"\"\n        msk_path = self.msk_paths[index % self.size]  # make sure index is within then range\n        img_path = msk_path.replace('mask/', '')\n        lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'\n\n        raw_img = Image.open(img_path).convert('RGB')\n        raw_msk = Image.open(msk_path).convert('RGB')\n        raw_lm = np.loadtxt(lm_path).astype(np.float32)\n\n        _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)\n        \n        aug_flag = self.opt.use_aug and self.opt.isTrain\n        if aug_flag:\n            img, lm, msk = self._augmentation(img, lm, self.opt, msk)\n        \n        _, H = img.size\n        M = estimate_norm(lm, H)\n        transform = get_transform()\n        img_tensor = transform(img)\n        msk_tensor = transform(msk)[:1, ...]\n        lm_tensor = parse_label(lm)\n        M_tensor = parse_label(M)\n\n\n        return {'imgs': img_tensor, \n                'lms': lm_tensor, \n                'msks': msk_tensor, \n                'M': M_tensor,\n                'im_paths': img_path, \n                'aug_flag': aug_flag,\n                'dataset': self.name}\n\n    def _augmentation(self, img, lm, opt, msk=None):\n        affine, affine_inv, flip = get_affine_mat(opt, img.size)\n        img = apply_img_affine(img, affine_inv)\n        lm = apply_lm_affine(lm, affine, flip, img.size)\n        if msk is not None:\n            msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)\n        return img, lm, msk\n    \n\n\n\n    def __len__(self):\n        \"\"\"Return the total number of images in the dataset.\n        \"\"\"\n        return self.size\n"
  },
  {
    "path": "src/face3d/data/image_folder.py",
    "content": "\"\"\"A modified image folder class\n\nWe modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)\nso that this class can load images from both current directory and its subdirectories.\n\"\"\"\nimport numpy as np\nimport torch.utils.data as data\n\nfrom PIL import Image\nimport os\nimport os.path\n\nIMG_EXTENSIONS = [\n    '.jpg', '.JPG', '.jpeg', '.JPEG',\n    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',\n    '.tif', '.TIF', '.tiff', '.TIFF',\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef make_dataset(dir, max_dataset_size=float(\"inf\")):\n    images = []\n    assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir\n\n    for root, _, fnames in sorted(os.walk(dir, followlinks=True)):\n        for fname in fnames:\n            if is_image_file(fname):\n                path = os.path.join(root, fname)\n                images.append(path)\n    return images[:min(max_dataset_size, len(images))]\n\n\ndef default_loader(path):\n    return Image.open(path).convert('RGB')\n\n\nclass ImageFolder(data.Dataset):\n\n    def __init__(self, root, transform=None, return_paths=False,\n                 loader=default_loader):\n        imgs = make_dataset(root)\n        if len(imgs) == 0:\n            raise(RuntimeError(\"Found 0 images in: \" + root + \"\\n\"\n                               \"Supported image extensions are: \" + \",\".join(IMG_EXTENSIONS)))\n\n        self.root = root\n        self.imgs = imgs\n        self.transform = transform\n        self.return_paths = return_paths\n        self.loader = loader\n\n    def __getitem__(self, index):\n        path = self.imgs[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.return_paths:\n            return img, path\n        else:\n            return img\n\n    def __len__(self):\n        return len(self.imgs)\n"
  },
  {
    "path": "src/face3d/data/template_dataset.py",
    "content": "\"\"\"Dataset class template\n\nThis module provides a template for users to implement custom datasets.\nYou can specify '--dataset_mode template' to use this dataset.\nThe class name should be consistent with both the filename and its dataset_mode option.\nThe filename should be <dataset_mode>_dataset.py\nThe class name should be <Dataset_mode>Dataset.py\nYou need to implement the following functions:\n    -- <modify_commandline_options>:　Add dataset-specific options and rewrite default values for existing options.\n    -- <__init__>: Initialize this dataset class.\n    -- <__getitem__>: Return a data point and its metadata information.\n    -- <__len__>: Return the number of images.\n\"\"\"\nfrom data.base_dataset import BaseDataset, get_transform\n# from data.image_folder import make_dataset\n# from PIL import Image\n\n\nclass TemplateDataset(BaseDataset):\n    \"\"\"A template dataset class for you to implement custom datasets.\"\"\"\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        \"\"\"Add new dataset-specific options, and rewrite default values for existing options.\n\n        Parameters:\n            parser          -- original option parser\n            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n        \"\"\"\n        parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')\n        parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0)  # specify dataset-specific default values\n        return parser\n\n    def __init__(self, opt):\n        \"\"\"Initialize this dataset class.\n\n        Parameters:\n            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions\n\n        A few things can be done here.\n        - save the options (have been done in BaseDataset)\n        - get image paths and meta information of the dataset.\n        - define the image transformation.\n        \"\"\"\n        # save the option and dataset root\n        BaseDataset.__init__(self, opt)\n        # get the image paths of your dataset;\n        self.image_paths = []  # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root\n        # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function\n        self.transform = get_transform(opt)\n\n    def __getitem__(self, index):\n        \"\"\"Return a data point and its metadata information.\n\n        Parameters:\n            index -- a random integer for data indexing\n\n        Returns:\n            a dictionary of data with their names. It usually contains the data itself and its metadata information.\n\n        Step 1: get a random image path: e.g., path = self.image_paths[index]\n        Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').\n        Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)\n        Step 4: return a data point as a dictionary.\n        \"\"\"\n        path = 'temp'    # needs to be a string\n        data_A = None    # needs to be a tensor\n        data_B = None    # needs to be a tensor\n        return {'data_A': data_A, 'data_B': data_B, 'path': path}\n\n    def __len__(self):\n        \"\"\"Return the total number of images.\"\"\"\n        return len(self.image_paths)\n"
  },
  {
    "path": "src/face3d/extract_kp_videos.py",
    "content": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport face_alignment\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom itertools import cycle\n\nfrom torch.multiprocessing import Pool, Process, set_start_method\n\nclass KeypointExtractor():\n    def __init__(self, device):\n        self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, \n                                                     device=device)   \n\n    def extract_keypoint(self, images, name=None, info=True):\n        if isinstance(images, list):\n            keypoints = []\n            if info:\n                i_range = tqdm(images,desc='landmark Det:')\n            else:\n                i_range = images\n\n            for image in i_range:\n                current_kp = self.extract_keypoint(image)\n                if np.mean(current_kp) == -1 and keypoints:\n                    keypoints.append(keypoints[-1])\n                else:\n                    keypoints.append(current_kp[None])\n\n            keypoints = np.concatenate(keypoints, 0)\n            np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))\n            return keypoints\n        else:\n            while True:\n                try:\n                    keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]\n                    break\n                except RuntimeError as e:\n                    if str(e).startswith('CUDA'):\n                        print(\"Warning: out of memory, sleep for 1s\")\n                        time.sleep(1)\n                    else:\n                        print(e)\n                        break    \n                except TypeError:\n                    print('No face detected in this image')\n                    shape = [68, 2]\n                    keypoints = -1. * np.ones(shape)                    \n                    break\n            if name is not None:\n                np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))\n            return keypoints\n\ndef read_video(filename):\n    frames = []\n    cap = cv2.VideoCapture(filename)\n    while cap.isOpened():\n        ret, frame = cap.read()\n        if ret:\n            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            frame = Image.fromarray(frame)\n            frames.append(frame)\n        else:\n            break\n    cap.release()\n    return frames\n\ndef run(data):\n    filename, opt, device = data\n    os.environ['CUDA_VISIBLE_DEVICES'] = device\n    kp_extractor = KeypointExtractor()\n    images = read_video(filename)\n    name = filename.split('/')[-2:]\n    os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)\n    kp_extractor.extract_keypoint(\n        images, \n        name=os.path.join(opt.output_dir, name[-2], name[-1])\n    )\n\nif __name__ == '__main__':\n    set_start_method('spawn')\n    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument('--input_dir', type=str, help='the folder of the input files')\n    parser.add_argument('--output_dir', type=str, help='the folder of the output files')\n    parser.add_argument('--device_ids', type=str, default='0,1')\n    parser.add_argument('--workers', type=int, default=4)\n\n    opt = parser.parse_args()\n    filenames = list()\n    VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}\n    VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})\n    extensions = VIDEO_EXTENSIONS\n    \n    for ext in extensions:\n        os.listdir(f'{opt.input_dir}')\n        print(f'{opt.input_dir}/*.{ext}')\n        filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))\n    print('Total number of videos:', len(filenames))\n    pool = Pool(opt.workers)\n    args_list = cycle([opt])\n    device_ids = opt.device_ids.split(\",\")\n    device_ids = cycle(device_ids)\n    for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):\n        None\n"
  },
  {
    "path": "src/face3d/extract_kp_videos_safe.py",
    "content": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport numpy as np\nfrom PIL import Image\nimport torch\nfrom tqdm import tqdm\nfrom itertools import cycle\nfrom torch.multiprocessing import Pool, Process, set_start_method\n\nfrom facexlib.alignment import landmark_98_to_68\nfrom facexlib.detection import init_detection_model\n\nfrom facexlib.utils import load_file_from_url\nfrom src.face3d.util.my_awing_arch import FAN\n\ndef init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):\n    if model_name == 'awing_fan':\n        model = FAN(num_modules=4, num_landmarks=98, device=device)\n        model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'\n    else:\n        raise NotImplementedError(f'{model_name} is not implemented.')\n\n    model_path = load_file_from_url(\n        url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)\n    model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)\n    model.eval()\n    model = model.to(device)\n    return model\n\n\nclass KeypointExtractor():\n    def __init__(self, device='cuda'):\n\n        ### gfpgan/weights\n        try:\n            import webui  # in webui\n            root_path = 'extensions/SadTalker/gfpgan/weights' \n\n        except:\n            root_path = 'gfpgan/weights'\n\n        self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)   \n        self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)\n\n    def extract_keypoint(self, images, name=None, info=True):\n        if isinstance(images, list):\n            keypoints = []\n            if info:\n                i_range = tqdm(images,desc='landmark Det:')\n            else:\n                i_range = images\n\n            for image in i_range:\n                current_kp = self.extract_keypoint(image)\n                # current_kp = self.detector.get_landmarks(np.array(image))\n                if np.mean(current_kp) == -1 and keypoints:\n                    keypoints.append(keypoints[-1])\n                else:\n                    keypoints.append(current_kp[None])\n\n            keypoints = np.concatenate(keypoints, 0)\n            np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))\n            return keypoints\n        else:\n            while True:\n                try:\n                    with torch.no_grad():\n                        # face detection -> face alignment.\n                        img = np.array(images)\n                        bboxes = self.det_net.detect_faces(images, 0.97)\n                        \n                        bboxes = bboxes[0]\n                        img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]\n\n                        keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]\n\n                        #### keypoints to the original location\n                        keypoints[:,0] += int(bboxes[0])\n                        keypoints[:,1] += int(bboxes[1])\n\n                        break\n                except RuntimeError as e:\n                    if str(e).startswith('CUDA'):\n                        print(\"Warning: out of memory, sleep for 1s\")\n                        time.sleep(1)\n                    else:\n                        print(e)\n                        break    \n                except TypeError:\n                    print('No face detected in this image')\n                    shape = [68, 2]\n                    keypoints = -1. * np.ones(shape)                    \n                    break\n            if name is not None:\n                np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))\n            return keypoints\n\ndef read_video(filename):\n    frames = []\n    cap = cv2.VideoCapture(filename)\n    while cap.isOpened():\n        ret, frame = cap.read()\n        if ret:\n            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            frame = Image.fromarray(frame)\n            frames.append(frame)\n        else:\n            break\n    cap.release()\n    return frames\n\ndef run(data):\n    filename, opt, device = data\n    os.environ['CUDA_VISIBLE_DEVICES'] = device\n    kp_extractor = KeypointExtractor()\n    images = read_video(filename)\n    name = filename.split('/')[-2:]\n    os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)\n    kp_extractor.extract_keypoint(\n        images, \n        name=os.path.join(opt.output_dir, name[-2], name[-1])\n    )\n\nif __name__ == '__main__':\n    set_start_method('spawn')\n    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument('--input_dir', type=str, help='the folder of the input files')\n    parser.add_argument('--output_dir', type=str, help='the folder of the output files')\n    parser.add_argument('--device_ids', type=str, default='0,1')\n    parser.add_argument('--workers', type=int, default=4)\n\n    opt = parser.parse_args()\n    filenames = list()\n    VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}\n    VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})\n    extensions = VIDEO_EXTENSIONS\n    \n    for ext in extensions:\n        os.listdir(f'{opt.input_dir}')\n        print(f'{opt.input_dir}/*.{ext}')\n        filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))\n    print('Total number of videos:', len(filenames))\n    pool = Pool(opt.workers)\n    args_list = cycle([opt])\n    device_ids = opt.device_ids.split(\",\")\n    device_ids = cycle(device_ids)\n    for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):\n        None\n"
  },
  {
    "path": "src/face3d/models/__init__.py",
    "content": "\"\"\"This package contains modules related to objective functions, optimizations, and network architectures.\n\nTo add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.\nYou need to implement the following five functions:\n    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).\n    -- <set_input>:                     unpack data from dataset and apply preprocessing.\n    -- <forward>:                       produce intermediate results.\n    -- <optimize_parameters>:           calculate loss, gradients, and update network weights.\n    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.\n\nIn the function <__init__>, you need to define four lists:\n    -- self.loss_names (str list):          specify the training losses that you want to plot and save.\n    -- self.model_names (str list):         define networks used in our training.\n    -- self.visual_names (str list):        specify the images that you want to display and save.\n    -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.\n\nNow you can use the model class by specifying flag '--model dummy'.\nSee our template model class 'template_model.py' for more details.\n\"\"\"\n\nimport importlib\nfrom src.face3d.models.base_model import BaseModel\n\n\ndef find_model_using_name(model_name):\n    \"\"\"Import the module \"models/[model_name]_model.py\".\n\n    In the file, the class called DatasetNameModel() will\n    be instantiated. It has to be a subclass of BaseModel,\n    and it is case-insensitive.\n    \"\"\"\n    model_filename = \"face3d.models.\" + model_name + \"_model\"\n    modellib = importlib.import_module(model_filename)\n    model = None\n    target_model_name = model_name.replace('_', '') + 'model'\n    for name, cls in modellib.__dict__.items():\n        if name.lower() == target_model_name.lower() \\\n           and issubclass(cls, BaseModel):\n            model = cls\n\n    if model is None:\n        print(\"In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase.\" % (model_filename, target_model_name))\n        exit(0)\n\n    return model\n\n\ndef get_option_setter(model_name):\n    \"\"\"Return the static method <modify_commandline_options> of the model class.\"\"\"\n    model_class = find_model_using_name(model_name)\n    return model_class.modify_commandline_options\n\n\ndef create_model(opt):\n    \"\"\"Create a model given the option.\n\n    This function warps the class CustomDatasetDataLoader.\n    This is the main interface between this package and 'train.py'/'test.py'\n\n    Example:\n        >>> from models import create_model\n        >>> model = create_model(opt)\n    \"\"\"\n    model = find_model_using_name(opt.model)\n    instance = model(opt)\n    print(\"model [%s] was created\" % type(instance).__name__)\n    return instance\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/README.md",
    "content": "# Distributed Arcface Training in Pytorch\n\nThis is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions\nidentity on a single server.\n\n## Requirements\n\n- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).\n- `pip install -r requirements.txt`.\n- Download the dataset\n  from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)\n  .\n\n## How to Training\n\nTo train a model, run `train.py` with the path to the configs:\n\n### 1. Single node, 8 GPUs:\n\n```shell\npython -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=\"127.0.0.1\" --master_port=1234 train.py configs/ms1mv3_r50\n```\n\n### 2. Multiple nodes, each node 8 GPUs:\n\nNode 0:\n\n```shell\npython -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=\"ip1\" --master_port=1234 train.py train.py configs/ms1mv3_r50\n```\n\nNode 1:\n\n```shell\npython -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=\"ip1\" --master_port=1234 train.py train.py configs/ms1mv3_r50\n```\n\n### 3.Training resnet2060 with 8 GPUs:\n\n```shell\npython -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=\"127.0.0.1\" --master_port=1234 train.py configs/ms1mv3_r2060.py\n```\n\n## Model Zoo\n\n- The models are available for non-commercial research purposes only.  \n- All models can be found in here.  \n- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g):   e8pw  \n- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)\n\n### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)\n\nICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face \nrecognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. \nAs the result, we can evaluate the FAIR performance for different algorithms.  \n\nFor **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The \nglobalised multi-racial testset contains 242,143 identities and 1,624,305 images. \n\nFor **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). \nMask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. \nThere are totally 13,928 positive pairs and 96,983,824 negative pairs.\n\n| Datasets | backbone  | Training throughout | Size / MB  | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |\n| :---:    | :---      | :---                | :---       |:---                   |:---                  |     \n| MS1MV3    | r18  | -              | 91   | **47.85** | **68.33** |\n| Glint360k | r18  | 8536           | 91   | **53.32** | **72.07** |\n| MS1MV3    | r34  | -              | 130  | **58.72** | **77.36** |\n| Glint360k | r34  | 6344           | 130  | **65.10** | **83.02** |\n| MS1MV3    | r50  | 5500           | 166  | **63.85** | **80.53** |\n| Glint360k | r50  | 5136           | 166  | **70.23** | **87.08** |\n| MS1MV3    | r100 | -              | 248  | **69.09** | **84.31** |\n| Glint360k | r100 | 3332           | 248  | **75.57** | **90.66** |\n| MS1MV3    | mobilefacenet | 12185 | 7.8  | **41.52** | **65.26** |        \n| Glint360k | mobilefacenet | 11197 | 7.8  | **44.52** | **66.48** |  \n\n### Performance on IJB-C and Verification Datasets\n\n|   Datasets | backbone      | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw  |  log    |\n| :---:      |    :---       | :---          | :---  | :---  |:---   |:---    |:---     |  \n| MS1MV3     | r18      | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|         \n| MS1MV3     | r34      | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|        \n| MS1MV3     | r50      | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|         \n| MS1MV3     | r100     | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|        \n| MS1MV3     | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|\n| Glint360k  |r18-0.1   | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| \n| Glint360k  |r34-0.1   | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| \n| Glint360k  |r50-0.1   | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| \n| Glint360k  |r100-0.1  | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|\n\n[comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)\n\n\n## [Speed Benchmark](docs/speed_benchmark.md)\n\n**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of\nclasses in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same\naccuracy with several times faster training performance and smaller GPU memory. \nPartial FC is a sparse variant of the model parallel architecture for large sacle  face recognition. Partial FC use a \nsparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a \nsparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, \nwe can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed \ntraining and mixed precision training.\n\n![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)\n\nMore details see \n[speed_benchmark.md](docs/speed_benchmark.md) in docs.\n\n### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)\n\n`-` means training failed because of gpu memory limitations.\n\n| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |\n| :---    | :--- | :--- | :--- |\n|125000   | 4681         | 4824          | 5004     |\n|1400000  | **1672**     | 3043          | 4738     |\n|5500000  | **-**        | **1389**      | 3975     |\n|8000000  | **-**        | **-**         | 3565     |\n|16000000 | **-**        | **-**         | 2679     |\n|29000000 | **-**        | **-**         | **1855** |\n\n### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)\n\n| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |\n| :---    | :---      | :---      | :---  |\n|125000   | 7358      | 5306      | 4868  |\n|1400000  | 32252     | 11178     | 6056  |\n|5500000  | **-**     | 32188     | 9854  |\n|8000000  | **-**     | **-**     | 12310 |\n|16000000 | **-**     | **-**     | 19950 |\n|29000000 | **-**     | **-**     | 32324 |\n\n## Evaluation ICCV2021-MFR and IJB-C\n\nMore details see [eval.md](docs/eval.md) in docs.\n\n## Test\n\nWe tested many versions of PyTorch. Please create an issue if you are having trouble.  \n\n- [x] torch 1.6.0\n- [x] torch 1.7.1\n- [x] torch 1.8.0\n- [x] torch 1.9.0\n\n## Citation\n\n```\n@inproceedings{deng2019arcface,\n  title={Arcface: Additive angular margin loss for deep face recognition},\n  author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},\n  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},\n  pages={4690--4699},\n  year={2019}\n}\n@inproceedings{an2020partical_fc,\n  title={Partial FC: Training 10 Million Identities on a Single Machine},\n  author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and\n  Zhang, Debing and Fu Ying},\n  booktitle={Arxiv 2010.05222},\n  year={2020}\n}\n```\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/backbones/__init__.py",
    "content": "from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200\nfrom .mobilefacenet import get_mbf\n\n\ndef get_model(name, **kwargs):\n    # resnet\n    if name == \"r18\":\n        return iresnet18(False, **kwargs)\n    elif name == \"r34\":\n        return iresnet34(False, **kwargs)\n    elif name == \"r50\":\n        return iresnet50(False, **kwargs)\n    elif name == \"r100\":\n        return iresnet100(False, **kwargs)\n    elif name == \"r200\":\n        return iresnet200(False, **kwargs)\n    elif name == \"r2060\":\n        from .iresnet2060 import iresnet2060\n        return iresnet2060(False, **kwargs)\n    elif name == \"mbf\":\n        fp16 = kwargs.get(\"fp16\", False)\n        num_features = kwargs.get(\"num_features\", 512)\n        return get_mbf(fp16=fp16, num_features=num_features)\n    else:\n        raise ValueError()"
  },
  {
    "path": "src/face3d/models/arcface_torch/backbones/iresnet.py",
    "content": "import torch\nfrom torch import nn\n\n__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes,\n                     out_planes,\n                     kernel_size=3,\n                     stride=stride,\n                     padding=dilation,\n                     groups=groups,\n                     bias=False,\n                     dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes,\n                     out_planes,\n                     kernel_size=1,\n                     stride=stride,\n                     bias=False)\n\n\nclass IBasicBlock(nn.Module):\n    expansion = 1\n    def __init__(self, inplanes, planes, stride=1, downsample=None,\n                 groups=1, base_width=64, dilation=1):\n        super(IBasicBlock, self).__init__()\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)\n        self.conv1 = conv3x3(inplanes, planes)\n        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)\n        self.prelu = nn.PReLU(planes)\n        self.conv2 = conv3x3(planes, planes, stride)\n        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n        out = self.bn1(x)\n        out = self.conv1(out)\n        out = self.bn2(out)\n        out = self.prelu(out)\n        out = self.conv2(out)\n        out = self.bn3(out)\n        if self.downsample is not None:\n            identity = self.downsample(x)\n        out += identity\n        return out\n\n\nclass IResNet(nn.Module):\n    fc_scale = 7 * 7\n    def __init__(self,\n                 block, layers, dropout=0, num_features=512, zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):\n        super(IResNet, self).__init__()\n        self.fp16 = fp16\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)\n        self.prelu = nn.PReLU(self.inplanes)\n        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)\n        self.layer2 = self._make_layer(block,\n                                       128,\n                                       layers[1],\n                                       stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block,\n                                       256,\n                                       layers[2],\n                                       stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block,\n                                       512,\n                                       layers[3],\n                                       stride=2,\n                                       dilate=replace_stride_with_dilation[2])\n        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)\n        self.dropout = nn.Dropout(p=dropout, inplace=True)\n        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)\n        self.features = nn.BatchNorm1d(num_features, eps=1e-05)\n        nn.init.constant_(self.features.weight, 1.0)\n        self.features.weight.requires_grad = False\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.normal_(m.weight, 0, 0.1)\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, IBasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),\n            )\n        layers = []\n        layers.append(\n            block(self.inplanes, planes, stride, downsample, self.groups,\n                  self.base_width, previous_dilation))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(\n                block(self.inplanes,\n                      planes,\n                      groups=self.groups,\n                      base_width=self.base_width,\n                      dilation=self.dilation))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        with torch.cuda.amp.autocast(self.fp16):\n            x = self.conv1(x)\n            x = self.bn1(x)\n            x = self.prelu(x)\n            x = self.layer1(x)\n            x = self.layer2(x)\n            x = self.layer3(x)\n            x = self.layer4(x)\n            x = self.bn2(x)\n            x = torch.flatten(x, 1)\n            x = self.dropout(x)\n        x = self.fc(x.float() if self.fp16 else x)\n        x = self.features(x)\n        return x\n\n\ndef _iresnet(arch, block, layers, pretrained, progress, **kwargs):\n    model = IResNet(block, layers, **kwargs)\n    if pretrained:\n        raise ValueError()\n    return model\n\n\ndef iresnet18(pretrained=False, progress=True, **kwargs):\n    return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,\n                    progress, **kwargs)\n\n\ndef iresnet34(pretrained=False, progress=True, **kwargs):\n    return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,\n                    progress, **kwargs)\n\n\ndef iresnet50(pretrained=False, progress=True, **kwargs):\n    return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,\n                    progress, **kwargs)\n\n\ndef iresnet100(pretrained=False, progress=True, **kwargs):\n    return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,\n                    progress, **kwargs)\n\n\ndef iresnet200(pretrained=False, progress=True, **kwargs):\n    return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,\n                    progress, **kwargs)\n\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/backbones/iresnet2060.py",
    "content": "import torch\nfrom torch import nn\n\nassert torch.__version__ >= \"1.8.1\"\nfrom torch.utils.checkpoint import checkpoint_sequential\n\n__all__ = ['iresnet2060']\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes,\n                     out_planes,\n                     kernel_size=3,\n                     stride=stride,\n                     padding=dilation,\n                     groups=groups,\n                     bias=False,\n                     dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes,\n                     out_planes,\n                     kernel_size=1,\n                     stride=stride,\n                     bias=False)\n\n\nclass IBasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None,\n                 groups=1, base_width=64, dilation=1):\n        super(IBasicBlock, self).__init__()\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )\n        self.conv1 = conv3x3(inplanes, planes)\n        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )\n        self.prelu = nn.PReLU(planes)\n        self.conv2 = conv3x3(planes, planes, stride)\n        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n        out = self.bn1(x)\n        out = self.conv1(out)\n        out = self.bn2(out)\n        out = self.prelu(out)\n        out = self.conv2(out)\n        out = self.bn3(out)\n        if self.downsample is not None:\n            identity = self.downsample(x)\n        out += identity\n        return out\n\n\nclass IResNet(nn.Module):\n    fc_scale = 7 * 7\n\n    def __init__(self,\n                 block, layers, dropout=0, num_features=512, zero_init_residual=False,\n                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):\n        super(IResNet, self).__init__()\n        self.fp16 = fp16\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)\n        self.prelu = nn.PReLU(self.inplanes)\n        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)\n        self.layer2 = self._make_layer(block,\n                                       128,\n                                       layers[1],\n                                       stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block,\n                                       256,\n                                       layers[2],\n                                       stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block,\n                                       512,\n                                       layers[3],\n                                       stride=2,\n                                       dilate=replace_stride_with_dilation[2])\n        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )\n        self.dropout = nn.Dropout(p=dropout, inplace=True)\n        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)\n        self.features = nn.BatchNorm1d(num_features, eps=1e-05)\n        nn.init.constant_(self.features.weight, 1.0)\n        self.features.weight.requires_grad = False\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.normal_(m.weight, 0, 0.1)\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, IBasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),\n            )\n        layers = []\n        layers.append(\n            block(self.inplanes, planes, stride, downsample, self.groups,\n                  self.base_width, previous_dilation))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(\n                block(self.inplanes,\n                      planes,\n                      groups=self.groups,\n                      base_width=self.base_width,\n                      dilation=self.dilation))\n\n        return nn.Sequential(*layers)\n\n    def checkpoint(self, func, num_seg, x):\n        if self.training:\n            return checkpoint_sequential(func, num_seg, x)\n        else:\n            return func(x)\n\n    def forward(self, x):\n        with torch.cuda.amp.autocast(self.fp16):\n            x = self.conv1(x)\n            x = self.bn1(x)\n            x = self.prelu(x)\n            x = self.layer1(x)\n            x = self.checkpoint(self.layer2, 20, x)\n            x = self.checkpoint(self.layer3, 100, x)\n            x = self.layer4(x)\n            x = self.bn2(x)\n            x = torch.flatten(x, 1)\n            x = self.dropout(x)\n        x = self.fc(x.float() if self.fp16 else x)\n        x = self.features(x)\n        return x\n\n\ndef _iresnet(arch, block, layers, pretrained, progress, **kwargs):\n    model = IResNet(block, layers, **kwargs)\n    if pretrained:\n        raise ValueError()\n    return model\n\n\ndef iresnet2060(pretrained=False, progress=True, **kwargs):\n    return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/backbones/mobilefacenet.py",
    "content": "'''\nAdapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py\nOriginal author cavalleria\n'''\n\nimport torch.nn as nn\nfrom torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module\nimport torch\n\n\nclass Flatten(Module):\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n\n\nclass ConvBlock(Module):\n    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):\n        super(ConvBlock, self).__init__()\n        self.layers = nn.Sequential(\n            Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),\n            BatchNorm2d(num_features=out_c),\n            PReLU(num_parameters=out_c)\n        )\n\n    def forward(self, x):\n        return self.layers(x)\n\n\nclass LinearBlock(Module):\n    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):\n        super(LinearBlock, self).__init__()\n        self.layers = nn.Sequential(\n            Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),\n            BatchNorm2d(num_features=out_c)\n        )\n\n    def forward(self, x):\n        return self.layers(x)\n\n\nclass DepthWise(Module):\n    def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):\n        super(DepthWise, self).__init__()\n        self.residual = residual\n        self.layers = nn.Sequential(\n            ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),\n            ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),\n            LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))\n        )\n\n    def forward(self, x):\n        short_cut = None\n        if self.residual:\n            short_cut = x\n        x = self.layers(x)\n        if self.residual:\n            output = short_cut + x\n        else:\n            output = x\n        return output\n\n\nclass Residual(Module):\n    def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):\n        super(Residual, self).__init__()\n        modules = []\n        for _ in range(num_block):\n            modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))\n        self.layers = Sequential(*modules)\n\n    def forward(self, x):\n        return self.layers(x)\n\n\nclass GDC(Module):\n    def __init__(self, embedding_size):\n        super(GDC, self).__init__()\n        self.layers = nn.Sequential(\n            LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),\n            Flatten(),\n            Linear(512, embedding_size, bias=False),\n            BatchNorm1d(embedding_size))\n\n    def forward(self, x):\n        return self.layers(x)\n\n\nclass MobileFaceNet(Module):\n    def __init__(self, fp16=False, num_features=512):\n        super(MobileFaceNet, self).__init__()\n        scale = 2\n        self.fp16 = fp16\n        self.layers = nn.Sequential(\n            ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),\n            ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),\n            DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),\n            Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),\n            DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),\n            Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),\n            DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),\n            Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),\n        )\n        self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))\n        self.features = GDC(num_features)\n        self._initialize_weights()\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    m.bias.data.zero_()\n\n    def forward(self, x):\n        with torch.cuda.amp.autocast(self.fp16):\n            x = self.layers(x)\n        x = self.conv_sep(x.float() if self.fp16 else x)\n        x = self.features(x)\n        return x\n\n\ndef get_mbf(fp16, num_features):\n    return MobileFaceNet(fp16, num_features)"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/3millions.py",
    "content": "from easydict import EasyDict as edict\n\n# configs for test speed\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.network = \"r50\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"synthetic\"\nconfig.num_classes = 300 * 10000\nconfig.num_epoch = 30\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [10, 16, 22]\nconfig.val_targets = []\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/3millions_pfc.py",
    "content": "from easydict import EasyDict as edict\n\n# configs for test speed\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.network = \"r50\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 0.1\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"synthetic\"\nconfig.num_classes = 300 * 10000\nconfig.num_epoch = 30\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [10, 16, 22]\nconfig.val_targets = []\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/__init__.py",
    "content": ""
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/base.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.network = \"r50\"\nconfig.resume = False\nconfig.output = \"ms1mv3_arcface_r50\"\n\nconfig.dataset = \"ms1m-retinaface-t1\"\nconfig.embedding_size = 512\nconfig.sample_rate = 1\nconfig.fp16 = False\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nif config.dataset == \"emore\":\n    config.rec = \"/train_tmp/faces_emore\"\n    config.num_classes = 85742\n    config.num_image = 5822653\n    config.num_epoch = 16\n    config.warmup_epoch = -1\n    config.decay_epoch = [8, 14, ]\n    config.val_targets = [\"lfw\", ]\n\nelif config.dataset == \"ms1m-retinaface-t1\":\n    config.rec = \"/train_tmp/ms1m-retinaface-t1\"\n    config.num_classes = 93431\n    config.num_image = 5179510\n    config.num_epoch = 25\n    config.warmup_epoch = -1\n    config.decay_epoch = [11, 17, 22]\n    config.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n\nelif config.dataset == \"glint360k\":\n    config.rec = \"/train_tmp/glint360k\"\n    config.num_classes = 360232\n    config.num_image = 17091657\n    config.num_epoch = 20\n    config.warmup_epoch = -1\n    config.decay_epoch = [8, 12, 15, 18]\n    config.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n\nelif config.dataset == \"webface\":\n    config.rec = \"/train_tmp/faces_webface_112x112\"\n    config.num_classes = 10572\n    config.num_image = \"forget\"\n    config.num_epoch = 34\n    config.warmup_epoch = -1\n    config.decay_epoch = [20, 28, 32]\n    config.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_mbf.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"cosface\"\nconfig.network = \"mbf\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 0.1\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 2e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/glint360k\"\nconfig.num_classes = 360232\nconfig.num_image = 17091657\nconfig.num_epoch = 20\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [8, 12, 15, 18]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_r100.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"cosface\"\nconfig.network = \"r100\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/glint360k\"\nconfig.num_classes = 360232\nconfig.num_image = 17091657\nconfig.num_epoch = 20\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [8, 12, 15, 18]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_r18.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"cosface\"\nconfig.network = \"r18\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/glint360k\"\nconfig.num_classes = 360232\nconfig.num_image = 17091657\nconfig.num_epoch = 20\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [8, 12, 15, 18]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_r34.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"cosface\"\nconfig.network = \"r34\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/glint360k\"\nconfig.num_classes = 360232\nconfig.num_image = 17091657\nconfig.num_epoch = 20\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [8, 12, 15, 18]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_r50.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"cosface\"\nconfig.network = \"r50\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/glint360k\"\nconfig.num_classes = 360232\nconfig.num_image = 17091657\nconfig.num_epoch = 20\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [8, 12, 15, 18]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.network = \"mbf\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 2e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/ms1m-retinaface-t1\"\nconfig.num_classes = 93431\nconfig.num_image = 5179510\nconfig.num_epoch = 30\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [10, 20, 25]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_r18.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.network = \"r18\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/ms1m-retinaface-t1\"\nconfig.num_classes = 93431\nconfig.num_image = 5179510\nconfig.num_epoch = 25\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [10, 16, 22]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.network = \"r2060\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 64\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/ms1m-retinaface-t1\"\nconfig.num_classes = 93431\nconfig.num_image = 5179510\nconfig.num_epoch = 25\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [10, 16, 22]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_r34.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.network = \"r34\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/ms1m-retinaface-t1\"\nconfig.num_classes = 93431\nconfig.num_image = 5179510\nconfig.num_epoch = 25\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [10, 16, 22]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_r50.py",
    "content": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /train_tmp\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.network = \"r50\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"/train_tmp/ms1m-retinaface-t1\"\nconfig.num_classes = 93431\nconfig.num_image = 5179510\nconfig.num_epoch = 25\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [10, 16, 22]\nconfig.val_targets = [\"lfw\", \"cfp_fp\", \"agedb_30\"]\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/speed.py",
    "content": "from easydict import EasyDict as edict\n\n# configs for test speed\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.network = \"r50\"\nconfig.resume = False\nconfig.output = None\nconfig.embedding_size = 512\nconfig.sample_rate = 1.0\nconfig.fp16 = True\nconfig.momentum = 0.9\nconfig.weight_decay = 5e-4\nconfig.batch_size = 128\nconfig.lr = 0.1  # batch size is 512\n\nconfig.rec = \"synthetic\"\nconfig.num_classes = 100 * 10000\nconfig.num_epoch = 30\nconfig.warmup_epoch = -1\nconfig.decay_epoch = [10, 16, 22]\nconfig.val_targets = []\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/dataset.py",
    "content": "import numbers\nimport os\nimport queue as Queue\nimport threading\n\nimport mxnet as mx\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision import transforms\n\n\nclass BackgroundGenerator(threading.Thread):\n    def __init__(self, generator, local_rank, max_prefetch=6):\n        super(BackgroundGenerator, self).__init__()\n        self.queue = Queue.Queue(max_prefetch)\n        self.generator = generator\n        self.local_rank = local_rank\n        self.daemon = True\n        self.start()\n\n    def run(self):\n        torch.cuda.set_device(self.local_rank)\n        for item in self.generator:\n            self.queue.put(item)\n        self.queue.put(None)\n\n    def next(self):\n        next_item = self.queue.get()\n        if next_item is None:\n            raise StopIteration\n        return next_item\n\n    def __next__(self):\n        return self.next()\n\n    def __iter__(self):\n        return self\n\n\nclass DataLoaderX(DataLoader):\n\n    def __init__(self, local_rank, **kwargs):\n        super(DataLoaderX, self).__init__(**kwargs)\n        self.stream = torch.cuda.Stream(local_rank)\n        self.local_rank = local_rank\n\n    def __iter__(self):\n        self.iter = super(DataLoaderX, self).__iter__()\n        self.iter = BackgroundGenerator(self.iter, self.local_rank)\n        self.preload()\n        return self\n\n    def preload(self):\n        self.batch = next(self.iter, None)\n        if self.batch is None:\n            return None\n        with torch.cuda.stream(self.stream):\n            for k in range(len(self.batch)):\n                self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)\n\n    def __next__(self):\n        torch.cuda.current_stream().wait_stream(self.stream)\n        batch = self.batch\n        if batch is None:\n            raise StopIteration\n        self.preload()\n        return batch\n\n\nclass MXFaceDataset(Dataset):\n    def __init__(self, root_dir, local_rank):\n        super(MXFaceDataset, self).__init__()\n        self.transform = transforms.Compose(\n            [transforms.ToPILImage(),\n             transforms.RandomHorizontalFlip(),\n             transforms.ToTensor(),\n             transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n             ])\n        self.root_dir = root_dir\n        self.local_rank = local_rank\n        path_imgrec = os.path.join(root_dir, 'train.rec')\n        path_imgidx = os.path.join(root_dir, 'train.idx')\n        self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')\n        s = self.imgrec.read_idx(0)\n        header, _ = mx.recordio.unpack(s)\n        if header.flag > 0:\n            self.header0 = (int(header.label[0]), int(header.label[1]))\n            self.imgidx = np.array(range(1, int(header.label[0])))\n        else:\n            self.imgidx = np.array(list(self.imgrec.keys))\n\n    def __getitem__(self, index):\n        idx = self.imgidx[index]\n        s = self.imgrec.read_idx(idx)\n        header, img = mx.recordio.unpack(s)\n        label = header.label\n        if not isinstance(label, numbers.Number):\n            label = label[0]\n        label = torch.tensor(label, dtype=torch.long)\n        sample = mx.image.imdecode(img).asnumpy()\n        if self.transform is not None:\n            sample = self.transform(sample)\n        return sample, label\n\n    def __len__(self):\n        return len(self.imgidx)\n\n\nclass SyntheticDataset(Dataset):\n    def __init__(self, local_rank):\n        super(SyntheticDataset, self).__init__()\n        img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)\n        img = np.transpose(img, (2, 0, 1))\n        img = torch.from_numpy(img).squeeze(0).float()\n        img = ((img / 255) - 0.5) / 0.5\n        self.img = img\n        self.label = 1\n\n    def __getitem__(self, index):\n        return self.img, self.label\n\n    def __len__(self):\n        return 1000000\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/docs/eval.md",
    "content": "## Eval on ICCV2021-MFR\n\ncoming soon.\n\n\n## Eval IJBC\nYou can eval ijbc with pytorch or onnx.\n\n\n1. Eval IJBC With Onnx\n```shell\nCUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50\n```\n\n2. Eval IJBC With Pytorch\n```shell\nCUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \\\n--model-prefix ms1mv3_arcface_r50/backbone.pth \\\n--image-path IJB_release/IJBC \\\n--result-dir ms1mv3_arcface_r50 \\\n--batch-size 128 \\\n--job ms1mv3_arcface_r50 \\\n--target IJBC \\\n--network iresnet50\n```\n\n## Inference\n\n```shell\npython inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50\n```\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/docs/install.md",
    "content": "## v1.8.0 \n### Linux and Windows  \n```shell\n# CUDA 11.0\npip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html\n\n# CUDA 10.2\npip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0\n\n# CPU only\npip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html\n\n```\n\n\n## v1.7.1  \n### Linux and Windows  \n```shell\n# CUDA 11.0\npip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html\n\n# CUDA 10.2\npip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2\n\n# CUDA 10.1\npip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html\n\n# CUDA 9.2\npip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html\n\n# CPU only\npip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html\n```\n\n\n## v1.6.0  \n\n### Linux and Windows\n```shell\n# CUDA 10.2\npip install torch==1.6.0 torchvision==0.7.0\n\n# CUDA 10.1\npip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html\n\n# CUDA 9.2\npip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html\n\n# CPU only\npip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html\n```"
  },
  {
    "path": "src/face3d/models/arcface_torch/docs/modelzoo.md",
    "content": ""
  },
  {
    "path": "src/face3d/models/arcface_torch/docs/speed_benchmark.md",
    "content": "## Test Training Speed\n\n- Test Commands\n\nYou need to use the following two commands to test the Partial FC training performance. \nThe number of identites is **3 millions** (synthetic data), turn mixed precision  training on, backbone is resnet50, \nbatch size is 1024.\n```shell\n# Model Parallel\npython -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=\"127.0.0.1\" --master_port=1234 train.py configs/3millions\n# Partial FC 0.1\npython -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=\"127.0.0.1\" --master_port=1234 train.py configs/3millions_pfc\n```\n\n- GPU Memory\n\n```\n# (Model Parallel) gpustat -i\n[0] Tesla V100-SXM2-32GB | 64'C,  94 % | 30338 / 32510 MB \n[1] Tesla V100-SXM2-32GB | 60'C,  99 % | 28876 / 32510 MB \n[2] Tesla V100-SXM2-32GB | 60'C,  99 % | 28872 / 32510 MB \n[3] Tesla V100-SXM2-32GB | 69'C,  99 % | 28872 / 32510 MB \n[4] Tesla V100-SXM2-32GB | 66'C,  99 % | 28888 / 32510 MB \n[5] Tesla V100-SXM2-32GB | 60'C,  99 % | 28932 / 32510 MB \n[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB \n[7] Tesla V100-SXM2-32GB | 65'C,  99 % | 28860 / 32510 MB \n\n# (Partial FC 0.1) gpustat -i\n[0] Tesla V100-SXM2-32GB | 60'C,  95 % | 10488 / 32510 MB                                                                                                                                          │·······················\n[1] Tesla V100-SXM2-32GB | 60'C,  97 % | 10344 / 32510 MB                                                                                                                                          │·······················\n[2] Tesla V100-SXM2-32GB | 61'C,  95 % | 10340 / 32510 MB                                                                                                                                          │·······················\n[3] Tesla V100-SXM2-32GB | 66'C,  95 % | 10340 / 32510 MB                                                                                                                                          │·······················\n[4] Tesla V100-SXM2-32GB | 65'C,  94 % | 10356 / 32510 MB                                                                                                                                          │·······················\n[5] Tesla V100-SXM2-32GB | 61'C,  95 % | 10400 / 32510 MB                                                                                                                                          │·······················\n[6] Tesla V100-SXM2-32GB | 68'C,  96 % | 10384 / 32510 MB                                                                                                                                          │·······················\n[7] Tesla V100-SXM2-32GB | 64'C,  95 % | 10328 / 32510 MB                                                                                                                                        │·······················\n```\n\n- Training Speed\n\n```python\n# (Model Parallel) trainging.log\nTraining: Speed 2271.33 samples/sec   Loss 1.1624   LearningRate 0.2000   Epoch: 0   Global Step: 100 \nTraining: Speed 2269.94 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 150 \nTraining: Speed 2272.67 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 200 \nTraining: Speed 2266.55 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 250 \nTraining: Speed 2272.54 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 300 \n\n# (Partial FC 0.1) trainging.log\nTraining: Speed 5299.56 samples/sec   Loss 1.0965   LearningRate 0.2000   Epoch: 0   Global Step: 100  \nTraining: Speed 5296.37 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 150  \nTraining: Speed 5304.37 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 200  \nTraining: Speed 5274.43 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 250  \nTraining: Speed 5300.10 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 300   \n```\n\nIn this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, \nand the training speed is 2.5 times faster than the model parallel.\n\n\n## Speed Benchmark\n\n1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)\n\n| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |\n| :---    | :--- | :--- | :--- |\n|125000   | 4681 | 4824 | 5004 |\n|250000   | 4047 | 4521 | 4976 |\n|500000   | 3087 | 4013 | 4900 |\n|1000000  | 2090 | 3449 | 4803 |\n|1400000  | 1672 | 3043 | 4738 |\n|2000000  | -    | 2593 | 4626 |\n|4000000  | -    | 1748 | 4208 |\n|5500000  | -    | 1389 | 3975 |\n|8000000  | -    | -    | 3565 |\n|16000000 | -    | -    | 2679 |\n|29000000 | -    | -    | 1855 |\n\n2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)\n\n| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |\n| :---    | :---  | :---  | :---  |\n|125000   | 7358  | 5306  | 4868  |\n|250000   | 9940  | 5826  | 5004  |\n|500000   | 14220 | 7114  | 5202  |\n|1000000  | 23708 | 9966  | 5620  |\n|1400000  | 32252 | 11178 | 6056  |\n|2000000  | -     | 13978 | 6472  |\n|4000000  | -     | 23238 | 8284  |\n|5500000  | -     | 32188 | 9854  |\n|8000000  | -     | -     | 12310 |\n|16000000 | -     | -     | 19950 |\n|29000000 | -     | -     | 32324 |\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/eval/__init__.py",
    "content": ""
  },
  {
    "path": "src/face3d/models/arcface_torch/eval/verification.py",
    "content": "\"\"\"Helper for evaluation on the Labeled Faces in the Wild dataset \n\"\"\"\n\n# MIT License\n#\n# Copyright (c) 2016 David Sandberg\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n\nimport datetime\nimport os\nimport pickle\n\nimport mxnet as mx\nimport numpy as np\nimport sklearn\nimport torch\nfrom mxnet import ndarray as nd\nfrom scipy import interpolate\nfrom sklearn.decomposition import PCA\nfrom sklearn.model_selection import KFold\n\n\nclass LFold:\n    def __init__(self, n_splits=2, shuffle=False):\n        self.n_splits = n_splits\n        if self.n_splits > 1:\n            self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)\n\n    def split(self, indices):\n        if self.n_splits > 1:\n            return self.k_fold.split(indices)\n        else:\n            return [(indices, indices)]\n\n\ndef calculate_roc(thresholds,\n                  embeddings1,\n                  embeddings2,\n                  actual_issame,\n                  nrof_folds=10,\n                  pca=0):\n    assert (embeddings1.shape[0] == embeddings2.shape[0])\n    assert (embeddings1.shape[1] == embeddings2.shape[1])\n    nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n    nrof_thresholds = len(thresholds)\n    k_fold = LFold(n_splits=nrof_folds, shuffle=False)\n\n    tprs = np.zeros((nrof_folds, nrof_thresholds))\n    fprs = np.zeros((nrof_folds, nrof_thresholds))\n    accuracy = np.zeros((nrof_folds))\n    indices = np.arange(nrof_pairs)\n\n    if pca == 0:\n        diff = np.subtract(embeddings1, embeddings2)\n        dist = np.sum(np.square(diff), 1)\n\n    for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n        if pca > 0:\n            print('doing pca on', fold_idx)\n            embed1_train = embeddings1[train_set]\n            embed2_train = embeddings2[train_set]\n            _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)\n            pca_model = PCA(n_components=pca)\n            pca_model.fit(_embed_train)\n            embed1 = pca_model.transform(embeddings1)\n            embed2 = pca_model.transform(embeddings2)\n            embed1 = sklearn.preprocessing.normalize(embed1)\n            embed2 = sklearn.preprocessing.normalize(embed2)\n            diff = np.subtract(embed1, embed2)\n            dist = np.sum(np.square(diff), 1)\n\n        # Find the best threshold for the fold\n        acc_train = np.zeros((nrof_thresholds))\n        for threshold_idx, threshold in enumerate(thresholds):\n            _, _, acc_train[threshold_idx] = calculate_accuracy(\n                threshold, dist[train_set], actual_issame[train_set])\n        best_threshold_index = np.argmax(acc_train)\n        for threshold_idx, threshold in enumerate(thresholds):\n            tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(\n                threshold, dist[test_set],\n                actual_issame[test_set])\n        _, _, accuracy[fold_idx] = calculate_accuracy(\n            thresholds[best_threshold_index], dist[test_set],\n            actual_issame[test_set])\n\n    tpr = np.mean(tprs, 0)\n    fpr = np.mean(fprs, 0)\n    return tpr, fpr, accuracy\n\n\ndef calculate_accuracy(threshold, dist, actual_issame):\n    predict_issame = np.less(dist, threshold)\n    tp = np.sum(np.logical_and(predict_issame, actual_issame))\n    fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n    tn = np.sum(\n        np.logical_and(np.logical_not(predict_issame),\n                       np.logical_not(actual_issame)))\n    fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))\n\n    tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)\n    fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)\n    acc = float(tp + tn) / dist.size\n    return tpr, fpr, acc\n\n\ndef calculate_val(thresholds,\n                  embeddings1,\n                  embeddings2,\n                  actual_issame,\n                  far_target,\n                  nrof_folds=10):\n    assert (embeddings1.shape[0] == embeddings2.shape[0])\n    assert (embeddings1.shape[1] == embeddings2.shape[1])\n    nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n    nrof_thresholds = len(thresholds)\n    k_fold = LFold(n_splits=nrof_folds, shuffle=False)\n\n    val = np.zeros(nrof_folds)\n    far = np.zeros(nrof_folds)\n\n    diff = np.subtract(embeddings1, embeddings2)\n    dist = np.sum(np.square(diff), 1)\n    indices = np.arange(nrof_pairs)\n\n    for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n\n        # Find the threshold that gives FAR = far_target\n        far_train = np.zeros(nrof_thresholds)\n        for threshold_idx, threshold in enumerate(thresholds):\n            _, far_train[threshold_idx] = calculate_val_far(\n                threshold, dist[train_set], actual_issame[train_set])\n        if np.max(far_train) >= far_target:\n            f = interpolate.interp1d(far_train, thresholds, kind='slinear')\n            threshold = f(far_target)\n        else:\n            threshold = 0.0\n\n        val[fold_idx], far[fold_idx] = calculate_val_far(\n            threshold, dist[test_set], actual_issame[test_set])\n\n    val_mean = np.mean(val)\n    far_mean = np.mean(far)\n    val_std = np.std(val)\n    return val_mean, val_std, far_mean\n\n\ndef calculate_val_far(threshold, dist, actual_issame):\n    predict_issame = np.less(dist, threshold)\n    true_accept = np.sum(np.logical_and(predict_issame, actual_issame))\n    false_accept = np.sum(\n        np.logical_and(predict_issame, np.logical_not(actual_issame)))\n    n_same = np.sum(actual_issame)\n    n_diff = np.sum(np.logical_not(actual_issame))\n    # print(true_accept, false_accept)\n    # print(n_same, n_diff)\n    val = float(true_accept) / float(n_same)\n    far = float(false_accept) / float(n_diff)\n    return val, far\n\n\ndef evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):\n    # Calculate evaluation metrics\n    thresholds = np.arange(0, 4, 0.01)\n    embeddings1 = embeddings[0::2]\n    embeddings2 = embeddings[1::2]\n    tpr, fpr, accuracy = calculate_roc(thresholds,\n                                       embeddings1,\n                                       embeddings2,\n                                       np.asarray(actual_issame),\n                                       nrof_folds=nrof_folds,\n                                       pca=pca)\n    thresholds = np.arange(0, 4, 0.001)\n    val, val_std, far = calculate_val(thresholds,\n                                      embeddings1,\n                                      embeddings2,\n                                      np.asarray(actual_issame),\n                                      1e-3,\n                                      nrof_folds=nrof_folds)\n    return tpr, fpr, accuracy, val, val_std, far\n\n@torch.no_grad()\ndef load_bin(path, image_size):\n    try:\n        with open(path, 'rb') as f:\n            bins, issame_list = pickle.load(f)  # py2\n    except UnicodeDecodeError as e:\n        with open(path, 'rb') as f:\n            bins, issame_list = pickle.load(f, encoding='bytes')  # py3\n    data_list = []\n    for flip in [0, 1]:\n        data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))\n        data_list.append(data)\n    for idx in range(len(issame_list) * 2):\n        _bin = bins[idx]\n        img = mx.image.imdecode(_bin)\n        if img.shape[1] != image_size[0]:\n            img = mx.image.resize_short(img, image_size[0])\n        img = nd.transpose(img, axes=(2, 0, 1))\n        for flip in [0, 1]:\n            if flip == 1:\n                img = mx.ndarray.flip(data=img, axis=2)\n            data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())\n        if idx % 1000 == 0:\n            print('loading bin', idx)\n    print(data_list[0].shape)\n    return data_list, issame_list\n\n@torch.no_grad()\ndef test(data_set, backbone, batch_size, nfolds=10):\n    print('testing verification..')\n    data_list = data_set[0]\n    issame_list = data_set[1]\n    embeddings_list = []\n    time_consumed = 0.0\n    for i in range(len(data_list)):\n        data = data_list[i]\n        embeddings = None\n        ba = 0\n        while ba < data.shape[0]:\n            bb = min(ba + batch_size, data.shape[0])\n            count = bb - ba\n            _data = data[bb - batch_size: bb]\n            time0 = datetime.datetime.now()\n            img = ((_data / 255) - 0.5) / 0.5\n            net_out: torch.Tensor = backbone(img)\n            _embeddings = net_out.detach().cpu().numpy()\n            time_now = datetime.datetime.now()\n            diff = time_now - time0\n            time_consumed += diff.total_seconds()\n            if embeddings is None:\n                embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))\n            embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]\n            ba = bb\n        embeddings_list.append(embeddings)\n\n    _xnorm = 0.0\n    _xnorm_cnt = 0\n    for embed in embeddings_list:\n        for i in range(embed.shape[0]):\n            _em = embed[i]\n            _norm = np.linalg.norm(_em)\n            _xnorm += _norm\n            _xnorm_cnt += 1\n    _xnorm /= _xnorm_cnt\n\n    acc1 = 0.0\n    std1 = 0.0\n    embeddings = embeddings_list[0] + embeddings_list[1]\n    embeddings = sklearn.preprocessing.normalize(embeddings)\n    print(embeddings.shape)\n    print('infer time', time_consumed)\n    _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)\n    acc2, std2 = np.mean(accuracy), np.std(accuracy)\n    return acc1, std1, acc2, std2, _xnorm, embeddings_list\n\n\ndef dumpR(data_set,\n          backbone,\n          batch_size,\n          name='',\n          data_extra=None,\n          label_shape=None):\n    print('dump verification embedding..')\n    data_list = data_set[0]\n    issame_list = data_set[1]\n    embeddings_list = []\n    time_consumed = 0.0\n    for i in range(len(data_list)):\n        data = data_list[i]\n        embeddings = None\n        ba = 0\n        while ba < data.shape[0]:\n            bb = min(ba + batch_size, data.shape[0])\n            count = bb - ba\n\n            _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)\n            time0 = datetime.datetime.now()\n            if data_extra is None:\n                db = mx.io.DataBatch(data=(_data,), label=(_label,))\n            else:\n                db = mx.io.DataBatch(data=(_data, _data_extra),\n                                     label=(_label,))\n            model.forward(db, is_train=False)\n            net_out = model.get_outputs()\n            _embeddings = net_out[0].asnumpy()\n            time_now = datetime.datetime.now()\n            diff = time_now - time0\n            time_consumed += diff.total_seconds()\n            if embeddings is None:\n                embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))\n            embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]\n            ba = bb\n        embeddings_list.append(embeddings)\n    embeddings = embeddings_list[0] + embeddings_list[1]\n    embeddings = sklearn.preprocessing.normalize(embeddings)\n    actual_issame = np.asarray(issame_list)\n    outname = os.path.join('temp.bin')\n    with open(outname, 'wb') as f:\n        pickle.dump((embeddings, issame_list),\n                    f,\n                    protocol=pickle.HIGHEST_PROTOCOL)\n\n\n# if __name__ == '__main__':\n#\n#     parser = argparse.ArgumentParser(description='do verification')\n#     # general\n#     parser.add_argument('--data-dir', default='', help='')\n#     parser.add_argument('--model',\n#                         default='../model/softmax,50',\n#                         help='path to load model.')\n#     parser.add_argument('--target',\n#                         default='lfw,cfp_ff,cfp_fp,agedb_30',\n#                         help='test targets.')\n#     parser.add_argument('--gpu', default=0, type=int, help='gpu id')\n#     parser.add_argument('--batch-size', default=32, type=int, help='')\n#     parser.add_argument('--max', default='', type=str, help='')\n#     parser.add_argument('--mode', default=0, type=int, help='')\n#     parser.add_argument('--nfolds', default=10, type=int, help='')\n#     args = parser.parse_args()\n#     image_size = [112, 112]\n#     print('image_size', image_size)\n#     ctx = mx.gpu(args.gpu)\n#     nets = []\n#     vec = args.model.split(',')\n#     prefix = args.model.split(',')[0]\n#     epochs = []\n#     if len(vec) == 1:\n#         pdir = os.path.dirname(prefix)\n#         for fname in os.listdir(pdir):\n#             if not fname.endswith('.params'):\n#                 continue\n#             _file = os.path.join(pdir, fname)\n#             if _file.startswith(prefix):\n#                 epoch = int(fname.split('.')[0].split('-')[1])\n#                 epochs.append(epoch)\n#         epochs = sorted(epochs, reverse=True)\n#         if len(args.max) > 0:\n#             _max = [int(x) for x in args.max.split(',')]\n#             assert len(_max) == 2\n#             if len(epochs) > _max[1]:\n#                 epochs = epochs[_max[0]:_max[1]]\n#\n#     else:\n#         epochs = [int(x) for x in vec[1].split('|')]\n#     print('model number', len(epochs))\n#     time0 = datetime.datetime.now()\n#     for epoch in epochs:\n#         print('loading', prefix, epoch)\n#         sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)\n#         # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)\n#         all_layers = sym.get_internals()\n#         sym = all_layers['fc1_output']\n#         model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)\n#         # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])\n#         model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],\n#                                           image_size[1]))])\n#         model.set_params(arg_params, aux_params)\n#         nets.append(model)\n#     time_now = datetime.datetime.now()\n#     diff = time_now - time0\n#     print('model loading time', diff.total_seconds())\n#\n#     ver_list = []\n#     ver_name_list = []\n#     for name in args.target.split(','):\n#         path = os.path.join(args.data_dir, name + \".bin\")\n#         if os.path.exists(path):\n#             print('loading.. ', name)\n#             data_set = load_bin(path, image_size)\n#             ver_list.append(data_set)\n#             ver_name_list.append(name)\n#\n#     if args.mode == 0:\n#         for i in range(len(ver_list)):\n#             results = []\n#             for model in nets:\n#                 acc1, std1, acc2, std2, xnorm, embeddings_list = test(\n#                     ver_list[i], model, args.batch_size, args.nfolds)\n#                 print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))\n#                 print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))\n#                 print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))\n#                 results.append(acc2)\n#             print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))\n#     elif args.mode == 1:\n#         raise ValueError\n#     else:\n#         model = nets[0]\n#         dumpR(ver_list[0], model, args.batch_size, args.target)\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/eval_ijbc.py",
    "content": "# coding: utf-8\n\nimport os\nimport pickle\n\nimport matplotlib\nimport pandas as pd\n\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimport timeit\nimport sklearn\nimport argparse\nimport cv2\nimport numpy as np\nimport torch\nfrom skimage import transform as trans\nfrom backbones import get_model\nfrom sklearn.metrics import roc_curve, auc\n\nfrom menpo.visualize.viewmatplotlib import sample_colours_from_colourmap\nfrom prettytable import PrettyTable\nfrom pathlib import Path\n\nimport sys\nimport warnings\n\nsys.path.insert(0, \"../\")\nwarnings.filterwarnings(\"ignore\")\n\nparser = argparse.ArgumentParser(description='do ijb test')\n# general\nparser.add_argument('--model-prefix', default='', help='path to load model.')\nparser.add_argument('--image-path', default='', type=str, help='')\nparser.add_argument('--result-dir', default='.', type=str, help='')\nparser.add_argument('--batch-size', default=128, type=int, help='')\nparser.add_argument('--network', default='iresnet50', type=str, help='')\nparser.add_argument('--job', default='insightface', type=str, help='job name')\nparser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')\nargs = parser.parse_args()\n\ntarget = args.target\nmodel_path = args.model_prefix\nimage_path = args.image_path\nresult_dir = args.result_dir\ngpu_id = None\nuse_norm_score = True  # if Ture, TestMode(N1)\nuse_detector_score = True  # if Ture, TestMode(D1)\nuse_flip_test = True  # if Ture, TestMode(F1)\njob = args.job\nbatch_size = args.batch_size\n\n\nclass Embedding(object):\n    def __init__(self, prefix, data_shape, batch_size=1):\n        image_size = (112, 112)\n        self.image_size = image_size\n        weight = torch.load(prefix)\n        resnet = get_model(args.network, dropout=0, fp16=False).cuda()\n        resnet.load_state_dict(weight)\n        model = torch.nn.DataParallel(resnet)\n        self.model = model\n        self.model.eval()\n        src = np.array([\n            [30.2946, 51.6963],\n            [65.5318, 51.5014],\n            [48.0252, 71.7366],\n            [33.5493, 92.3655],\n            [62.7299, 92.2041]], dtype=np.float32)\n        src[:, 0] += 8.0\n        self.src = src\n        self.batch_size = batch_size\n        self.data_shape = data_shape\n\n    def get(self, rimg, landmark):\n\n        assert landmark.shape[0] == 68 or landmark.shape[0] == 5\n        assert landmark.shape[1] == 2\n        if landmark.shape[0] == 68:\n            landmark5 = np.zeros((5, 2), dtype=np.float32)\n            landmark5[0] = (landmark[36] + landmark[39]) / 2\n            landmark5[1] = (landmark[42] + landmark[45]) / 2\n            landmark5[2] = landmark[30]\n            landmark5[3] = landmark[48]\n            landmark5[4] = landmark[54]\n        else:\n            landmark5 = landmark\n        tform = trans.SimilarityTransform()\n        tform.estimate(landmark5, self.src)\n        M = tform.params[0:2, :]\n        img = cv2.warpAffine(rimg,\n                             M, (self.image_size[1], self.image_size[0]),\n                             borderValue=0.0)\n        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n        img_flip = np.fliplr(img)\n        img = np.transpose(img, (2, 0, 1))  # 3*112*112, RGB\n        img_flip = np.transpose(img_flip, (2, 0, 1))\n        input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)\n        input_blob[0] = img\n        input_blob[1] = img_flip\n        return input_blob\n\n    @torch.no_grad()\n    def forward_db(self, batch_data):\n        imgs = torch.Tensor(batch_data).cuda()\n        imgs.div_(255).sub_(0.5).div_(0.5)\n        feat = self.model(imgs)\n        feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])\n        return feat.cpu().numpy()\n\n\n# 将一个list尽量均分成n份，限制len(list)==n，份数大于原list内元素个数则分配空list[]\ndef divideIntoNstrand(listTemp, n):\n    twoList = [[] for i in range(n)]\n    for i, e in enumerate(listTemp):\n        twoList[i % n].append(e)\n    return twoList\n\n\ndef read_template_media_list(path):\n    # ijb_meta = np.loadtxt(path, dtype=str)\n    ijb_meta = pd.read_csv(path, sep=' ', header=None).values\n    templates = ijb_meta[:, 1].astype(np.int)\n    medias = ijb_meta[:, 2].astype(np.int)\n    return templates, medias\n\n\n# In[ ]:\n\n\ndef read_template_pair_list(path):\n    # pairs = np.loadtxt(path, dtype=str)\n    pairs = pd.read_csv(path, sep=' ', header=None).values\n    # print(pairs.shape)\n    # print(pairs[:, 0].astype(np.int))\n    t1 = pairs[:, 0].astype(np.int)\n    t2 = pairs[:, 1].astype(np.int)\n    label = pairs[:, 2].astype(np.int)\n    return t1, t2, label\n\n\n# In[ ]:\n\n\ndef read_image_feature(path):\n    with open(path, 'rb') as fid:\n        img_feats = pickle.load(fid)\n    return img_feats\n\n\n# In[ ]:\n\n\ndef get_image_feature(img_path, files_list, model_path, epoch, gpu_id):\n    batch_size = args.batch_size\n    data_shape = (3, 112, 112)\n\n    files = files_list\n    print('files:', len(files))\n    rare_size = len(files) % batch_size\n    faceness_scores = []\n    batch = 0\n    img_feats = np.empty((len(files), 1024), dtype=np.float32)\n\n    batch_data = np.empty((2 * batch_size, 3, 112, 112))\n    embedding = Embedding(model_path, data_shape, batch_size)\n    for img_index, each_line in enumerate(files[:len(files) - rare_size]):\n        name_lmk_score = each_line.strip().split(' ')\n        img_name = os.path.join(img_path, name_lmk_score[0])\n        img = cv2.imread(img_name)\n        lmk = np.array([float(x) for x in name_lmk_score[1:-1]],\n                       dtype=np.float32)\n        lmk = lmk.reshape((5, 2))\n        input_blob = embedding.get(img, lmk)\n\n        batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]\n        batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]\n        if (img_index + 1) % batch_size == 0:\n            print('batch', batch)\n            img_feats[batch * batch_size:batch * batch_size +\n                                         batch_size][:] = embedding.forward_db(batch_data)\n            batch += 1\n        faceness_scores.append(name_lmk_score[-1])\n\n    batch_data = np.empty((2 * rare_size, 3, 112, 112))\n    embedding = Embedding(model_path, data_shape, rare_size)\n    for img_index, each_line in enumerate(files[len(files) - rare_size:]):\n        name_lmk_score = each_line.strip().split(' ')\n        img_name = os.path.join(img_path, name_lmk_score[0])\n        img = cv2.imread(img_name)\n        lmk = np.array([float(x) for x in name_lmk_score[1:-1]],\n                       dtype=np.float32)\n        lmk = lmk.reshape((5, 2))\n        input_blob = embedding.get(img, lmk)\n        batch_data[2 * img_index][:] = input_blob[0]\n        batch_data[2 * img_index + 1][:] = input_blob[1]\n        if (img_index + 1) % rare_size == 0:\n            print('batch', batch)\n            img_feats[len(files) -\n                      rare_size:][:] = embedding.forward_db(batch_data)\n            batch += 1\n        faceness_scores.append(name_lmk_score[-1])\n    faceness_scores = np.array(faceness_scores).astype(np.float32)\n    # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01\n    # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )\n    return img_feats, faceness_scores\n\n\n# In[ ]:\n\n\ndef image2template_feature(img_feats=None, templates=None, medias=None):\n    # ==========================================================\n    # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]\n    # 2. compute media feature.\n    # 3. compute template feature.\n    # ==========================================================\n    unique_templates = np.unique(templates)\n    template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))\n\n    for count_template, uqt in enumerate(unique_templates):\n\n        (ind_t,) = np.where(templates == uqt)\n        face_norm_feats = img_feats[ind_t]\n        face_medias = medias[ind_t]\n        unique_medias, unique_media_counts = np.unique(face_medias,\n                                                       return_counts=True)\n        media_norm_feats = []\n        for u, ct in zip(unique_medias, unique_media_counts):\n            (ind_m,) = np.where(face_medias == u)\n            if ct == 1:\n                media_norm_feats += [face_norm_feats[ind_m]]\n            else:  # image features from the same video will be aggregated into one feature\n                media_norm_feats += [\n                    np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)\n                ]\n        media_norm_feats = np.array(media_norm_feats)\n        # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))\n        template_feats[count_template] = np.sum(media_norm_feats, axis=0)\n        if count_template % 2000 == 0:\n            print('Finish Calculating {} template features.'.format(\n                count_template))\n    # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))\n    template_norm_feats = sklearn.preprocessing.normalize(template_feats)\n    # print(template_norm_feats.shape)\n    return template_norm_feats, unique_templates\n\n\n# In[ ]:\n\n\ndef verification(template_norm_feats=None,\n                 unique_templates=None,\n                 p1=None,\n                 p2=None):\n    # ==========================================================\n    #         Compute set-to-set Similarity Score.\n    # ==========================================================\n    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)\n    for count_template, uqt in enumerate(unique_templates):\n        template2id[uqt] = count_template\n\n    score = np.zeros((len(p1),))  # save cosine distance between pairs\n\n    total_pairs = np.array(range(len(p1)))\n    batchsize = 100000  # small batchsize instead of all pairs in one batch due to the memory limiation\n    sublists = [\n        total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)\n    ]\n    total_sublists = len(sublists)\n    for c, s in enumerate(sublists):\n        feat1 = template_norm_feats[template2id[p1[s]]]\n        feat2 = template_norm_feats[template2id[p2[s]]]\n        similarity_score = np.sum(feat1 * feat2, -1)\n        score[s] = similarity_score.flatten()\n        if c % 10 == 0:\n            print('Finish {}/{} pairs.'.format(c, total_sublists))\n    return score\n\n\n# In[ ]:\ndef verification2(template_norm_feats=None,\n                  unique_templates=None,\n                  p1=None,\n                  p2=None):\n    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)\n    for count_template, uqt in enumerate(unique_templates):\n        template2id[uqt] = count_template\n    score = np.zeros((len(p1),))  # save cosine distance between pairs\n    total_pairs = np.array(range(len(p1)))\n    batchsize = 100000  # small batchsize instead of all pairs in one batch due to the memory limiation\n    sublists = [\n        total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)\n    ]\n    total_sublists = len(sublists)\n    for c, s in enumerate(sublists):\n        feat1 = template_norm_feats[template2id[p1[s]]]\n        feat2 = template_norm_feats[template2id[p2[s]]]\n        similarity_score = np.sum(feat1 * feat2, -1)\n        score[s] = similarity_score.flatten()\n        if c % 10 == 0:\n            print('Finish {}/{} pairs.'.format(c, total_sublists))\n    return score\n\n\ndef read_score(path):\n    with open(path, 'rb') as fid:\n        img_feats = pickle.load(fid)\n    return img_feats\n\n\n# # Step1: Load Meta Data\n\n# In[ ]:\n\nassert target == 'IJBC' or target == 'IJBB'\n\n# =============================================================\n# load image and template relationships for template feature embedding\n# tid --> template id,  mid --> media id\n# format:\n#           image_name tid mid\n# =============================================================\nstart = timeit.default_timer()\ntemplates, medias = read_template_media_list(\n    os.path.join('%s/meta' % image_path,\n                 '%s_face_tid_mid.txt' % target.lower()))\nstop = timeit.default_timer()\nprint('Time: %.2f s. ' % (stop - start))\n\n# In[ ]:\n\n# =============================================================\n# load template pairs for template-to-template verification\n# tid : template id,  label : 1/0\n# format:\n#           tid_1 tid_2 label\n# =============================================================\nstart = timeit.default_timer()\np1, p2, label = read_template_pair_list(\n    os.path.join('%s/meta' % image_path,\n                 '%s_template_pair_label.txt' % target.lower()))\nstop = timeit.default_timer()\nprint('Time: %.2f s. ' % (stop - start))\n\n# # Step 2: Get Image Features\n\n# In[ ]:\n\n# =============================================================\n# load image features\n# format:\n#           img_feats: [image_num x feats_dim] (227630, 512)\n# =============================================================\nstart = timeit.default_timer()\nimg_path = '%s/loose_crop' % image_path\nimg_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())\nimg_list = open(img_list_path)\nfiles = img_list.readlines()\n# files_list = divideIntoNstrand(files, rank_size)\nfiles_list = files\n\n# img_feats\n# for i in range(rank_size):\nimg_feats, faceness_scores = get_image_feature(img_path, files_list,\n                                               model_path, 0, gpu_id)\nstop = timeit.default_timer()\nprint('Time: %.2f s. ' % (stop - start))\nprint('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],\n                                          img_feats.shape[1]))\n\n# # Step3: Get Template Features\n\n# In[ ]:\n\n# =============================================================\n# compute template features from image features.\n# =============================================================\nstart = timeit.default_timer()\n# ==========================================================\n# Norm feature before aggregation into template feature?\n# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).\n# ==========================================================\n# 1. FaceScore （Feature Norm）\n# 2. FaceScore （Detector）\n\nif use_flip_test:\n    # concat --- F1\n    # img_input_feats = img_feats\n    # add --- F2\n    img_input_feats = img_feats[:, 0:img_feats.shape[1] //\n                                     2] + img_feats[:, img_feats.shape[1] // 2:]\nelse:\n    img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]\n\nif use_norm_score:\n    img_input_feats = img_input_feats\nelse:\n    # normalise features to remove norm information\n    img_input_feats = img_input_feats / np.sqrt(\n        np.sum(img_input_feats ** 2, -1, keepdims=True))\n\nif use_detector_score:\n    print(img_input_feats.shape, faceness_scores.shape)\n    img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]\nelse:\n    img_input_feats = img_input_feats\n\ntemplate_norm_feats, unique_templates = image2template_feature(\n    img_input_feats, templates, medias)\nstop = timeit.default_timer()\nprint('Time: %.2f s. ' % (stop - start))\n\n# # Step 4: Get Template Similarity Scores\n\n# In[ ]:\n\n# =============================================================\n# compute verification scores between template pairs.\n# =============================================================\nstart = timeit.default_timer()\nscore = verification(template_norm_feats, unique_templates, p1, p2)\nstop = timeit.default_timer()\nprint('Time: %.2f s. ' % (stop - start))\n\n# In[ ]:\nsave_path = os.path.join(result_dir, args.job)\n# save_path = result_dir + '/%s_result' % target\n\nif not os.path.exists(save_path):\n    os.makedirs(save_path)\n\nscore_save_file = os.path.join(save_path, \"%s.npy\" % target.lower())\nnp.save(score_save_file, score)\n\n# # Step 5: Get ROC Curves and TPR@FPR Table\n\n# In[ ]:\n\nfiles = [score_save_file]\nmethods = []\nscores = []\nfor file in files:\n    methods.append(Path(file).stem)\n    scores.append(np.load(file))\n\nmethods = np.array(methods)\nscores = dict(zip(methods, scores))\ncolours = dict(\n    zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))\nx_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]\ntpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])\nfig = plt.figure()\nfor method in methods:\n    fpr, tpr, _ = roc_curve(label, scores[method])\n    roc_auc = auc(fpr, tpr)\n    fpr = np.flipud(fpr)\n    tpr = np.flipud(tpr)  # select largest tpr at same fpr\n    plt.plot(fpr,\n             tpr,\n             color=colours[method],\n             lw=1,\n             label=('[%s (AUC = %0.4f %%)]' %\n                    (method.split('-')[-1], roc_auc * 100)))\n    tpr_fpr_row = []\n    tpr_fpr_row.append(\"%s-%s\" % (method, target))\n    for fpr_iter in np.arange(len(x_labels)):\n        _, min_index = min(\n            list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))\n        tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))\n    tpr_fpr_table.add_row(tpr_fpr_row)\nplt.xlim([10 ** -6, 0.1])\nplt.ylim([0.3, 1.0])\nplt.grid(linestyle='--', linewidth=1)\nplt.xticks(x_labels)\nplt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))\nplt.xscale('log')\nplt.xlabel('False Positive Rate')\nplt.ylabel('True Positive Rate')\nplt.title('ROC on IJB')\nplt.legend(loc=\"lower right\")\nfig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))\nprint(tpr_fpr_table)\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/inference.py",
    "content": "import argparse\n\nimport cv2\nimport numpy as np\nimport torch\n\nfrom backbones import get_model\n\n\n@torch.no_grad()\ndef inference(weight, name, img):\n    if img is None:\n        img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)\n    else:\n        img = cv2.imread(img)\n        img = cv2.resize(img, (112, 112))\n\n    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n    img = np.transpose(img, (2, 0, 1))\n    img = torch.from_numpy(img).unsqueeze(0).float()\n    img.div_(255).sub_(0.5).div_(0.5)\n    net = get_model(name, fp16=False)\n    net.load_state_dict(torch.load(weight))\n    net.eval()\n    feat = net(img).numpy()\n    print(feat)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')\n    parser.add_argument('--network', type=str, default='r50', help='backbone network')\n    parser.add_argument('--weight', type=str, default='')\n    parser.add_argument('--img', type=str, default=None)\n    args = parser.parse_args()\n    inference(args.weight, args.network, args.img)\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/losses.py",
    "content": "import torch\nfrom torch import nn\n\n\ndef get_loss(name):\n    if name == \"cosface\":\n        return CosFace()\n    elif name == \"arcface\":\n        return ArcFace()\n    else:\n        raise ValueError()\n\n\nclass CosFace(nn.Module):\n    def __init__(self, s=64.0, m=0.40):\n        super(CosFace, self).__init__()\n        self.s = s\n        self.m = m\n\n    def forward(self, cosine, label):\n        index = torch.where(label != -1)[0]\n        m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)\n        m_hot.scatter_(1, label[index, None], self.m)\n        cosine[index] -= m_hot\n        ret = cosine * self.s\n        return ret\n\n\nclass ArcFace(nn.Module):\n    def __init__(self, s=64.0, m=0.5):\n        super(ArcFace, self).__init__()\n        self.s = s\n        self.m = m\n\n    def forward(self, cosine: torch.Tensor, label):\n        index = torch.where(label != -1)[0]\n        m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)\n        m_hot.scatter_(1, label[index, None], self.m)\n        cosine.acos_()\n        cosine[index] += m_hot\n        cosine.cos_().mul_(self.s)\n        return cosine\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/onnx_helper.py",
    "content": "from __future__ import division\nimport datetime\nimport os\nimport os.path as osp\nimport glob\nimport numpy as np\nimport cv2\nimport sys\nimport onnxruntime\nimport onnx\nimport argparse\nfrom onnx import numpy_helper\nfrom insightface.data import get_image\n\nclass ArcFaceORT:\n    def __init__(self, model_path, cpu=False):\n        self.model_path = model_path\n        # providers = None will use available provider, for onnxruntime-gpu it will be \"CUDAExecutionProvider\"\n        self.providers = ['CPUExecutionProvider'] if cpu else None\n\n    #input_size is (w,h), return error message, return None if success\n    def check(self, track='cfat', test_img = None):\n        #default is cfat\n        max_model_size_mb=1024\n        max_feat_dim=512\n        max_time_cost=15\n        if track.startswith('ms1m'):\n            max_model_size_mb=1024\n            max_feat_dim=512\n            max_time_cost=10\n        elif track.startswith('glint'):\n            max_model_size_mb=1024\n            max_feat_dim=1024\n            max_time_cost=20\n        elif track.startswith('cfat'):\n            max_model_size_mb = 1024\n            max_feat_dim = 512\n            max_time_cost = 15\n        elif track.startswith('unconstrained'):\n            max_model_size_mb=1024\n            max_feat_dim=1024\n            max_time_cost=30\n        else:\n            return \"track not found\"\n\n        if not os.path.exists(self.model_path):\n            return \"model_path not exists\"\n        if not os.path.isdir(self.model_path):\n            return \"model_path should be directory\"\n        onnx_files = []\n        for _file in os.listdir(self.model_path):\n            if _file.endswith('.onnx'):\n                onnx_files.append(osp.join(self.model_path, _file))\n        if len(onnx_files)==0:\n            return \"do not have onnx files\"\n        self.model_file = sorted(onnx_files)[-1]\n        print('use onnx-model:', self.model_file)\n        try:\n            session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)\n        except:\n            return \"load onnx failed\"\n        input_cfg = session.get_inputs()[0]\n        input_shape = input_cfg.shape\n        print('input-shape:', input_shape)\n        if len(input_shape)!=4:\n            return \"length of input_shape should be 4\"\n        if not isinstance(input_shape[0], str):\n            #return \"input_shape[0] should be str to support batch-inference\"\n            print('reset input-shape[0] to None')\n            model = onnx.load(self.model_file)\n            model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'\n            new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx')\n            onnx.save(model, new_model_file)\n            self.model_file = new_model_file\n            print('use new onnx-model:', self.model_file)\n            try:\n                session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)\n            except:\n                return \"load onnx failed\"\n            input_cfg = session.get_inputs()[0]\n            input_shape = input_cfg.shape\n            print('new-input-shape:', input_shape)\n\n        self.image_size = tuple(input_shape[2:4][::-1])\n        #print('image_size:', self.image_size)\n        input_name = input_cfg.name\n        outputs = session.get_outputs()\n        output_names = []\n        for o in outputs:\n            output_names.append(o.name)\n            #print(o.name, o.shape)\n        if len(output_names)!=1:\n            return \"number of output nodes should be 1\"\n        self.session = session\n        self.input_name = input_name\n        self.output_names = output_names\n        #print(self.output_names)\n        model = onnx.load(self.model_file)\n        graph = model.graph\n        if len(graph.node)<8:\n            return \"too small onnx graph\"\n\n        input_size = (112,112)\n        self.crop = None\n        if track=='cfat':\n            crop_file = osp.join(self.model_path, 'crop.txt')\n            if osp.exists(crop_file):\n                lines = open(crop_file,'r').readlines()\n                if len(lines)!=6:\n                    return \"crop.txt should contain 6 lines\"\n                lines = [int(x) for x in lines]\n                self.crop = lines[:4]\n                input_size = tuple(lines[4:6])\n        if input_size!=self.image_size:\n            return \"input-size is inconsistant with onnx model input, %s vs %s\"%(input_size, self.image_size)\n\n        self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024)\n        if self.model_size_mb > max_model_size_mb:\n            return \"max model size exceed, given %.3f-MB\"%self.model_size_mb\n\n        input_mean = None\n        input_std = None\n        if track=='cfat':\n            pn_file = osp.join(self.model_path, 'pixel_norm.txt')\n            if osp.exists(pn_file):\n                lines = open(pn_file,'r').readlines()\n                if len(lines)!=2:\n                    return \"pixel_norm.txt should contain 2 lines\"\n                input_mean = float(lines[0])\n                input_std = float(lines[1])\n        if input_mean is not None or input_std is not None:\n            if input_mean is None or input_std is None:\n                return \"please set input_mean and input_std simultaneously\"\n        else:\n            find_sub = False\n            find_mul = False\n            for nid, node in enumerate(graph.node[:8]):\n                print(nid, node.name)\n                if node.name.startswith('Sub') or node.name.startswith('_minus'):\n                    find_sub = True\n                if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'):\n                    find_mul = True\n            if find_sub and find_mul:\n                print(\"find sub and mul\")\n                #mxnet arcface model\n                input_mean = 0.0\n                input_std = 1.0\n            else:\n                input_mean = 127.5\n                input_std = 127.5\n        self.input_mean = input_mean\n        self.input_std = input_std\n        for initn in graph.initializer:\n            weight_array = numpy_helper.to_array(initn)\n            dt = weight_array.dtype\n            if dt.itemsize<4:\n                return 'invalid weight type - (%s:%s)' % (initn.name, dt.name)\n        if test_img is None:\n            test_img = get_image('Tom_Hanks_54745')\n            test_img = cv2.resize(test_img, self.image_size)\n        else:\n            test_img = cv2.resize(test_img, self.image_size)\n        feat, cost = self.benchmark(test_img)\n        batch_result = self.check_batch(test_img)\n        batch_result_sum = float(np.sum(batch_result))\n        if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum:\n            print(batch_result)\n            print(batch_result_sum)\n            return \"batch result output contains NaN!\"\n\n        if len(feat.shape) < 2:\n           return \"the shape of the feature must be two, but get {}\".format(str(feat.shape))\n\n        if feat.shape[1] > max_feat_dim:\n            return \"max feat dim exceed, given %d\"%feat.shape[1]\n        self.feat_dim = feat.shape[1]\n        cost_ms = cost*1000\n        if cost_ms>max_time_cost:\n            return \"max time cost exceed, given %.4f\"%cost_ms\n        self.cost_ms = cost_ms\n        print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std))\n        return None\n\n    def check_batch(self, img):\n        if not isinstance(img, list):\n            imgs = [img, ] * 32\n        if self.crop is not None:\n            nimgs = []\n            for img in imgs:\n                nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :]\n                if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:\n                    nimg = cv2.resize(nimg, self.image_size)\n                nimgs.append(nimg)\n            imgs = nimgs\n        blob = cv2.dnn.blobFromImages(\n            images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size,\n            mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True)\n        net_out = self.session.run(self.output_names, {self.input_name: blob})[0]\n        return net_out\n\n\n    def meta_info(self):\n        return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms}\n\n\n    def forward(self, imgs):\n        if not isinstance(imgs, list):\n            imgs = [imgs]\n        input_size = self.image_size\n        if self.crop is not None:\n            nimgs = []\n            for img in imgs:\n                nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]\n                if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:\n                    nimg = cv2.resize(nimg, input_size)\n                nimgs.append(nimg)\n            imgs = nimgs\n        blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)\n        net_out = self.session.run(self.output_names, {self.input_name : blob})[0]\n        return net_out\n\n    def benchmark(self, img):\n        input_size = self.image_size\n        if self.crop is not None:\n            nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]\n            if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:\n                nimg = cv2.resize(nimg, input_size)\n            img = nimg\n        blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)\n        costs = []\n        for _ in range(50):\n            ta = datetime.datetime.now()\n            net_out = self.session.run(self.output_names, {self.input_name : blob})[0]\n            tb = datetime.datetime.now()\n            cost = (tb-ta).total_seconds()\n            costs.append(cost)\n        costs = sorted(costs)\n        cost = costs[5]\n        return net_out, cost\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='')\n    # general\n    parser.add_argument('workdir', help='submitted work dir', type=str)\n    parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat')\n    args = parser.parse_args()\n    handler = ArcFaceORT(args.workdir)\n    err = handler.check(args.track)\n    print('err:', err)\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/onnx_ijbc.py",
    "content": "import argparse\nimport os\nimport pickle\nimport timeit\n\nimport cv2\nimport mxnet as mx\nimport numpy as np\nimport pandas as pd\nimport prettytable\nimport skimage.transform\nfrom sklearn.metrics import roc_curve\nfrom sklearn.preprocessing import normalize\n\nfrom onnx_helper import ArcFaceORT\n\nSRC = np.array(\n    [\n        [30.2946, 51.6963],\n        [65.5318, 51.5014],\n        [48.0252, 71.7366],\n        [33.5493, 92.3655],\n        [62.7299, 92.2041]]\n    , dtype=np.float32)\nSRC[:, 0] += 8.0\n\n\nclass AlignedDataSet(mx.gluon.data.Dataset):\n    def __init__(self, root, lines, align=True):\n        self.lines = lines\n        self.root = root\n        self.align = align\n\n    def __len__(self):\n        return len(self.lines)\n\n    def __getitem__(self, idx):\n        each_line = self.lines[idx]\n        name_lmk_score = each_line.strip().split(' ')\n        name = os.path.join(self.root, name_lmk_score[0])\n        img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)\n        landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))\n        st = skimage.transform.SimilarityTransform()\n        st.estimate(landmark5, SRC)\n        img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)\n        img_1 = np.expand_dims(img, 0)\n        img_2 = np.expand_dims(np.fliplr(img), 0)\n        output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)\n        output = np.transpose(output, (0, 3, 1, 2))\n        output = mx.nd.array(output)\n        return output\n\n\ndef extract(model_root, dataset):\n    model = ArcFaceORT(model_path=model_root)\n    model.check()\n    feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))\n\n    def batchify_fn(data):\n        return mx.nd.concat(*data, dim=0)\n\n    data_loader = mx.gluon.data.DataLoader(\n        dataset, 128, last_batch='keep', num_workers=4,\n        thread_pool=True, prefetch=16, batchify_fn=batchify_fn)\n    num_iter = 0\n    for batch in data_loader:\n        batch = batch.asnumpy()\n        batch = (batch - model.input_mean) / model.input_std\n        feat = model.session.run(model.output_names, {model.input_name: batch})[0]\n        feat = np.reshape(feat, (-1, model.feat_dim * 2))\n        feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat\n        num_iter += 1\n        if num_iter % 50 == 0:\n            print(num_iter)\n    return feat_mat\n\n\ndef read_template_media_list(path):\n    ijb_meta = pd.read_csv(path, sep=' ', header=None).values\n    templates = ijb_meta[:, 1].astype(np.int)\n    medias = ijb_meta[:, 2].astype(np.int)\n    return templates, medias\n\n\ndef read_template_pair_list(path):\n    pairs = pd.read_csv(path, sep=' ', header=None).values\n    t1 = pairs[:, 0].astype(np.int)\n    t2 = pairs[:, 1].astype(np.int)\n    label = pairs[:, 2].astype(np.int)\n    return t1, t2, label\n\n\ndef read_image_feature(path):\n    with open(path, 'rb') as fid:\n        img_feats = pickle.load(fid)\n    return img_feats\n\n\ndef image2template_feature(img_feats=None,\n                           templates=None,\n                           medias=None):\n    unique_templates = np.unique(templates)\n    template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))\n    for count_template, uqt in enumerate(unique_templates):\n        (ind_t,) = np.where(templates == uqt)\n        face_norm_feats = img_feats[ind_t]\n        face_medias = medias[ind_t]\n        unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)\n        media_norm_feats = []\n        for u, ct in zip(unique_medias, unique_media_counts):\n            (ind_m,) = np.where(face_medias == u)\n            if ct == 1:\n                media_norm_feats += [face_norm_feats[ind_m]]\n            else:  # image features from the same video will be aggregated into one feature\n                media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ]\n        media_norm_feats = np.array(media_norm_feats)\n        template_feats[count_template] = np.sum(media_norm_feats, axis=0)\n        if count_template % 2000 == 0:\n            print('Finish Calculating {} template features.'.format(\n                count_template))\n    template_norm_feats = normalize(template_feats)\n    return template_norm_feats, unique_templates\n\n\ndef verification(template_norm_feats=None,\n                 unique_templates=None,\n                 p1=None,\n                 p2=None):\n    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)\n    for count_template, uqt in enumerate(unique_templates):\n        template2id[uqt] = count_template\n    score = np.zeros((len(p1),))\n    total_pairs = np.array(range(len(p1)))\n    batchsize = 100000\n    sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)]\n    total_sublists = len(sublists)\n    for c, s in enumerate(sublists):\n        feat1 = template_norm_feats[template2id[p1[s]]]\n        feat2 = template_norm_feats[template2id[p2[s]]]\n        similarity_score = np.sum(feat1 * feat2, -1)\n        score[s] = similarity_score.flatten()\n        if c % 10 == 0:\n            print('Finish {}/{} pairs.'.format(c, total_sublists))\n    return score\n\n\ndef verification2(template_norm_feats=None,\n                  unique_templates=None,\n                  p1=None,\n                  p2=None):\n    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)\n    for count_template, uqt in enumerate(unique_templates):\n        template2id[uqt] = count_template\n    score = np.zeros((len(p1),))  # save cosine distance between pairs\n    total_pairs = np.array(range(len(p1)))\n    batchsize = 100000  # small batchsize instead of all pairs in one batch due to the memory limiation\n    sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)]\n    total_sublists = len(sublists)\n    for c, s in enumerate(sublists):\n        feat1 = template_norm_feats[template2id[p1[s]]]\n        feat2 = template_norm_feats[template2id[p2[s]]]\n        similarity_score = np.sum(feat1 * feat2, -1)\n        score[s] = similarity_score.flatten()\n        if c % 10 == 0:\n            print('Finish {}/{} pairs.'.format(c, total_sublists))\n    return score\n\n\ndef main(args):\n    use_norm_score = True  # if Ture, TestMode(N1)\n    use_detector_score = True  # if Ture, TestMode(D1)\n    use_flip_test = True  # if Ture, TestMode(F1)\n    assert args.target == 'IJBC' or args.target == 'IJBB'\n\n    start = timeit.default_timer()\n    templates, medias = read_template_media_list(\n        os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower()))\n    stop = timeit.default_timer()\n    print('Time: %.2f s. ' % (stop - start))\n\n    start = timeit.default_timer()\n    p1, p2, label = read_template_pair_list(\n        os.path.join('%s/meta' % args.image_path,\n                     '%s_template_pair_label.txt' % args.target.lower()))\n    stop = timeit.default_timer()\n    print('Time: %.2f s. ' % (stop - start))\n\n    start = timeit.default_timer()\n    img_path = '%s/loose_crop' % args.image_path\n    img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower())\n    img_list = open(img_list_path)\n    files = img_list.readlines()\n    dataset = AlignedDataSet(root=img_path, lines=files, align=True)\n    img_feats = extract(args.model_root, dataset)\n\n    faceness_scores = []\n    for each_line in files:\n        name_lmk_score = each_line.split()\n        faceness_scores.append(name_lmk_score[-1])\n    faceness_scores = np.array(faceness_scores).astype(np.float32)\n    stop = timeit.default_timer()\n    print('Time: %.2f s. ' % (stop - start))\n    print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1]))\n    start = timeit.default_timer()\n\n    if use_flip_test:\n        img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:]\n    else:\n        img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]\n\n    if use_norm_score:\n        img_input_feats = img_input_feats\n    else:\n        img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True))\n\n    if use_detector_score:\n        print(img_input_feats.shape, faceness_scores.shape)\n        img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]\n    else:\n        img_input_feats = img_input_feats\n\n    template_norm_feats, unique_templates = image2template_feature(\n        img_input_feats, templates, medias)\n    stop = timeit.default_timer()\n    print('Time: %.2f s. ' % (stop - start))\n\n    start = timeit.default_timer()\n    score = verification(template_norm_feats, unique_templates, p1, p2)\n    stop = timeit.default_timer()\n    print('Time: %.2f s. ' % (stop - start))\n    save_path = os.path.join(args.result_dir, \"{}_result\".format(args.target))\n    if not os.path.exists(save_path):\n        os.makedirs(save_path)\n    score_save_file = os.path.join(save_path, \"{}.npy\".format(args.model_root))\n    np.save(score_save_file, score)\n    files = [score_save_file]\n    methods = []\n    scores = []\n    for file in files:\n        methods.append(os.path.basename(file))\n        scores.append(np.load(file))\n    methods = np.array(methods)\n    scores = dict(zip(methods, scores))\n    x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]\n    tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels])\n    for method in methods:\n        fpr, tpr, _ = roc_curve(label, scores[method])\n        fpr = np.flipud(fpr)\n        tpr = np.flipud(tpr)\n        tpr_fpr_row = []\n        tpr_fpr_row.append(\"%s-%s\" % (method, args.target))\n        for fpr_iter in np.arange(len(x_labels)):\n            _, min_index = min(\n                list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))\n            tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))\n        tpr_fpr_table.add_row(tpr_fpr_row)\n    print(tpr_fpr_table)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='do ijb test')\n    # general\n    parser.add_argument('--model-root', default='', help='path to load model.')\n    parser.add_argument('--image-path', default='', type=str, help='')\n    parser.add_argument('--result-dir', default='.', type=str, help='')\n    parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')\n    main(parser.parse_args())\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/partial_fc.py",
    "content": "import logging\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom torch.nn import Module\nfrom torch.nn.functional import normalize, linear\nfrom torch.nn.parameter import Parameter\n\n\nclass PartialFC(Module):\n    \"\"\"\n    Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,\n    Partial FC: Training 10 Million Identities on a Single Machine\n    See the original paper:\n    https://arxiv.org/abs/2010.05222\n    \"\"\"\n\n    @torch.no_grad()\n    def __init__(self, rank, local_rank, world_size, batch_size, resume,\n                 margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix=\"./\"):\n        \"\"\"\n        rank: int\n            Unique process(GPU) ID from 0 to world_size - 1.\n        local_rank: int\n            Unique process(GPU) ID within the server from 0 to 7.\n        world_size: int\n            Number of GPU.\n        batch_size: int\n            Batch size on current rank(GPU).\n        resume: bool\n            Select whether to restore the weight of softmax.\n        margin_softmax: callable\n            A function of margin softmax, eg: cosface, arcface.\n        num_classes: int\n            The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,\n            required.\n        sample_rate: float\n            The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling\n            can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.\n        embedding_size: int\n            The feature dimension, default is 512.\n        prefix: str\n            Path for save checkpoint, default is './'.\n        \"\"\"\n        super(PartialFC, self).__init__()\n        #\n        self.num_classes: int = num_classes\n        self.rank: int = rank\n        self.local_rank: int = local_rank\n        self.device: torch.device = torch.device(\"cuda:{}\".format(self.local_rank))\n        self.world_size: int = world_size\n        self.batch_size: int = batch_size\n        self.margin_softmax: callable = margin_softmax\n        self.sample_rate: float = sample_rate\n        self.embedding_size: int = embedding_size\n        self.prefix: str = prefix\n        self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)\n        self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)\n        self.num_sample: int = int(self.sample_rate * self.num_local)\n\n        self.weight_name = os.path.join(self.prefix, \"rank_{}_softmax_weight.pt\".format(self.rank))\n        self.weight_mom_name = os.path.join(self.prefix, \"rank_{}_softmax_weight_mom.pt\".format(self.rank))\n\n        if resume:\n            try:\n                self.weight: torch.Tensor = torch.load(self.weight_name)\n                self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)\n                if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:\n                    raise IndexError\n                logging.info(\"softmax weight resume successfully!\")\n                logging.info(\"softmax weight mom resume successfully!\")\n            except (FileNotFoundError, KeyError, IndexError):\n                self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)\n                self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)\n                logging.info(\"softmax weight init!\")\n                logging.info(\"softmax weight mom init!\")\n        else:\n            self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)\n            self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)\n            logging.info(\"softmax weight init successfully!\")\n            logging.info(\"softmax weight mom init successfully!\")\n        self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)\n\n        self.index = None\n        if int(self.sample_rate) == 1:\n            self.update = lambda: 0\n            self.sub_weight = Parameter(self.weight)\n            self.sub_weight_mom = self.weight_mom\n        else:\n            self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))\n\n    def save_params(self):\n        \"\"\" Save softmax weight for each rank on prefix\n        \"\"\"\n        torch.save(self.weight.data, self.weight_name)\n        torch.save(self.weight_mom, self.weight_mom_name)\n\n    @torch.no_grad()\n    def sample(self, total_label):\n        \"\"\"\n        Sample all positive class centers in each rank, and random select neg class centers to filling a fixed\n        `num_sample`.\n\n        total_label: tensor\n            Label after all gather, which cross all GPUs.\n        \"\"\"\n        index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)\n        total_label[~index_positive] = -1\n        total_label[index_positive] -= self.class_start\n        if int(self.sample_rate) != 1:\n            positive = torch.unique(total_label[index_positive], sorted=True)\n            if self.num_sample - positive.size(0) >= 0:\n                perm = torch.rand(size=[self.num_local], device=self.device)\n                perm[positive] = 2.0\n                index = torch.topk(perm, k=self.num_sample)[1]\n                index = index.sort()[0]\n            else:\n                index = positive\n            self.index = index\n            total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])\n            self.sub_weight = Parameter(self.weight[index])\n            self.sub_weight_mom = self.weight_mom[index]\n\n    def forward(self, total_features, norm_weight):\n        \"\"\" Partial fc forward, `logits = X * sample(W)`\n        \"\"\"\n        torch.cuda.current_stream().wait_stream(self.stream)\n        logits = linear(total_features, norm_weight)\n        return logits\n\n    @torch.no_grad()\n    def update(self):\n        \"\"\" Set updated weight and weight_mom to memory bank.\n        \"\"\"\n        self.weight_mom[self.index] = self.sub_weight_mom\n        self.weight[self.index] = self.sub_weight\n\n    def prepare(self, label, optimizer):\n        \"\"\"\n        get sampled class centers for cal softmax.\n\n        label: tensor\n            Label tensor on each rank.\n        optimizer: opt\n            Optimizer for partial fc, which need to get weight mom.\n        \"\"\"\n        with torch.cuda.stream(self.stream):\n            total_label = torch.zeros(\n                size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)\n            dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)\n            self.sample(total_label)\n            optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)\n            optimizer.param_groups[-1]['params'][0] = self.sub_weight\n            optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom\n            norm_weight = normalize(self.sub_weight)\n            return total_label, norm_weight\n\n    def forward_backward(self, label, features, optimizer):\n        \"\"\"\n        Partial fc forward and backward with model parallel\n\n        label: tensor\n            Label tensor on each rank(GPU)\n        features: tensor\n            Features tensor on each rank(GPU)\n        optimizer: optimizer\n            Optimizer for partial fc\n\n        Returns:\n        --------\n        x_grad: tensor\n            The gradient of features.\n        loss_v: tensor\n            Loss value for cross entropy.\n        \"\"\"\n        total_label, norm_weight = self.prepare(label, optimizer)\n        total_features = torch.zeros(\n            size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)\n        dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)\n        total_features.requires_grad = True\n\n        logits = self.forward(total_features, norm_weight)\n        logits = self.margin_softmax(logits, total_label)\n\n        with torch.no_grad():\n            max_fc = torch.max(logits, dim=1, keepdim=True)[0]\n            dist.all_reduce(max_fc, dist.ReduceOp.MAX)\n\n            # calculate exp(logits) and all-reduce\n            logits_exp = torch.exp(logits - max_fc)\n            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)\n            dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)\n\n            # calculate prob\n            logits_exp.div_(logits_sum_exp)\n\n            # get one-hot\n            grad = logits_exp\n            index = torch.where(total_label != -1)[0]\n            one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)\n            one_hot.scatter_(1, total_label[index, None], 1)\n\n            # calculate loss\n            loss = torch.zeros(grad.size()[0], 1, device=grad.device)\n            loss[index] = grad[index].gather(1, total_label[index, None])\n            dist.all_reduce(loss, dist.ReduceOp.SUM)\n            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)\n\n            # calculate grad\n            grad[index] -= one_hot\n            grad.div_(self.batch_size * self.world_size)\n\n        logits.backward(grad)\n        if total_features.grad is not None:\n            total_features.grad.detach_()\n        x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)\n        # feature gradient all-reduce\n        dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))\n        x_grad = x_grad * self.world_size\n        # backward backbone\n        return x_grad, loss_v\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/requirement.txt",
    "content": "tensorboard\neasydict\nmxnet\nonnx\nsklearn\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/run.sh",
    "content": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=\"127.0.0.1\" --master_port=1234 train.py configs/ms1mv3_r50\nps -ef | grep \"train\" | grep -v grep | awk '{print \"kill -9 \"$2}' | sh\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/torch2onnx.py",
    "content": "import numpy as np\nimport onnx\nimport torch\n\n\ndef convert_onnx(net, path_module, output, opset=11, simplify=False):\n    assert isinstance(net, torch.nn.Module)\n    img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)\n    img = img.astype(np.float)\n    img = (img / 255. - 0.5) / 0.5  # torch style norm\n    img = img.transpose((2, 0, 1))\n    img = torch.from_numpy(img).unsqueeze(0).float()\n\n    weight = torch.load(path_module)\n    net.load_state_dict(weight)\n    net.eval()\n    torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)\n    model = onnx.load(output)\n    graph = model.graph\n    graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'\n    if simplify:\n        from onnxsim import simplify\n        model, check = simplify(model)\n        assert check, \"Simplified ONNX model could not be validated\"\n    onnx.save(model, output)\n\n    \nif __name__ == '__main__':\n    import os\n    import argparse\n    from backbones import get_model\n\n    parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')\n    parser.add_argument('input', type=str, help='input backbone.pth file or path')\n    parser.add_argument('--output', type=str, default=None, help='output onnx path')\n    parser.add_argument('--network', type=str, default=None, help='backbone network')\n    parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')\n    args = parser.parse_args()\n    input_file = args.input\n    if os.path.isdir(input_file):\n        input_file = os.path.join(input_file, \"backbone.pth\")\n    assert os.path.exists(input_file)\n    model_name = os.path.basename(os.path.dirname(input_file)).lower()\n    params = model_name.split(\"_\")\n    if len(params) >= 3 and params[1] in ('arcface', 'cosface'):\n        if args.network is None:\n            args.network = params[2]\n    assert args.network is not None\n    print(args)\n    backbone_onnx = get_model(args.network, dropout=0)\n\n    output_path = args.output\n    if output_path is None:\n        output_path = os.path.join(os.path.dirname(__file__), 'onnx')\n    if not os.path.exists(output_path):\n        os.makedirs(output_path)\n    assert os.path.isdir(output_path)\n    output_file = os.path.join(output_path, \"%s.onnx\" % model_name)\n    convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/train.py",
    "content": "import argparse\nimport logging\nimport os\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nimport torch.utils.data.distributed\nfrom torch.nn.utils import clip_grad_norm_\n\nimport losses\nfrom backbones import get_model\nfrom dataset import MXFaceDataset, SyntheticDataset, DataLoaderX\nfrom partial_fc import PartialFC\nfrom utils.utils_amp import MaxClipGradScaler\nfrom utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint\nfrom utils.utils_config import get_config\nfrom utils.utils_logging import AverageMeter, init_logging\n\n\ndef main(args):\n    cfg = get_config(args.config)\n    try:\n        world_size = int(os.environ['WORLD_SIZE'])\n        rank = int(os.environ['RANK'])\n        dist.init_process_group('nccl')\n    except KeyError:\n        world_size = 1\n        rank = 0\n        dist.init_process_group(backend='nccl', init_method=\"tcp://127.0.0.1:12584\", rank=rank, world_size=world_size)\n\n    local_rank = args.local_rank\n    torch.cuda.set_device(local_rank)\n    os.makedirs(cfg.output, exist_ok=True)\n    init_logging(rank, cfg.output)\n\n    if cfg.rec == \"synthetic\":\n        train_set = SyntheticDataset(local_rank=local_rank)\n    else:\n        train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)\n    train_loader = DataLoaderX(\n        local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,\n        sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)\n    backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)\n\n    if cfg.resume:\n        try:\n            backbone_pth = os.path.join(cfg.output, \"backbone.pth\")\n            backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))\n            if rank == 0:\n                logging.info(\"backbone resume successfully!\")\n        except (FileNotFoundError, KeyError, IndexError, RuntimeError):\n            if rank == 0:\n                logging.info(\"resume fail, backbone init successfully!\")\n\n    backbone = torch.nn.parallel.DistributedDataParallel(\n        module=backbone, broadcast_buffers=False, device_ids=[local_rank])\n    backbone.train()\n    margin_softmax = losses.get_loss(cfg.loss)\n    module_partial_fc = PartialFC(\n        rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,\n        batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,\n        sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)\n\n    opt_backbone = torch.optim.SGD(\n        params=[{'params': backbone.parameters()}],\n        lr=cfg.lr / 512 * cfg.batch_size * world_size,\n        momentum=0.9, weight_decay=cfg.weight_decay)\n    opt_pfc = torch.optim.SGD(\n        params=[{'params': module_partial_fc.parameters()}],\n        lr=cfg.lr / 512 * cfg.batch_size * world_size,\n        momentum=0.9, weight_decay=cfg.weight_decay)\n\n    num_image = len(train_set)\n    total_batch_size = cfg.batch_size * world_size\n    cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch\n    cfg.total_step = num_image // total_batch_size * cfg.num_epoch\n\n    def lr_step_func(current_step):\n        cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]\n        if current_step < cfg.warmup_step:\n            return current_step / cfg.warmup_step\n        else:\n            return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])\n\n    scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(\n        optimizer=opt_backbone, lr_lambda=lr_step_func)\n    scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(\n        optimizer=opt_pfc, lr_lambda=lr_step_func)\n\n    for key, value in cfg.items():\n        num_space = 25 - len(key)\n        logging.info(\": \" + key + \" \" * num_space + str(value))\n\n    val_target = cfg.val_targets\n    callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)\n    callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)\n    callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)\n\n    loss = AverageMeter()\n    start_epoch = 0\n    global_step = 0\n    grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None\n    for epoch in range(start_epoch, cfg.num_epoch):\n        train_sampler.set_epoch(epoch)\n        for step, (img, label) in enumerate(train_loader):\n            global_step += 1\n            features = F.normalize(backbone(img))\n            x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)\n            if cfg.fp16:\n                features.backward(grad_amp.scale(x_grad))\n                grad_amp.unscale_(opt_backbone)\n                clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)\n                grad_amp.step(opt_backbone)\n                grad_amp.update()\n            else:\n                features.backward(x_grad)\n                clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)\n                opt_backbone.step()\n\n            opt_pfc.step()\n            module_partial_fc.update()\n            opt_backbone.zero_grad()\n            opt_pfc.zero_grad()\n            loss.update(loss_v, 1)\n            callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)\n            callback_verification(global_step, backbone)\n            scheduler_backbone.step()\n            scheduler_pfc.step()\n        callback_checkpoint(global_step, backbone, module_partial_fc)\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    torch.backends.cudnn.benchmark = True\n    parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')\n    parser.add_argument('config', type=str, help='py config file')\n    parser.add_argument('--local_rank', type=int, default=0, help='local_rank')\n    main(parser.parse_args())\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/__init__.py",
    "content": ""
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/plot.py",
    "content": "# coding: utf-8\n\nimport os\nfrom pathlib import Path\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nfrom menpo.visualize.viewmatplotlib import sample_colours_from_colourmap\nfrom prettytable import PrettyTable\nfrom sklearn.metrics import roc_curve, auc\n\nimage_path = \"/data/anxiang/IJB_release/IJBC\"\nfiles = [\n        \"./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy\"\n]\n\n\ndef read_template_pair_list(path):\n    pairs = pd.read_csv(path, sep=' ', header=None).values\n    t1 = pairs[:, 0].astype(np.int)\n    t2 = pairs[:, 1].astype(np.int)\n    label = pairs[:, 2].astype(np.int)\n    return t1, t2, label\n\n\np1, p2, label = read_template_pair_list(\n    os.path.join('%s/meta' % image_path,\n                 '%s_template_pair_label.txt' % 'ijbc'))\n\nmethods = []\nscores = []\nfor file in files:\n    methods.append(file.split('/')[-2])\n    scores.append(np.load(file))\n\nmethods = np.array(methods)\nscores = dict(zip(methods, scores))\ncolours = dict(\n    zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))\nx_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]\ntpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])\nfig = plt.figure()\nfor method in methods:\n    fpr, tpr, _ = roc_curve(label, scores[method])\n    roc_auc = auc(fpr, tpr)\n    fpr = np.flipud(fpr)\n    tpr = np.flipud(tpr)  # select largest tpr at same fpr\n    plt.plot(fpr,\n             tpr,\n             color=colours[method],\n             lw=1,\n             label=('[%s (AUC = %0.4f %%)]' %\n                    (method.split('-')[-1], roc_auc * 100)))\n    tpr_fpr_row = []\n    tpr_fpr_row.append(\"%s-%s\" % (method, \"IJBC\"))\n    for fpr_iter in np.arange(len(x_labels)):\n        _, min_index = min(\n            list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))\n        tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))\n    tpr_fpr_table.add_row(tpr_fpr_row)\nplt.xlim([10 ** -6, 0.1])\nplt.ylim([0.3, 1.0])\nplt.grid(linestyle='--', linewidth=1)\nplt.xticks(x_labels)\nplt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))\nplt.xscale('log')\nplt.xlabel('False Positive Rate')\nplt.ylabel('True Positive Rate')\nplt.title('ROC on IJB')\nplt.legend(loc=\"lower right\")\nprint(tpr_fpr_table)\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_amp.py",
    "content": "from typing import Dict, List\n\nimport torch\n\nif torch.__version__ < '1.9':\n    Iterable = torch._six.container_abcs.Iterable\nelse:\n    import collections\n\n    Iterable = collections.abc.Iterable\nfrom torch.cuda.amp import GradScaler\n\n\nclass _MultiDeviceReplicator(object):\n    \"\"\"\n    Lazily serves copies of a tensor to requested devices.  Copies are cached per-device.\n    \"\"\"\n\n    def __init__(self, master_tensor: torch.Tensor) -> None:\n        assert master_tensor.is_cuda\n        self.master = master_tensor\n        self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}\n\n    def get(self, device) -> torch.Tensor:\n        retval = self._per_device_tensors.get(device, None)\n        if retval is None:\n            retval = self.master.to(device=device, non_blocking=True, copy=True)\n            self._per_device_tensors[device] = retval\n        return retval\n\n\nclass MaxClipGradScaler(GradScaler):\n    def __init__(self, init_scale, max_scale: float, growth_interval=100):\n        GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)\n        self.max_scale = max_scale\n\n    def scale_clip(self):\n        if self.get_scale() == self.max_scale:\n            self.set_growth_factor(1)\n        elif self.get_scale() < self.max_scale:\n            self.set_growth_factor(2)\n        elif self.get_scale() > self.max_scale:\n            self._scale.fill_(self.max_scale)\n            self.set_growth_factor(1)\n\n    def scale(self, outputs):\n        \"\"\"\n        Multiplies ('scales') a tensor or list of tensors by the scale factor.\n\n        Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returned\n        unmodified.\n\n        Arguments:\n            outputs (Tensor or iterable of Tensors):  Outputs to scale.\n        \"\"\"\n        if not self._enabled:\n            return outputs\n        self.scale_clip()\n        # Short-circuit for the common case.\n        if isinstance(outputs, torch.Tensor):\n            assert outputs.is_cuda\n            if self._scale is None:\n                self._lazy_init_scale_growth_tracker(outputs.device)\n            assert self._scale is not None\n            return outputs * self._scale.to(device=outputs.device, non_blocking=True)\n\n        # Invoke the more complex machinery only if we're treating multiple outputs.\n        stash: List[_MultiDeviceReplicator] = []  # holds a reference that can be overwritten by apply_scale\n\n        def apply_scale(val):\n            if isinstance(val, torch.Tensor):\n                assert val.is_cuda\n                if len(stash) == 0:\n                    if self._scale is None:\n                        self._lazy_init_scale_growth_tracker(val.device)\n                    assert self._scale is not None\n                    stash.append(_MultiDeviceReplicator(self._scale))\n                return val * stash[0].get(val.device)\n            elif isinstance(val, Iterable):\n                iterable = map(apply_scale, val)\n                if isinstance(val, list) or isinstance(val, tuple):\n                    return type(val)(iterable)\n                else:\n                    return iterable\n            else:\n                raise ValueError(\"outputs must be a Tensor or an iterable of Tensors\")\n\n        return apply_scale(outputs)\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_callbacks.py",
    "content": "import logging\nimport os\nimport time\nfrom typing import List\n\nimport torch\n\nfrom eval import verification\nfrom utils.utils_logging import AverageMeter\n\n\nclass CallBackVerification(object):\n    def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)):\n        self.frequent: int = frequent\n        self.rank: int = rank\n        self.highest_acc: float = 0.0\n        self.highest_acc_list: List[float] = [0.0] * len(val_targets)\n        self.ver_list: List[object] = []\n        self.ver_name_list: List[str] = []\n        if self.rank is 0:\n            self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)\n\n    def ver_test(self, backbone: torch.nn.Module, global_step: int):\n        results = []\n        for i in range(len(self.ver_list)):\n            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(\n                self.ver_list[i], backbone, 10, 10)\n            logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))\n            logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))\n            if acc2 > self.highest_acc_list[i]:\n                self.highest_acc_list[i] = acc2\n            logging.info(\n                '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))\n            results.append(acc2)\n\n    def init_dataset(self, val_targets, data_dir, image_size):\n        for name in val_targets:\n            path = os.path.join(data_dir, name + \".bin\")\n            if os.path.exists(path):\n                data_set = verification.load_bin(path, image_size)\n                self.ver_list.append(data_set)\n                self.ver_name_list.append(name)\n\n    def __call__(self, num_update, backbone: torch.nn.Module):\n        if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:\n            backbone.eval()\n            self.ver_test(backbone, num_update)\n            backbone.train()\n\n\nclass CallBackLogging(object):\n    def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):\n        self.frequent: int = frequent\n        self.rank: int = rank\n        self.time_start = time.time()\n        self.total_step: int = total_step\n        self.batch_size: int = batch_size\n        self.world_size: int = world_size\n        self.writer = writer\n\n        self.init = False\n        self.tic = 0\n\n    def __call__(self,\n                 global_step: int,\n                 loss: AverageMeter,\n                 epoch: int,\n                 fp16: bool,\n                 learning_rate: float,\n                 grad_scaler: torch.cuda.amp.GradScaler):\n        if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:\n            if self.init:\n                try:\n                    speed: float = self.frequent * self.batch_size / (time.time() - self.tic)\n                    speed_total = speed * self.world_size\n                except ZeroDivisionError:\n                    speed_total = float('inf')\n\n                time_now = (time.time() - self.time_start) / 3600\n                time_total = time_now / ((global_step + 1) / self.total_step)\n                time_for_end = time_total - time_now\n                if self.writer is not None:\n                    self.writer.add_scalar('time_for_end', time_for_end, global_step)\n                    self.writer.add_scalar('learning_rate', learning_rate, global_step)\n                    self.writer.add_scalar('loss', loss.avg, global_step)\n                if fp16:\n                    msg = \"Speed %.2f samples/sec   Loss %.4f   LearningRate %.4f   Epoch: %d   Global Step: %d   \" \\\n                          \"Fp16 Grad Scale: %2.f   Required: %1.f hours\" % (\n                              speed_total, loss.avg, learning_rate, epoch, global_step,\n                              grad_scaler.get_scale(), time_for_end\n                          )\n                else:\n                    msg = \"Speed %.2f samples/sec   Loss %.4f   LearningRate %.4f   Epoch: %d   Global Step: %d   \" \\\n                          \"Required: %1.f hours\" % (\n                              speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end\n                          )\n                logging.info(msg)\n                loss.reset()\n                self.tic = time.time()\n            else:\n                self.init = True\n                self.tic = time.time()\n\n\nclass CallBackModelCheckpoint(object):\n    def __init__(self, rank, output=\"./\"):\n        self.rank: int = rank\n        self.output: str = output\n\n    def __call__(self, global_step, backbone, partial_fc, ):\n        if global_step > 100 and self.rank == 0:\n            path_module = os.path.join(self.output, \"backbone.pth\")\n            torch.save(backbone.module.state_dict(), path_module)\n            logging.info(\"Pytorch Model Saved in '{}'\".format(path_module))\n\n        if global_step > 100 and partial_fc is not None:\n            partial_fc.save_params()\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_config.py",
    "content": "import importlib\nimport os.path as osp\n\n\ndef get_config(config_file):\n    assert config_file.startswith('configs/'), 'config file setting must start with configs/'\n    temp_config_name = osp.basename(config_file)\n    temp_module_name = osp.splitext(temp_config_name)[0]\n    config = importlib.import_module(\"configs.base\")\n    cfg = config.config\n    config = importlib.import_module(\"configs.%s\" % temp_module_name)\n    job_cfg = config.config\n    cfg.update(job_cfg)\n    if cfg.output is None:\n        cfg.output = osp.join('work_dirs', temp_module_name)\n    return cfg"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_logging.py",
    "content": "import logging\nimport os\nimport sys\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\n    \"\"\"\n\n    def __init__(self):\n        self.val = None\n        self.avg = None\n        self.sum = None\n        self.count = None\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef init_logging(rank, models_root):\n    if rank == 0:\n        log_root = logging.getLogger()\n        log_root.setLevel(logging.INFO)\n        formatter = logging.Formatter(\"Training: %(asctime)s-%(message)s\")\n        handler_file = logging.FileHandler(os.path.join(models_root, \"training.log\"))\n        handler_stream = logging.StreamHandler(sys.stdout)\n        handler_file.setFormatter(formatter)\n        handler_stream.setFormatter(formatter)\n        log_root.addHandler(handler_file)\n        log_root.addHandler(handler_stream)\n        log_root.info('rank_id: %d' % rank)\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_os.py",
    "content": ""
  },
  {
    "path": "src/face3d/models/base_model.py",
    "content": "\"\"\"This script defines the base network model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os\nimport numpy as np\nimport torch\nfrom collections import OrderedDict\nfrom abc import ABC, abstractmethod\nfrom . import networks\n\n\nclass BaseModel(ABC):\n    \"\"\"This class is an abstract base class (ABC) for models.\n    To create a subclass, you need to implement the following five functions:\n        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).\n        -- <set_input>:                     unpack data from dataset and apply preprocessing.\n        -- <forward>:                       produce intermediate results.\n        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.\n        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize the BaseModel class.\n\n        Parameters:\n            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions\n\n        When creating your custom class, you need to implement your own initialization.\n        In this fucntion, you should first call <BaseModel.__init__(self, opt)>\n        Then, you need to define four lists:\n            -- self.loss_names (str list):          specify the training losses that you want to plot and save.\n            -- self.model_names (str list):         specify the images that you want to display and save.\n            -- self.visual_names (str list):        define networks used in our training.\n            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.\n        \"\"\"\n        self.opt = opt\n        self.isTrain = False\n        self.device = torch.device('cpu') \n        self.save_dir = \" \" # os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir\n        self.loss_names = []\n        self.model_names = []\n        self.visual_names = []\n        self.parallel_names = []\n        self.optimizers = []\n        self.image_paths = []\n        self.metric = 0  # used for learning rate policy 'plateau'\n\n    @staticmethod\n    def dict_grad_hook_factory(add_func=lambda x: x):\n        saved_dict = dict()\n\n        def hook_gen(name):\n            def grad_hook(grad):\n                saved_vals = add_func(grad)\n                saved_dict[name] = saved_vals\n            return grad_hook\n        return hook_gen, saved_dict\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        \"\"\"Add new model-specific options, and rewrite default values for existing options.\n\n        Parameters:\n            parser          -- original option parser\n            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n        \"\"\"\n        return parser\n\n    @abstractmethod\n    def set_input(self, input):\n        \"\"\"Unpack input data from the dataloader and perform necessary pre-processing steps.\n\n        Parameters:\n            input (dict): includes the data itself and its metadata information.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def forward(self):\n        \"\"\"Run forward pass; called by both functions <optimize_parameters> and <test>.\"\"\"\n        pass\n\n    @abstractmethod\n    def optimize_parameters(self):\n        \"\"\"Calculate losses, gradients, and update network weights; called in every training iteration\"\"\"\n        pass\n\n    def setup(self, opt):\n        \"\"\"Load and print networks; create schedulers\n\n        Parameters:\n            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        if self.isTrain:\n            self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]\n        \n        if not self.isTrain or opt.continue_train:\n            load_suffix = opt.epoch\n            self.load_networks(load_suffix)\n \n            \n        # self.print_networks(opt.verbose)\n\n    def parallelize(self, convert_sync_batchnorm=True):\n        if not self.opt.use_ddp:\n            for name in self.parallel_names:\n                if isinstance(name, str):\n                    module = getattr(self, name)\n                    setattr(self, name, module.to(self.device))\n        else:\n            for name in self.model_names:\n                if isinstance(name, str):\n                    module = getattr(self, name)\n                    if convert_sync_batchnorm:\n                        module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)\n                    setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device),\n                        device_ids=[self.device.index], \n                        find_unused_parameters=True, broadcast_buffers=True))\n            \n            # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.\n            for name in self.parallel_names:\n                if isinstance(name, str) and name not in self.model_names:\n                    module = getattr(self, name)\n                    setattr(self, name, module.to(self.device))\n            \n        # put state_dict of optimizer to gpu device\n        if self.opt.phase != 'test':\n            if self.opt.continue_train:\n                for optim in self.optimizers:\n                    for state in optim.state.values():\n                        for k, v in state.items():\n                            if isinstance(v, torch.Tensor):\n                                state[k] = v.to(self.device)\n\n    def data_dependent_initialize(self, data):\n        pass\n\n    def train(self):\n        \"\"\"Make models train mode\"\"\"\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, name)\n                net.train()\n\n    def eval(self):\n        \"\"\"Make models eval mode\"\"\"\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, name)\n                net.eval()\n\n    def test(self):\n        \"\"\"Forward function used in test time.\n\n        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop\n        It also calls <compute_visuals> to produce additional visualization results\n        \"\"\"\n        with torch.no_grad():\n            self.forward()\n            self.compute_visuals()\n\n    def compute_visuals(self):\n        \"\"\"Calculate additional output images for visdom and HTML visualization\"\"\"\n        pass\n\n    def get_image_paths(self, name='A'):\n        \"\"\" Return image paths that are used to load current data\"\"\"\n        return self.image_paths if name =='A' else self.image_paths_B\n\n    def update_learning_rate(self):\n        \"\"\"Update learning rates for all the networks; called at the end of every epoch\"\"\"\n        for scheduler in self.schedulers:\n            if self.opt.lr_policy == 'plateau':\n                scheduler.step(self.metric)\n            else:\n                scheduler.step()\n\n        lr = self.optimizers[0].param_groups[0]['lr']\n        print('learning rate = %.7f' % lr)\n\n    def get_current_visuals(self):\n        \"\"\"Return visualization images. train.py will display these images with visdom, and save the images to a HTML\"\"\"\n        visual_ret = OrderedDict()\n        for name in self.visual_names:\n            if isinstance(name, str):\n                visual_ret[name] = getattr(self, name)[:, :3, ...]\n        return visual_ret\n\n    def get_current_losses(self):\n        \"\"\"Return traning losses / errors. train.py will print out these errors on console, and save them to a file\"\"\"\n        errors_ret = OrderedDict()\n        for name in self.loss_names:\n            if isinstance(name, str):\n                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number\n        return errors_ret\n\n    def save_networks(self, epoch):\n        \"\"\"Save all the networks to the disk.\n\n        Parameters:\n            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)\n        \"\"\"\n        if not os.path.isdir(self.save_dir):\n            os.makedirs(self.save_dir)\n\n        save_filename = 'epoch_%s.pth' % (epoch)\n        save_path = os.path.join(self.save_dir, save_filename)\n        \n        save_dict = {}\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, name)\n                if isinstance(net, torch.nn.DataParallel) or isinstance(net,\n                        torch.nn.parallel.DistributedDataParallel):\n                    net = net.module\n                save_dict[name] = net.state_dict()\n                \n\n        for i, optim in enumerate(self.optimizers):\n            save_dict['opt_%02d'%i] = optim.state_dict()\n\n        for i, sched in enumerate(self.schedulers):\n            save_dict['sched_%02d'%i] = sched.state_dict()\n        \n        torch.save(save_dict, save_path)\n\n    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):\n        \"\"\"Fix InstanceNorm checkpoints incompatibility (prior to 0.4)\"\"\"\n        key = keys[i]\n        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer\n            if module.__class__.__name__.startswith('InstanceNorm') and \\\n                    (key == 'running_mean' or key == 'running_var'):\n                if getattr(module, key) is None:\n                    state_dict.pop('.'.join(keys))\n            if module.__class__.__name__.startswith('InstanceNorm') and \\\n               (key == 'num_batches_tracked'):\n                state_dict.pop('.'.join(keys))\n        else:\n            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)\n\n    def load_networks(self, epoch):\n        \"\"\"Load all the networks from the disk.\n\n        Parameters:\n            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)\n        \"\"\"\n        if self.opt.isTrain and self.opt.pretrained_name is not None:\n            load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)\n        else:\n            load_dir = self.save_dir    \n        load_filename = 'epoch_%s.pth' % (epoch)\n        load_path = os.path.join(load_dir, load_filename)\n        state_dict = torch.load(load_path, map_location=self.device)\n        print('loading the model from %s' % load_path)\n\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, name)\n                if isinstance(net, torch.nn.DataParallel):\n                    net = net.module\n                net.load_state_dict(state_dict[name])\n        \n        if self.opt.phase != 'test':\n            if self.opt.continue_train:\n                print('loading the optim from %s' % load_path)\n                for i, optim in enumerate(self.optimizers):\n                    optim.load_state_dict(state_dict['opt_%02d'%i])\n\n                try:\n                    print('loading the sched from %s' % load_path)\n                    for i, sched in enumerate(self.schedulers):\n                        sched.load_state_dict(state_dict['sched_%02d'%i])\n                except:\n                    print('Failed to load schedulers, set schedulers according to epoch count manually')\n                    for i, sched in enumerate(self.schedulers):\n                        sched.last_epoch = self.opt.epoch_count - 1\n                    \n\n            \n\n    def print_networks(self, verbose):\n        \"\"\"Print the total number of parameters in the network and (if verbose) network architecture\n\n        Parameters:\n            verbose (bool) -- if verbose: print the network architecture\n        \"\"\"\n        print('---------- Networks initialized -------------')\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, name)\n                num_params = 0\n                for param in net.parameters():\n                    num_params += param.numel()\n                if verbose:\n                    print(net)\n                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))\n        print('-----------------------------------------------')\n\n    def set_requires_grad(self, nets, requires_grad=False):\n        \"\"\"Set requies_grad=Fasle for all the networks to avoid unnecessary computations\n        Parameters:\n            nets (network list)   -- a list of networks\n            requires_grad (bool)  -- whether the networks require gradients or not\n        \"\"\"\n        if not isinstance(nets, list):\n            nets = [nets]\n        for net in nets:\n            if net is not None:\n                for param in net.parameters():\n                    param.requires_grad = requires_grad\n\n    def generate_visuals_for_evaluation(self, data, mode):\n        return {}\n"
  },
  {
    "path": "src/face3d/models/bfm.py",
    "content": "\"\"\"This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nimport  torch\nimport torch.nn.functional as F\nfrom scipy.io import loadmat\nfrom src.face3d.util.load_mats import transferBFM09\nimport os\n\ndef perspective_projection(focal, center):\n    # return p.T (N, 3) @ (3, 3) \n    return np.array([\n        focal, 0, center,\n        0, focal, center,\n        0, 0, 1\n    ]).reshape([3, 3]).astype(np.float32).transpose()\n\nclass SH:\n    def __init__(self):\n        self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]\n        self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]\n\n\n\nclass ParametricFaceModel:\n    def __init__(self, \n                bfm_folder='./BFM', \n                recenter=True,\n                camera_distance=10.,\n                init_lit=np.array([\n                    0.8, 0, 0, 0, 0, 0, 0, 0, 0\n                    ]),\n                focal=1015.,\n                center=112.,\n                is_train=True,\n                default_name='BFM_model_front.mat'):\n        \n        if not os.path.isfile(os.path.join(bfm_folder, default_name)):\n            transferBFM09(bfm_folder)\n            \n        model = loadmat(os.path.join(bfm_folder, default_name))\n        # mean face shape. [3*N,1]\n        self.mean_shape = model['meanshape'].astype(np.float32)\n        # identity basis. [3*N,80]\n        self.id_base = model['idBase'].astype(np.float32)\n        # expression basis. [3*N,64]\n        self.exp_base = model['exBase'].astype(np.float32)\n        # mean face texture. [3*N,1] (0-255)\n        self.mean_tex = model['meantex'].astype(np.float32)\n        # texture basis. [3*N,80]\n        self.tex_base = model['texBase'].astype(np.float32)\n        # face indices for each vertex that lies in. starts from 0. [N,8]\n        self.point_buf = model['point_buf'].astype(np.int64) - 1\n        # vertex indices for each face. starts from 0. [F,3]\n        self.face_buf = model['tri'].astype(np.int64) - 1\n        # vertex indices for 68 landmarks. starts from 0. [68,1]\n        self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1\n\n        if is_train:\n            # vertex indices for small face region to compute photometric error. starts from 0.\n            self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1\n            # vertex indices for each face from small face region. starts from 0. [f,3]\n            self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1\n            # vertex indices for pre-defined skin region to compute reflectance loss\n            self.skin_mask = np.squeeze(model['skinmask'])\n        \n        if recenter:\n            mean_shape = self.mean_shape.reshape([-1, 3])\n            mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)\n            self.mean_shape = mean_shape.reshape([-1, 1])\n\n        self.persc_proj = perspective_projection(focal, center)\n        self.device = 'cpu'\n        self.camera_distance = camera_distance\n        self.SH = SH()\n        self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)\n        \n\n    def to(self, device):\n        self.device = device\n        for key, value in self.__dict__.items():\n            if type(value).__module__ == np.__name__:\n                setattr(self, key, torch.tensor(value).to(device))\n\n    \n    def compute_shape(self, id_coeff, exp_coeff):\n        \"\"\"\n        Return:\n            face_shape       -- torch.tensor, size (B, N, 3)\n\n        Parameters:\n            id_coeff         -- torch.tensor, size (B, 80), identity coeffs\n            exp_coeff        -- torch.tensor, size (B, 64), expression coeffs\n        \"\"\"\n        batch_size = id_coeff.shape[0]\n        id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)\n        exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)\n        face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])\n        return face_shape.reshape([batch_size, -1, 3])\n    \n\n    def compute_texture(self, tex_coeff, normalize=True):\n        \"\"\"\n        Return:\n            face_texture     -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)\n\n        Parameters:\n            tex_coeff        -- torch.tensor, size (B, 80)\n        \"\"\"\n        batch_size = tex_coeff.shape[0]\n        face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex\n        if normalize:\n            face_texture = face_texture / 255.\n        return face_texture.reshape([batch_size, -1, 3])\n\n\n    def compute_norm(self, face_shape):\n        \"\"\"\n        Return:\n            vertex_norm      -- torch.tensor, size (B, N, 3)\n\n        Parameters:\n            face_shape       -- torch.tensor, size (B, N, 3)\n        \"\"\"\n\n        v1 = face_shape[:, self.face_buf[:, 0]]\n        v2 = face_shape[:, self.face_buf[:, 1]]\n        v3 = face_shape[:, self.face_buf[:, 2]]\n        e1 = v1 - v2\n        e2 = v2 - v3\n        face_norm = torch.cross(e1, e2, dim=-1)\n        face_norm = F.normalize(face_norm, dim=-1, p=2)\n        face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)\n        \n        vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)\n        vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)\n        return vertex_norm\n\n\n    def compute_color(self, face_texture, face_norm, gamma):\n        \"\"\"\n        Return:\n            face_color       -- torch.tensor, size (B, N, 3), range (0, 1.)\n\n        Parameters:\n            face_texture     -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)\n            face_norm        -- torch.tensor, size (B, N, 3), rotated face normal\n            gamma            -- torch.tensor, size (B, 27), SH coeffs\n        \"\"\"\n        batch_size = gamma.shape[0]\n        v_num = face_texture.shape[1]\n        a, c = self.SH.a, self.SH.c\n        gamma = gamma.reshape([batch_size, 3, 9])\n        gamma = gamma + self.init_lit\n        gamma = gamma.permute(0, 2, 1)\n        Y = torch.cat([\n             a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),\n            -a[1] * c[1] * face_norm[..., 1:2],\n             a[1] * c[1] * face_norm[..., 2:],\n            -a[1] * c[1] * face_norm[..., :1],\n             a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],\n            -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],\n            0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),\n            -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],\n            0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2  - face_norm[..., 1:2] ** 2)\n        ], dim=-1)\n        r = Y @ gamma[..., :1]\n        g = Y @ gamma[..., 1:2]\n        b = Y @ gamma[..., 2:]\n        face_color = torch.cat([r, g, b], dim=-1) * face_texture\n        return face_color\n\n    \n    def compute_rotation(self, angles):\n        \"\"\"\n        Return:\n            rot              -- torch.tensor, size (B, 3, 3) pts @ trans_mat\n\n        Parameters:\n            angles           -- torch.tensor, size (B, 3), radian\n        \"\"\"\n\n        batch_size = angles.shape[0]\n        ones = torch.ones([batch_size, 1]).to(self.device)\n        zeros = torch.zeros([batch_size, 1]).to(self.device)\n        x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],\n        \n        rot_x = torch.cat([\n            ones, zeros, zeros,\n            zeros, torch.cos(x), -torch.sin(x), \n            zeros, torch.sin(x), torch.cos(x)\n        ], dim=1).reshape([batch_size, 3, 3])\n        \n        rot_y = torch.cat([\n            torch.cos(y), zeros, torch.sin(y),\n            zeros, ones, zeros,\n            -torch.sin(y), zeros, torch.cos(y)\n        ], dim=1).reshape([batch_size, 3, 3])\n\n        rot_z = torch.cat([\n            torch.cos(z), -torch.sin(z), zeros,\n            torch.sin(z), torch.cos(z), zeros,\n            zeros, zeros, ones\n        ], dim=1).reshape([batch_size, 3, 3])\n\n        rot = rot_z @ rot_y @ rot_x\n        return rot.permute(0, 2, 1)\n\n\n    def to_camera(self, face_shape):\n        face_shape[..., -1] = self.camera_distance - face_shape[..., -1]\n        return face_shape\n\n    def to_image(self, face_shape):\n        \"\"\"\n        Return:\n            face_proj        -- torch.tensor, size (B, N, 2), y direction is opposite to v direction\n\n        Parameters:\n            face_shape       -- torch.tensor, size (B, N, 3)\n        \"\"\"\n        # to image_plane\n        face_proj = face_shape @ self.persc_proj\n        face_proj = face_proj[..., :2] / face_proj[..., 2:]\n\n        return face_proj\n\n\n    def transform(self, face_shape, rot, trans):\n        \"\"\"\n        Return:\n            face_shape       -- torch.tensor, size (B, N, 3) pts @ rot + trans\n\n        Parameters:\n            face_shape       -- torch.tensor, size (B, N, 3)\n            rot              -- torch.tensor, size (B, 3, 3)\n            trans            -- torch.tensor, size (B, 3)\n        \"\"\"\n        return face_shape @ rot + trans.unsqueeze(1)\n\n\n    def get_landmarks(self, face_proj):\n        \"\"\"\n        Return:\n            face_lms         -- torch.tensor, size (B, 68, 2)\n\n        Parameters:\n            face_proj       -- torch.tensor, size (B, N, 2)\n        \"\"\"  \n        return face_proj[:, self.keypoints]\n\n    def split_coeff(self, coeffs):\n        \"\"\"\n        Return:\n            coeffs_dict     -- a dict of torch.tensors\n\n        Parameters:\n            coeffs          -- torch.tensor, size (B, 256)\n        \"\"\"\n        id_coeffs = coeffs[:, :80]\n        exp_coeffs = coeffs[:, 80: 144]\n        tex_coeffs = coeffs[:, 144: 224]\n        angles = coeffs[:, 224: 227]\n        gammas = coeffs[:, 227: 254]\n        translations = coeffs[:, 254:]\n        return {\n            'id': id_coeffs,\n            'exp': exp_coeffs,\n            'tex': tex_coeffs,\n            'angle': angles,\n            'gamma': gammas,\n            'trans': translations\n        }\n    def compute_for_render(self, coeffs):\n        \"\"\"\n        Return:\n            face_vertex     -- torch.tensor, size (B, N, 3), in camera coordinate\n            face_color      -- torch.tensor, size (B, N, 3), in RGB order\n            landmark        -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction\n        Parameters:\n            coeffs          -- torch.tensor, size (B, 257)\n        \"\"\"\n        coef_dict = self.split_coeff(coeffs)\n        face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])\n        rotation = self.compute_rotation(coef_dict['angle'])\n\n\n        face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])\n        face_vertex = self.to_camera(face_shape_transformed)\n        \n        face_proj = self.to_image(face_vertex)\n        landmark = self.get_landmarks(face_proj)\n\n        face_texture = self.compute_texture(coef_dict['tex'])\n        face_norm = self.compute_norm(face_shape)\n        face_norm_roted = face_norm @ rotation\n        face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])\n\n        return face_vertex, face_texture, face_color, landmark\n\n    def compute_for_render_woRotation(self, coeffs):\n        \"\"\"\n        Return:\n            face_vertex     -- torch.tensor, size (B, N, 3), in camera coordinate\n            face_color      -- torch.tensor, size (B, N, 3), in RGB order\n            landmark        -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction\n        Parameters:\n            coeffs          -- torch.tensor, size (B, 257)\n        \"\"\"\n        coef_dict = self.split_coeff(coeffs)\n        face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])\n        #rotation = self.compute_rotation(coef_dict['angle'])\n\n\n        #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])\n        face_vertex = self.to_camera(face_shape)\n        \n        face_proj = self.to_image(face_vertex)\n        landmark = self.get_landmarks(face_proj)\n\n        face_texture = self.compute_texture(coef_dict['tex'])\n        face_norm = self.compute_norm(face_shape)\n        face_norm_roted = face_norm                                    # @ rotation\n        face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])\n\n        return face_vertex, face_texture, face_color, landmark\n\n\nif __name__ == '__main__':\n    transferBFM09()"
  },
  {
    "path": "src/face3d/models/facerecon_model.py",
    "content": "\"\"\"This script defines the face reconstruction model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nimport torch\nfrom src.face3d.models.base_model import BaseModel\nfrom src.face3d.models import networks\nfrom src.face3d.models.bfm import ParametricFaceModel\nfrom src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss\nfrom src.face3d.util import util \nfrom src.face3d.util.nvdiffrast import MeshRenderer\n# from src.face3d.util.preprocess import estimate_norm_torch\n\nimport trimesh\nfrom scipy.io import savemat\n\nclass FaceReconModel(BaseModel):\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train=False):\n        \"\"\"  Configures options specific for CUT model\n        \"\"\"\n        # net structure and parameters\n        parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')\n        parser.add_argument('--init_path', type=str, default='./checkpoints/init_model/resnet50-0676ba61.pth')\n        parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc')\n        parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')\n        parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')\n\n        # renderer parameters\n        parser.add_argument('--focal', type=float, default=1015.)\n        parser.add_argument('--center', type=float, default=112.)\n        parser.add_argument('--camera_d', type=float, default=10.)\n        parser.add_argument('--z_near', type=float, default=5.)\n        parser.add_argument('--z_far', type=float, default=15.)\n\n        if is_train:\n            # training parameters\n            parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure')\n            parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth')\n            parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss')\n            parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face')\n\n            \n            # augmentation parameters\n            parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels')\n            parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor')\n            parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree')\n\n            # loss weights\n            parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss')\n            parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss')\n            parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss')\n            parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss')\n            parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss')\n            parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss')\n            parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss')\n            parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss')\n            parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss')\n\n        opt, _ = parser.parse_known_args()\n        parser.set_defaults(\n                focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.\n            )\n        if is_train:\n            parser.set_defaults(\n                use_crop_face=True, use_predef_M=False\n            )\n        return parser\n\n    def __init__(self, opt):\n        \"\"\"Initialize this model class.\n\n        Parameters:\n            opt -- training/test options\n\n        A few things can be done here.\n        - (required) call the initialization function of BaseModel\n        - define loss function, visualization images, model names, and optimizers\n        \"\"\"\n        BaseModel.__init__(self, opt)  # call the initialization method of BaseModel\n        \n        self.visual_names = ['output_vis']\n        self.model_names = ['net_recon']\n        self.parallel_names = self.model_names + ['renderer']\n\n        self.facemodel = ParametricFaceModel(\n            bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center,\n            is_train=self.isTrain, default_name=opt.bfm_model\n        )\n        \n        fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi\n        self.renderer = MeshRenderer(\n            rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center)\n        )\n\n        if self.isTrain:\n            self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc']\n\n            self.net_recog = networks.define_net_recog(\n                net_recog=opt.net_recog, pretrained_path=opt.net_recog_path\n                )\n            # loss func name: (compute_%s_loss) % loss_name\n            self.compute_feat_loss = perceptual_loss\n            self.comupte_color_loss = photo_loss\n            self.compute_lm_loss = landmark_loss\n            self.compute_reg_loss = reg_loss\n            self.compute_reflc_loss = reflectance_loss\n\n            self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)\n            self.optimizers = [self.optimizer]\n            self.parallel_names += ['net_recog']\n        # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks\n\n    def set_input(self, input):\n        \"\"\"Unpack input data from the dataloader and perform necessary pre-processing steps.\n\n        Parameters:\n            input: a dictionary that contains the data itself and its metadata information.\n        \"\"\"\n        self.input_img = input['imgs'].to(self.device) \n        self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None\n        self.gt_lm = input['lms'].to(self.device)  if 'lms' in input else None\n        self.trans_m = input['M'].to(self.device) if 'M' in input else None\n        self.image_paths = input['im_paths'] if 'im_paths' in input else None\n\n    def forward(self, output_coeff, device):\n        self.facemodel.to(device)\n        self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \\\n            self.facemodel.compute_for_render(output_coeff)\n        self.pred_mask, _, self.pred_face = self.renderer(\n            self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)\n        \n        self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)\n\n\n    def compute_losses(self):\n        \"\"\"Calculate losses, gradients, and update network weights; called in every training iteration\"\"\"\n\n        assert self.net_recog.training == False\n        trans_m = self.trans_m\n        if not self.opt.use_predef_M:\n            trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])\n\n        pred_feat = self.net_recog(self.pred_face, trans_m)\n        gt_feat = self.net_recog(self.input_img, self.trans_m)\n        self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)\n\n        face_mask = self.pred_mask\n        if self.opt.use_crop_face:\n            face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)\n        \n        face_mask = face_mask.detach()\n        self.loss_color = self.opt.w_color * self.comupte_color_loss(\n            self.pred_face, self.input_img, self.atten_mask * face_mask)\n        \n        loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)\n        self.loss_reg = self.opt.w_reg * loss_reg\n        self.loss_gamma = self.opt.w_gamma * loss_gamma\n\n        self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)\n\n        self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)\n\n        self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \\\n                        + self.loss_lm + self.loss_reflc\n            \n\n    def optimize_parameters(self, isTrain=True):\n        self.forward()               \n        self.compute_losses()\n        \"\"\"Update network weights; it will be called in every training iteration.\"\"\"\n        if isTrain:\n            self.optimizer.zero_grad()  \n            self.loss_all.backward()         \n            self.optimizer.step()        \n\n    def compute_visuals(self):\n        with torch.no_grad():\n            input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()\n            output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img\n            output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()\n            \n            if self.gt_lm is not None:\n                gt_lm_numpy = self.gt_lm.cpu().numpy()\n                pred_lm_numpy = self.pred_lm.detach().cpu().numpy()\n                output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b')\n                output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r')\n            \n                output_vis_numpy = np.concatenate((input_img_numpy, \n                                    output_vis_numpy_raw, output_vis_numpy), axis=-2)\n            else:\n                output_vis_numpy = np.concatenate((input_img_numpy, \n                                    output_vis_numpy_raw), axis=-2)\n\n            self.output_vis = torch.tensor(\n                    output_vis_numpy / 255., dtype=torch.float32\n                ).permute(0, 3, 1, 2).to(self.device)\n\n    def save_mesh(self, name):\n\n        recon_shape = self.pred_vertex  # get reconstructed shape\n        recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space\n        recon_shape = recon_shape.cpu().numpy()[0]\n        recon_color = self.pred_color\n        recon_color = recon_color.cpu().numpy()[0]\n        tri = self.facemodel.face_buf.cpu().numpy()\n        mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8))\n        mesh.export(name)\n\n    def save_coeff(self,name):\n\n        pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}\n        pred_lm = self.pred_lm.cpu().numpy()\n        pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate\n        pred_coeffs['lm68'] = pred_lm\n        savemat(name,pred_coeffs)\n\n\n\n"
  },
  {
    "path": "src/face3d/models/losses.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom kornia.geometry import warp_affine\nimport torch.nn.functional as F\n\ndef resize_n_crop(image, M, dsize=112):\n    # image: (b, c, h, w)\n    # M   :  (b, 2, 3)\n    return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)\n\n### perceptual level loss\nclass PerceptualLoss(nn.Module):\n    def __init__(self, recog_net, input_size=112):\n        super(PerceptualLoss, self).__init__()\n        self.recog_net = recog_net\n        self.preprocess = lambda x: 2 * x - 1\n        self.input_size=input_size\n    def forward(imageA, imageB, M):\n        \"\"\"\n        1 - cosine distance\n        Parameters:\n            imageA       --torch.tensor (B, 3, H, W), range (0, 1) , RGB order\n            imageB       --same as imageA\n        \"\"\"\n\n        imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))\n        imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))\n\n        # freeze bn\n        self.recog_net.eval()\n        \n        id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)\n        id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)  \n        cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)\n        # assert torch.sum((cosine_d > 1).float()) == 0\n        return torch.sum(1 - cosine_d) / cosine_d.shape[0]        \n\ndef perceptual_loss(id_featureA, id_featureB):\n    cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)\n        # assert torch.sum((cosine_d > 1).float()) == 0\n    return torch.sum(1 - cosine_d) / cosine_d.shape[0]  \n\n### image level loss\ndef photo_loss(imageA, imageB, mask, eps=1e-6):\n    \"\"\"\n    l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)\n    Parameters:\n        imageA       --torch.tensor (B, 3, H, W), range (0, 1), RGB order \n        imageB       --same as imageA\n    \"\"\"\n    loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask\n    loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))\n    return loss\n\ndef landmark_loss(predict_lm, gt_lm, weight=None):\n    \"\"\"\n    weighted mse loss\n    Parameters:\n        predict_lm    --torch.tensor (B, 68, 2)\n        gt_lm         --torch.tensor (B, 68, 2)\n        weight        --numpy.array (1, 68)\n    \"\"\"\n    if not weight:\n        weight = np.ones([68])\n        weight[28:31] = 20\n        weight[-8:] = 20\n        weight = np.expand_dims(weight, 0)\n        weight = torch.tensor(weight).to(predict_lm.device)\n    loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight\n    loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])\n    return loss\n\n\n### regulization\ndef reg_loss(coeffs_dict, opt=None):\n    \"\"\"\n    l2 norm without the sqrt, from yu's implementation (mse)\n    tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss\n    Parameters:\n        coeffs_dict     -- a  dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans\n\n    \"\"\"\n    # coefficient regularization to ensure plausible 3d faces\n    if opt:\n        w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex\n    else:\n        w_id, w_exp, w_tex = 1, 1, 1, 1\n    creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) +  \\\n           w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \\\n           w_tex * torch.sum(coeffs_dict['tex'] ** 2)\n    creg_loss = creg_loss / coeffs_dict['id'].shape[0]\n\n    # gamma regularization to ensure a nearly-monochromatic light\n    gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])\n    gamma_mean = torch.mean(gamma, dim=1, keepdims=True)\n    gamma_loss = torch.mean((gamma - gamma_mean) ** 2)\n\n    return creg_loss, gamma_loss\n\ndef reflectance_loss(texture, mask):\n    \"\"\"\n    minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo\n    Parameters:\n        texture       --torch.tensor, (B, N, 3)\n        mask          --torch.tensor, (N), 1 or 0\n\n    \"\"\"\n    mask = mask.reshape([1, mask.shape[0], 1])\n    texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)\n    loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))\n    return loss\n\n"
  },
  {
    "path": "src/face3d/models/networks.py",
    "content": "\"\"\"This script defines deep neural networks for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os\nimport numpy as np\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport functools\nfrom torch.optim import lr_scheduler\nimport torch\nfrom torch import Tensor\nimport torch.nn as nn\ntry:\n    from torch.hub import load_state_dict_from_url\nexcept ImportError:\n    from torch.utils.model_zoo import load_url as load_state_dict_from_url\nfrom typing import Type, Any, Callable, Union, List, Optional\nfrom .arcface_torch.backbones import get_model\nfrom kornia.geometry import warp_affine\n\ndef resize_n_crop(image, M, dsize=112):\n    # image: (b, c, h, w)\n    # M   :  (b, 2, 3)\n    return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)\n\ndef filter_state_dict(state_dict, remove_name='fc'):\n    new_state_dict = {}\n    for key in state_dict:\n        if remove_name in key:\n            continue\n        new_state_dict[key] = state_dict[key]\n    return new_state_dict\n\ndef get_scheduler(optimizer, opt):\n    \"\"\"Return a learning rate scheduler\n\n    Parameters:\n        optimizer          -- the optimizer of the network\n        opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions．　\n                              opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine\n\n    For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.\n    See https://pytorch.org/docs/stable/optim.html for more details.\n    \"\"\"\n    if opt.lr_policy == 'linear':\n        def lambda_rule(epoch):\n            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1)\n            return lr_l\n        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)\n    elif opt.lr_policy == 'step':\n        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2)\n    elif opt.lr_policy == 'plateau':\n        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)\n    elif opt.lr_policy == 'cosine':\n        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)\n    else:\n        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)\n    return scheduler\n\n\ndef define_net_recon(net_recon, use_last_fc=False, init_path=None):\n    return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path)\n\ndef define_net_recog(net_recog, pretrained_path=None):\n    net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path)\n    net.eval()\n    return net\n\nclass ReconNetWrapper(nn.Module):\n    fc_dim=257\n    def __init__(self, net_recon, use_last_fc=False, init_path=None):\n        super(ReconNetWrapper, self).__init__()\n        self.use_last_fc = use_last_fc\n        if net_recon not in func_dict:\n            return  NotImplementedError('network [%s] is not implemented', net_recon)\n        func, last_dim = func_dict[net_recon]\n        backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)\n        if init_path and os.path.isfile(init_path):\n            state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))\n            backbone.load_state_dict(state_dict)\n            print(\"loading init net_recon %s from %s\" %(net_recon, init_path))\n        self.backbone = backbone\n        if not use_last_fc:\n            self.final_layers = nn.ModuleList([\n                conv1x1(last_dim, 80, bias=True), # id layer\n                conv1x1(last_dim, 64, bias=True), # exp layer\n                conv1x1(last_dim, 80, bias=True), # tex layer\n                conv1x1(last_dim, 3, bias=True),  # angle layer\n                conv1x1(last_dim, 27, bias=True), # gamma layer\n                conv1x1(last_dim, 2, bias=True),  # tx, ty\n                conv1x1(last_dim, 1, bias=True)   # tz\n            ])\n            for m in self.final_layers:\n                nn.init.constant_(m.weight, 0.)\n                nn.init.constant_(m.bias, 0.)\n\n    def forward(self, x):\n        x = self.backbone(x)\n        if not self.use_last_fc:\n            output = []\n            for layer in self.final_layers:\n                output.append(layer(x))\n            x = torch.flatten(torch.cat(output, dim=1), 1)\n        return x\n\n\nclass RecogNetWrapper(nn.Module):\n    def __init__(self, net_recog, pretrained_path=None, input_size=112):\n        super(RecogNetWrapper, self).__init__()\n        net = get_model(name=net_recog, fp16=False)\n        if pretrained_path:\n            state_dict = torch.load(pretrained_path, map_location='cpu')\n            net.load_state_dict(state_dict)\n            print(\"loading pretrained net_recog %s from %s\" %(net_recog, pretrained_path))\n        for param in net.parameters():\n            param.requires_grad = False\n        self.net = net\n        self.preprocess = lambda x: 2 * x - 1\n        self.input_size=input_size\n        \n    def forward(self, image, M):\n        image = self.preprocess(resize_n_crop(image, M, self.input_size))\n        id_feature = F.normalize(self.net(image), dim=-1, p=2)\n        return id_feature\n\n\n# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py\n__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',\n           'wide_resnet50_2', 'wide_resnet101_2']\n\n\nmodel_urls = {\n    'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',\n    'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',\n    'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',\n    'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',\n    'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',\n    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',\n    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',\n    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',\n    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',\n}\n\n\ndef conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)\n\n\nclass BasicBlock(nn.Module):\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n\n    def __init__(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        use_last_fc: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.use_last_fc = use_last_fc\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n                                       dilate=replace_stride_with_dilation[2])\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        \n        if self.use_last_fc:\n            self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]\n\n    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,\n                    stride: int = 1, dilate: bool = False) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        if self.use_last_fc:\n            x = torch.flatten(x, 1)\n            x = self.fc(x)\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[BasicBlock, Bottleneck]],\n    layers: List[int],\n    pretrained: bool,\n    progress: bool,\n    **kwargs: Any\n) -> ResNet:\n    model = ResNet(block, layers, **kwargs)\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch],\n                                              progress=progress)\n        model.load_state_dict(state_dict)\n    return model\n\n\ndef resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-18 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-34 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-50 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-101 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-152 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNeXt-50 32x4d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 4\n    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNeXt-101 32x8d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 8\n    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"Wide ResNet-50-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_.\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"Wide ResNet-101-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_.\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n\nfunc_dict = {\n    'resnet18': (resnet18, 512),\n    'resnet50': (resnet50, 2048)\n}\n"
  },
  {
    "path": "src/face3d/models/template_model.py",
    "content": "\"\"\"Model class template\n\nThis module provides a template for users to implement custom models.\nYou can specify '--model template' to use this model.\nThe class name should be consistent with both the filename and its model option.\nThe filename should be <model>_dataset.py\nThe class name should be <Model>Dataset.py\nIt implements a simple image-to-image translation baseline based on regression loss.\nGiven input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:\n    min_<netG> ||netG(data_A) - data_B||_1\nYou need to implement the following functions:\n    <modify_commandline_options>:　Add model-specific options and rewrite default values for existing options.\n    <__init__>: Initialize this model class.\n    <set_input>: Unpack input data and perform data pre-processing.\n    <forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.\n    <optimize_parameters>: Update network weights; it will be called in every training iteration.\n\"\"\"\nimport numpy as np\nimport torch\nfrom .base_model import BaseModel\nfrom . import networks\n\n\nclass TemplateModel(BaseModel):\n    @staticmethod\n    def modify_commandline_options(parser, is_train=True):\n        \"\"\"Add new model-specific options and rewrite default values for existing options.\n\n        Parameters:\n            parser -- the option parser\n            is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n        \"\"\"\n        parser.set_defaults(dataset_mode='aligned')  # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.\n        if is_train:\n            parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss')  # You can define new arguments for this model.\n\n        return parser\n\n    def __init__(self, opt):\n        \"\"\"Initialize this model class.\n\n        Parameters:\n            opt -- training/test options\n\n        A few things can be done here.\n        - (required) call the initialization function of BaseModel\n        - define loss function, visualization images, model names, and optimizers\n        \"\"\"\n        BaseModel.__init__(self, opt)  # call the initialization method of BaseModel\n        # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.\n        self.loss_names = ['loss_G']\n        # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.\n        self.visual_names = ['data_A', 'data_B', 'output']\n        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.\n        # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.\n        self.model_names = ['G']\n        # define networks; you can use opt.isTrain to specify different behaviors for training and test.\n        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)\n        if self.isTrain:  # only defined during training time\n            # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.\n            # We also provide a GANLoss class \"networks.GANLoss\". self.criterionGAN = networks.GANLoss().to(self.device)\n            self.criterionLoss = torch.nn.L1Loss()\n            # define and initialize optimizers. You can define one optimizer for each network.\n            # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.\n            self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))\n            self.optimizers = [self.optimizer]\n\n        # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks\n\n    def set_input(self, input):\n        \"\"\"Unpack input data from the dataloader and perform necessary pre-processing steps.\n\n        Parameters:\n            input: a dictionary that contains the data itself and its metadata information.\n        \"\"\"\n        AtoB = self.opt.direction == 'AtoB'  # use <direction> to swap data_A and data_B\n        self.data_A = input['A' if AtoB else 'B'].to(self.device)  # get image data A\n        self.data_B = input['B' if AtoB else 'A'].to(self.device)  # get image data B\n        self.image_paths = input['A_paths' if AtoB else 'B_paths']  # get image paths\n\n    def forward(self):\n        \"\"\"Run forward pass. This will be called by both functions <optimize_parameters> and <test>.\"\"\"\n        self.output = self.netG(self.data_A)  # generate output image given the input data_A\n\n    def backward(self):\n        \"\"\"Calculate losses, gradients, and update network weights; called in every training iteration\"\"\"\n        # caculate the intermediate results if necessary; here self.output has been computed during function <forward>\n        # calculate loss given the input and intermediate results\n        self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression\n        self.loss_G.backward()       # calculate gradients of network G w.r.t. loss_G\n\n    def optimize_parameters(self):\n        \"\"\"Update network weights; it will be called in every training iteration.\"\"\"\n        self.forward()               # first call forward to calculate intermediate results\n        self.optimizer.zero_grad()   # clear network G's existing gradients\n        self.backward()              # calculate gradients for network G\n        self.optimizer.step()        # update gradients for network G\n"
  },
  {
    "path": "src/face3d/options/__init__.py",
    "content": "\"\"\"This package options includes option modules: training options, test options, and basic options (used in both training and test).\"\"\"\n"
  },
  {
    "path": "src/face3d/options/base_options.py",
    "content": "\"\"\"This script contains base options for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport argparse\nimport os\nfrom util import util\nimport numpy as np\nimport torch\nimport face3d.models as models\nimport face3d.data as data\n\n\nclass BaseOptions():\n    \"\"\"This class defines options used during both training and test time.\n\n    It also implements several helper functions such as parsing, printing, and saving the options.\n    It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.\n    \"\"\"\n\n    def __init__(self, cmd_line=None):\n        \"\"\"Reset the class; indicates the class hasn't been initailized\"\"\"\n        self.initialized = False\n        self.cmd_line = None\n        if cmd_line is not None:\n            self.cmd_line = cmd_line.split()\n\n    def initialize(self, parser):\n        \"\"\"Define the common options that are used in both training and test.\"\"\"\n        # basic parameters\n        parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models')\n        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')\n        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n        parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization')\n        parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation')\n        parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel')\n        parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port')\n        parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses')\n        parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard')\n        parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation')\n\n        # model parameters\n        parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.')\n\n        # additional parameters\n        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')\n        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n\n        self.initialized = True\n        return parser\n\n    def gather_options(self):\n        \"\"\"Initialize our parser with basic options(only once).\n        Add additional model-specific and dataset-specific options.\n        These options are defined in the <modify_commandline_options> function\n        in model and dataset classes.\n        \"\"\"\n        if not self.initialized:  # check if it has been initialized\n            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n            parser = self.initialize(parser)\n\n        # get the basic options\n        if self.cmd_line is None:\n            opt, _ = parser.parse_known_args()\n        else:\n            opt, _ = parser.parse_known_args(self.cmd_line)\n\n        # set cuda visible devices\n        os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids\n\n        # modify model-related parser options\n        model_name = opt.model\n        model_option_setter = models.get_option_setter(model_name)\n        parser = model_option_setter(parser, self.isTrain)\n        if self.cmd_line is None:\n            opt, _ = parser.parse_known_args()  # parse again with new defaults\n        else:\n            opt, _ = parser.parse_known_args(self.cmd_line)  # parse again with new defaults\n\n        # modify dataset-related parser options\n        if opt.dataset_mode:\n            dataset_name = opt.dataset_mode\n            dataset_option_setter = data.get_option_setter(dataset_name)\n            parser = dataset_option_setter(parser, self.isTrain)\n\n        # save and return the parser\n        self.parser = parser\n        if self.cmd_line is None:\n            return parser.parse_args()\n        else:\n            return parser.parse_args(self.cmd_line)\n\n    def print_options(self, opt):\n        \"\"\"Print and save options\n\n        It will print both current options and default values(if different).\n        It will save options into a text file / [checkpoints_dir] / opt.txt\n        \"\"\"\n        message = ''\n        message += '----------------- Options ---------------\\n'\n        for k, v in sorted(vars(opt).items()):\n            comment = ''\n            default = self.parser.get_default(k)\n            if v != default:\n                comment = '\\t[default: %s]' % str(default)\n            message += '{:>25}: {:<30}{}\\n'.format(str(k), str(v), comment)\n        message += '----------------- End -------------------'\n        print(message)\n\n        # save to the disk\n        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)\n        util.mkdirs(expr_dir)\n        file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))\n        try:\n            with open(file_name, 'wt') as opt_file:\n                opt_file.write(message)\n                opt_file.write('\\n')\n        except PermissionError as error:\n            print(\"permission error {}\".format(error))\n            pass\n\n    def parse(self):\n        \"\"\"Parse our options, create checkpoints directory suffix, and set up gpu device.\"\"\"\n        opt = self.gather_options()\n        opt.isTrain = self.isTrain   # train or test\n\n        # process opt.suffix\n        if opt.suffix:\n            suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''\n            opt.name = opt.name + suffix\n\n\n        # set gpu ids\n        str_ids = opt.gpu_ids.split(',')\n        gpu_ids = []\n        for str_id in str_ids:\n            id = int(str_id)\n            if id >= 0:\n                gpu_ids.append(id)\n        opt.world_size = len(gpu_ids)\n        # if len(opt.gpu_ids) > 0:\n        #     torch.cuda.set_device(gpu_ids[0])\n        if opt.world_size == 1:\n            opt.use_ddp = False\n\n        if opt.phase != 'test':\n            # set continue_train automatically\n            if opt.pretrained_name is None:\n                model_dir = os.path.join(opt.checkpoints_dir, opt.name)\n            else:\n                model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)\n            if os.path.isdir(model_dir):\n                model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')]\n                if os.path.isdir(model_dir) and len(model_pths) != 0:\n                    opt.continue_train= True\n        \n            # update the latest epoch count\n            if opt.continue_train:\n                if opt.epoch == 'latest':\n                    epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i]\n                    if len(epoch_counts) != 0:\n                        opt.epoch_count = max(epoch_counts) + 1\n                else:\n                    opt.epoch_count = int(opt.epoch) + 1\n                    \n\n        self.print_options(opt)\n        self.opt = opt\n        return self.opt\n"
  },
  {
    "path": "src/face3d/options/inference_options.py",
    "content": "from face3d.options.base_options import BaseOptions\n\n\nclass InferenceOptions(BaseOptions):\n    \"\"\"This class includes test options.\n\n    It also includes shared options defined in BaseOptions.\n    \"\"\"\n\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)  # define shared options\n        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')\n        parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')\n\n        parser.add_argument('--input_dir', type=str, help='the folder of the input files')\n        parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files')\n        parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients')\n        parser.add_argument('--save_split_files', action='store_true', help='save split files or not')\n        parser.add_argument('--inference_batch_size', type=int, default=8)\n        \n        # Dropout and Batchnorm has different behavior during training and test.\n        self.isTrain = False\n        return parser\n"
  },
  {
    "path": "src/face3d/options/test_options.py",
    "content": "\"\"\"This script contains the test options for Deep3DFaceRecon_pytorch\n\"\"\"\n\nfrom .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    \"\"\"This class includes test options.\n\n    It also includes shared options defined in BaseOptions.\n    \"\"\"\n\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)  # define shared options\n        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')\n        parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')\n        parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')\n\n        # Dropout and Batchnorm has different behavior during training and test.\n        self.isTrain = False\n        return parser\n"
  },
  {
    "path": "src/face3d/options/train_options.py",
    "content": "\"\"\"This script contains the training options for Deep3DFaceRecon_pytorch\n\"\"\"\n\nfrom .base_options import BaseOptions\nfrom util import util\n\nclass TrainOptions(BaseOptions):\n    \"\"\"This class includes training options.\n\n    It also includes shared options defined in BaseOptions.\n    \"\"\"\n\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)\n        # dataset parameters\n        # for train\n        parser.add_argument('--data_root', type=str, default='./', help='dataset root')\n        parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')\n        parser.add_argument('--batch_size', type=int, default=32)\n        parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')\n        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')\n        parser.add_argument('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n        parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')\n        parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')\n\n        # for val\n        parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')\n        parser.add_argument('--batch_size_val', type=int, default=32)\n\n\n        # visualization parameters\n        parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')\n        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n        \n        # network saving and loading parameters\n        parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n        parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')\n        parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')\n        parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')\n        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')\n        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')\n        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')\n        parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')\n\n        # training parameters\n        parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')\n        parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')\n        parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')\n        parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')\n\n        self.isTrain = True\n        return parser\n"
  },
  {
    "path": "src/face3d/util/__init__.py",
    "content": "\"\"\"This package includes a miscellaneous collection of useful helper functions.\"\"\"\nfrom src.face3d.util import *\n\n"
  },
  {
    "path": "src/face3d/util/detect_lm68.py",
    "content": "import os\nimport cv2\nimport numpy as np\nfrom scipy.io import loadmat\nimport tensorflow as tf\nfrom util.preprocess import align_for_lm\nfrom shutil import move\n\nmean_face = np.loadtxt('util/test_mean_face.txt')\nmean_face = mean_face.reshape([68, 2])\n\ndef save_label(labels, save_path):\n    np.savetxt(save_path, labels)\n\ndef draw_landmarks(img, landmark, save_name):\n    landmark = landmark\n    lm_img = np.zeros([img.shape[0], img.shape[1], 3])\n    lm_img[:] = img.astype(np.float32)\n    landmark = np.round(landmark).astype(np.int32)\n\n    for i in range(len(landmark)):\n        for j in range(-1, 1):\n            for k in range(-1, 1):\n                if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \\\n                        img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \\\n                        landmark[i, 0]+k > 0 and \\\n                        landmark[i, 0]+k < img.shape[1]:\n                    lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k,\n                           :] = np.array([0, 0, 255])\n    lm_img = lm_img.astype(np.uint8)\n\n    cv2.imwrite(save_name, lm_img)\n\n\ndef load_data(img_name, txt_name):\n    return cv2.imread(img_name), np.loadtxt(txt_name)\n\n# create tensorflow graph for landmark detector\ndef load_lm_graph(graph_filename):\n    with tf.gfile.GFile(graph_filename, 'rb') as f:\n        graph_def = tf.GraphDef()\n        graph_def.ParseFromString(f.read())\n\n    with tf.Graph().as_default() as graph:\n        tf.import_graph_def(graph_def, name='net')\n        img_224 = graph.get_tensor_by_name('net/input_imgs:0')\n        output_lm = graph.get_tensor_by_name('net/lm:0')\n        lm_sess = tf.Session(graph=graph)\n\n    return lm_sess,img_224,output_lm\n\n# landmark detection\ndef detect_68p(img_path,sess,input_op,output_op):\n    print('detecting landmarks......')\n    names = [i for i in sorted(os.listdir(\n        img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]\n    vis_path = os.path.join(img_path, 'vis')\n    remove_path = os.path.join(img_path, 'remove')\n    save_path = os.path.join(img_path, 'landmarks')\n    if not os.path.isdir(vis_path):\n        os.makedirs(vis_path)\n    if not os.path.isdir(remove_path):\n        os.makedirs(remove_path)\n    if not os.path.isdir(save_path):\n        os.makedirs(save_path)\n\n    for i in range(0, len(names)):\n        name = names[i]\n        print('%05d' % (i), ' ', name)\n        full_image_name = os.path.join(img_path, name)\n        txt_name = '.'.join(name.split('.')[:-1]) + '.txt'\n        full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image\n\n        # if an image does not have detected 5 facial landmarks, remove it from the training list\n        if not os.path.isfile(full_txt_name):\n            move(full_image_name, os.path.join(remove_path, name))\n            continue \n\n        # load data\n        img, five_points = load_data(full_image_name, full_txt_name)\n        input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection \n\n        # if the alignment fails, remove corresponding image from the training list\n        if scale == 0:\n            move(full_txt_name, os.path.join(\n                remove_path, txt_name))\n            move(full_image_name, os.path.join(remove_path, name))\n            continue\n\n        # detect landmarks\n        input_img = np.reshape(\n            input_img, [1, 224, 224, 3]).astype(np.float32)\n        landmark = sess.run(\n            output_op, feed_dict={input_op: input_img})\n\n        # transform back to original image coordinate\n        landmark = landmark.reshape([68, 2]) + mean_face\n        landmark[:, 1] = 223 - landmark[:, 1]\n        landmark = landmark / scale\n        landmark[:, 0] = landmark[:, 0] + bbox[0]\n        landmark[:, 1] = landmark[:, 1] + bbox[1]\n        landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]\n\n        if i % 100 == 0:\n            draw_landmarks(img, landmark, os.path.join(vis_path, name))\n        save_label(landmark, os.path.join(save_path, txt_name))\n"
  },
  {
    "path": "src/face3d/util/generate_list.py",
    "content": "\"\"\"This script is to generate training list files for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os\n\n# save path to training data\ndef write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''):\n    save_path = os.path.join(save_folder, mode)\n    if not os.path.isdir(save_path):\n        os.makedirs(save_path)\n    with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd:\n        fd.writelines([i + '\\n' for i in lms_list])\n\n    with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd:\n        fd.writelines([i + '\\n' for i in imgs_list])\n\n    with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd:\n        fd.writelines([i + '\\n' for i in msks_list])   \n\n# check if the path is valid\ndef check_list(rlms_list, rimgs_list, rmsks_list):\n    lms_list, imgs_list, msks_list = [], [], []\n    for i in range(len(rlms_list)):\n        flag = 'false'\n        lm_path = rlms_list[i]\n        im_path = rimgs_list[i]\n        msk_path = rmsks_list[i]\n        if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):\n            flag = 'true'\n            lms_list.append(rlms_list[i])\n            imgs_list.append(rimgs_list[i])\n            msks_list.append(rmsks_list[i])\n        print(i, rlms_list[i], flag)\n    return lms_list, imgs_list, msks_list\n"
  },
  {
    "path": "src/face3d/util/html.py",
    "content": "import dominate\nfrom dominate.tags import meta, h3, table, tr, td, p, a, img, br\nimport os\n\n\nclass HTML:\n    \"\"\"This HTML class allows us to save images and write texts into a single HTML file.\n\n     It consists of functions such as <add_header> (add a text header to the HTML file),\n     <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).\n     It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.\n    \"\"\"\n\n    def __init__(self, web_dir, title, refresh=0):\n        \"\"\"Initialize the HTML classes\n\n        Parameters:\n            web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/\n            title (str)   -- the webpage name\n            refresh (int) -- how often the website refresh itself; if 0; no refreshing\n        \"\"\"\n        self.title = title\n        self.web_dir = web_dir\n        self.img_dir = os.path.join(self.web_dir, 'images')\n        if not os.path.exists(self.web_dir):\n            os.makedirs(self.web_dir)\n        if not os.path.exists(self.img_dir):\n            os.makedirs(self.img_dir)\n\n        self.doc = dominate.document(title=title)\n        if refresh > 0:\n            with self.doc.head:\n                meta(http_equiv=\"refresh\", content=str(refresh))\n\n    def get_image_dir(self):\n        \"\"\"Return the directory that stores images\"\"\"\n        return self.img_dir\n\n    def add_header(self, text):\n        \"\"\"Insert a header to the HTML file\n\n        Parameters:\n            text (str) -- the header text\n        \"\"\"\n        with self.doc:\n            h3(text)\n\n    def add_images(self, ims, txts, links, width=400):\n        \"\"\"add images to the HTML file\n\n        Parameters:\n            ims (str list)   -- a list of image paths\n            txts (str list)  -- a list of image names shown on the website\n            links (str list) --  a list of hyperref links; when you click an image, it will redirect you to a new page\n        \"\"\"\n        self.t = table(border=1, style=\"table-layout: fixed;\")  # Insert a table\n        self.doc.add(self.t)\n        with self.t:\n            with tr():\n                for im, txt, link in zip(ims, txts, links):\n                    with td(style=\"word-wrap: break-word;\", halign=\"center\", valign=\"top\"):\n                        with p():\n                            with a(href=os.path.join('images', link)):\n                                img(style=\"width:%dpx\" % width, src=os.path.join('images', im))\n                            br()\n                            p(txt)\n\n    def save(self):\n        \"\"\"save the current content to the HMTL file\"\"\"\n        html_file = '%s/index.html' % self.web_dir\n        f = open(html_file, 'wt')\n        f.write(self.doc.render())\n        f.close()\n\n\nif __name__ == '__main__':  # we show an example usage here.\n    html = HTML('web/', 'test_html')\n    html.add_header('hello world')\n\n    ims, txts, links = [], [], []\n    for n in range(4):\n        ims.append('image_%d.png' % n)\n        txts.append('text_%d' % n)\n        links.append('image_%d.png' % n)\n    html.add_images(ims, txts, links)\n    html.save()\n"
  },
  {
    "path": "src/face3d/util/load_mats.py",
    "content": "\"\"\"This script is to load 3D face model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nfrom PIL import Image\nfrom scipy.io import loadmat, savemat\nfrom array import array\nimport os.path as osp\n\n# load expression basis\ndef LoadExpBasis(bfm_folder='BFM'):\n    n_vertex = 53215\n    Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')\n    exp_dim = array('i')\n    exp_dim.fromfile(Expbin, 1)\n    expMU = array('f')\n    expPC = array('f')\n    expMU.fromfile(Expbin, 3*n_vertex)\n    expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex)\n    Expbin.close()\n\n    expPC = np.array(expPC)\n    expPC = np.reshape(expPC, [exp_dim[0], -1])\n    expPC = np.transpose(expPC)\n\n    expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))\n\n    return expPC, expEV\n\n\n# transfer original BFM09 to our face model\ndef transferBFM09(bfm_folder='BFM'):\n    print('Transfer BFM09 to BFM_model_front......')\n    original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))\n    shapePC = original_BFM['shapePC']  # shape basis\n    shapeEV = original_BFM['shapeEV']  # corresponding eigen value\n    shapeMU = original_BFM['shapeMU']  # mean face\n    texPC = original_BFM['texPC']  # texture basis\n    texEV = original_BFM['texEV']  # eigen value\n    texMU = original_BFM['texMU']  # mean texture\n\n    expPC, expEV = LoadExpBasis(bfm_folder)\n\n    # transfer BFM09 to our face model\n\n    idBase = shapePC*np.reshape(shapeEV, [-1, 199])\n    idBase = idBase/1e5  # unify the scale to decimeter\n    idBase = idBase[:, :80]  # use only first 80 basis\n\n    exBase = expPC*np.reshape(expEV, [-1, 79])\n    exBase = exBase/1e5  # unify the scale to decimeter\n    exBase = exBase[:, :64]  # use only first 64 basis\n\n    texBase = texPC*np.reshape(texEV, [-1, 199])\n    texBase = texBase[:, :80]  # use only first 80 basis\n\n    # our face model is cropped along face landmarks and contains only 35709 vertex.\n    # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.\n    # thus we select corresponding vertex to get our face model.\n\n    index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))\n    index_exp = index_exp['idx'].astype(np.int32) - 1  # starts from 0 (to 53215)\n\n    index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))\n    index_shape = index_shape['trimIndex'].astype(\n        np.int32) - 1  # starts from 0 (to 53490)\n    index_shape = index_shape[index_exp]\n\n    idBase = np.reshape(idBase, [-1, 3, 80])\n    idBase = idBase[index_shape, :, :]\n    idBase = np.reshape(idBase, [-1, 80])\n\n    texBase = np.reshape(texBase, [-1, 3, 80])\n    texBase = texBase[index_shape, :, :]\n    texBase = np.reshape(texBase, [-1, 80])\n\n    exBase = np.reshape(exBase, [-1, 3, 64])\n    exBase = exBase[index_exp, :, :]\n    exBase = np.reshape(exBase, [-1, 64])\n\n    meanshape = np.reshape(shapeMU, [-1, 3])/1e5\n    meanshape = meanshape[index_shape, :]\n    meanshape = np.reshape(meanshape, [1, -1])\n\n    meantex = np.reshape(texMU, [-1, 3])\n    meantex = meantex[index_shape, :]\n    meantex = np.reshape(meantex, [1, -1])\n\n    # other info contains triangles, region used for computing photometric loss,\n    # region used for skin texture regularization, and 68 landmarks index etc.\n    other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))\n    frontmask2_idx = other_info['frontmask2_idx']\n    skinmask = other_info['skinmask']\n    keypoints = other_info['keypoints']\n    point_buf = other_info['point_buf']\n    tri = other_info['tri']\n    tri_mask2 = other_info['tri_mask2']\n\n    # save our face model\n    savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase,\n            'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask})\n\n\n# load landmarks for standard face, which is used for image preprocessing\ndef load_lm3d(bfm_folder):\n\n    Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))\n    Lm3D = Lm3D['lm']\n\n    # calculate 5 facial landmarks using 68 landmarks\n    lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1\n    Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(\n        Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)\n    Lm3D = Lm3D[[1, 2, 0, 3, 4], :]\n\n    return Lm3D\n\n\nif __name__ == '__main__':\n    transferBFM09()"
  },
  {
    "path": "src/face3d/util/my_awing_arch.py",
    "content": "import cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef calculate_points(heatmaps):\n    # change heatmaps to landmarks\n    B, N, H, W = heatmaps.shape\n    HW = H * W\n    BN_range = np.arange(B * N)\n\n    heatline = heatmaps.reshape(B, N, HW)\n    indexes = np.argmax(heatline, axis=2)\n\n    preds = np.stack((indexes % W, indexes // W), axis=2)\n    preds = preds.astype(np.float, copy=False)\n\n    inr = indexes.ravel()\n\n    heatline = heatline.reshape(B * N, HW)\n    x_up = heatline[BN_range, inr + 1]\n    x_down = heatline[BN_range, inr - 1]\n    # y_up = heatline[BN_range, inr + W]\n\n    if any((inr + W) >= 4096):\n        y_up = heatline[BN_range, 4095]\n    else:\n        y_up = heatline[BN_range, inr + W]\n    if any((inr - W) <= 0):\n        y_down = heatline[BN_range, 0]\n    else:\n        y_down = heatline[BN_range, inr - W]\n\n    think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1))\n    think_diff *= .25\n\n    preds += think_diff.reshape(B, N, 2)\n    preds += .5\n    return preds\n\n\nclass AddCoordsTh(nn.Module):\n\n    def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False):\n        super(AddCoordsTh, self).__init__()\n        self.x_dim = x_dim\n        self.y_dim = y_dim\n        self.with_r = with_r\n        self.with_boundary = with_boundary\n\n    def forward(self, input_tensor, heatmap=None):\n        \"\"\"\n        input_tensor: (batch, c, x_dim, y_dim)\n        \"\"\"\n        batch_size_tensor = input_tensor.shape[0]\n\n        xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32, device=input_tensor.device)\n        xx_ones = xx_ones.unsqueeze(-1)\n\n        xx_range = torch.arange(self.x_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0)\n        xx_range = xx_range.unsqueeze(1)\n\n        xx_channel = torch.matmul(xx_ones.float(), xx_range.float())\n        xx_channel = xx_channel.unsqueeze(-1)\n\n        yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32, device=input_tensor.device)\n        yy_ones = yy_ones.unsqueeze(1)\n\n        yy_range = torch.arange(self.y_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0)\n        yy_range = yy_range.unsqueeze(-1)\n\n        yy_channel = torch.matmul(yy_range.float(), yy_ones.float())\n        yy_channel = yy_channel.unsqueeze(-1)\n\n        xx_channel = xx_channel.permute(0, 3, 2, 1)\n        yy_channel = yy_channel.permute(0, 3, 2, 1)\n\n        xx_channel = xx_channel / (self.x_dim - 1)\n        yy_channel = yy_channel / (self.y_dim - 1)\n\n        xx_channel = xx_channel * 2 - 1\n        yy_channel = yy_channel * 2 - 1\n\n        xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)\n        yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)\n\n        if self.with_boundary and heatmap is not None:\n            boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0)\n\n            zero_tensor = torch.zeros_like(xx_channel)\n            xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor)\n            yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor)\n        if self.with_boundary and heatmap is not None:\n            xx_boundary_channel = xx_boundary_channel.to(input_tensor.device)\n            yy_boundary_channel = yy_boundary_channel.to(input_tensor.device)\n        ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)\n\n        if self.with_r:\n            rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))\n            rr = rr / torch.max(rr)\n            ret = torch.cat([ret, rr], dim=1)\n\n        if self.with_boundary and heatmap is not None:\n            ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1)\n        return ret\n\n\nclass CoordConvTh(nn.Module):\n    \"\"\"CoordConv layer as in the paper.\"\"\"\n\n    def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs):\n        super(CoordConvTh, self).__init__()\n        self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary)\n        in_channels += 2\n        if with_r:\n            in_channels += 1\n        if with_boundary and not first_one:\n            in_channels += 2\n        self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)\n\n    def forward(self, input_tensor, heatmap=None):\n        ret = self.addcoords(input_tensor, heatmap)\n        last_channel = ret[:, -2:, :, :]\n        ret = self.conv(ret)\n        return ret, last_channel\n\n\ndef conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilation=1):\n    '3x3 convolution with padding'\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias, dilation=dilation)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        # self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        # self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ConvBlock(nn.Module):\n\n    def __init__(self, in_planes, out_planes):\n        super(ConvBlock, self).__init__()\n        self.bn1 = nn.BatchNorm2d(in_planes)\n        self.conv1 = conv3x3(in_planes, int(out_planes / 2))\n        self.bn2 = nn.BatchNorm2d(int(out_planes / 2))\n        self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), padding=1, dilation=1)\n        self.bn3 = nn.BatchNorm2d(int(out_planes / 4))\n        self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), padding=1, dilation=1)\n\n        if in_planes != out_planes:\n            self.downsample = nn.Sequential(\n                nn.BatchNorm2d(in_planes),\n                nn.ReLU(True),\n                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False),\n            )\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        residual = x\n\n        out1 = self.bn1(x)\n        out1 = F.relu(out1, True)\n        out1 = self.conv1(out1)\n\n        out2 = self.bn2(out1)\n        out2 = F.relu(out2, True)\n        out2 = self.conv2(out2)\n\n        out3 = self.bn3(out2)\n        out3 = F.relu(out3, True)\n        out3 = self.conv3(out3)\n\n        out3 = torch.cat((out1, out2, out3), 1)\n\n        if self.downsample is not None:\n            residual = self.downsample(residual)\n\n        out3 += residual\n\n        return out3\n\n\nclass HourGlass(nn.Module):\n\n    def __init__(self, num_modules, depth, num_features, first_one=False):\n        super(HourGlass, self).__init__()\n        self.num_modules = num_modules\n        self.depth = depth\n        self.features = num_features\n        self.coordconv = CoordConvTh(\n            x_dim=64,\n            y_dim=64,\n            with_r=True,\n            with_boundary=True,\n            in_channels=256,\n            first_one=first_one,\n            out_channels=256,\n            kernel_size=1,\n            stride=1,\n            padding=0)\n        self._generate_network(self.depth)\n\n    def _generate_network(self, level):\n        self.add_module('b1_' + str(level), ConvBlock(256, 256))\n\n        self.add_module('b2_' + str(level), ConvBlock(256, 256))\n\n        if level > 1:\n            self._generate_network(level - 1)\n        else:\n            self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))\n\n        self.add_module('b3_' + str(level), ConvBlock(256, 256))\n\n    def _forward(self, level, inp):\n        # Upper branch\n        up1 = inp\n        up1 = self._modules['b1_' + str(level)](up1)\n\n        # Lower branch\n        low1 = F.avg_pool2d(inp, 2, stride=2)\n        low1 = self._modules['b2_' + str(level)](low1)\n\n        if level > 1:\n            low2 = self._forward(level - 1, low1)\n        else:\n            low2 = low1\n            low2 = self._modules['b2_plus_' + str(level)](low2)\n\n        low3 = low2\n        low3 = self._modules['b3_' + str(level)](low3)\n\n        up2 = F.interpolate(low3, scale_factor=2, mode='nearest')\n\n        return up1 + up2\n\n    def forward(self, x, heatmap):\n        x, last_channel = self.coordconv(x, heatmap)\n        return self._forward(self.depth, x), last_channel\n\n\nclass FAN(nn.Module):\n\n    def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68, device='cuda'):\n        super(FAN, self).__init__()\n        self.device = device\n        self.num_modules = num_modules\n        self.gray_scale = gray_scale\n        self.end_relu = end_relu\n        self.num_landmarks = num_landmarks\n\n        # Base part\n        if self.gray_scale:\n            self.conv1 = CoordConvTh(\n                x_dim=256,\n                y_dim=256,\n                with_r=True,\n                with_boundary=False,\n                in_channels=3,\n                out_channels=64,\n                kernel_size=7,\n                stride=2,\n                padding=3)\n        else:\n            self.conv1 = CoordConvTh(\n                x_dim=256,\n                y_dim=256,\n                with_r=True,\n                with_boundary=False,\n                in_channels=3,\n                out_channels=64,\n                kernel_size=7,\n                stride=2,\n                padding=3)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.conv2 = ConvBlock(64, 128)\n        self.conv3 = ConvBlock(128, 128)\n        self.conv4 = ConvBlock(128, 256)\n\n        # Stacking part\n        for hg_module in range(self.num_modules):\n            if hg_module == 0:\n                first_one = True\n            else:\n                first_one = False\n            self.add_module('m' + str(hg_module), HourGlass(1, 4, 256, first_one))\n            self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))\n            self.add_module('conv_last' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))\n            self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))\n            self.add_module('l' + str(hg_module), nn.Conv2d(256, num_landmarks + 1, kernel_size=1, stride=1, padding=0))\n\n            if hg_module < self.num_modules - 1:\n                self.add_module('bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))\n                self.add_module('al' + str(hg_module),\n                                nn.Conv2d(num_landmarks + 1, 256, kernel_size=1, stride=1, padding=0))\n\n    def forward(self, x):\n        x, _ = self.conv1(x)\n        x = F.relu(self.bn1(x), True)\n        # x = F.relu(self.bn1(self.conv1(x)), True)\n        x = F.avg_pool2d(self.conv2(x), 2, stride=2)\n        x = self.conv3(x)\n        x = self.conv4(x)\n\n        previous = x\n\n        outputs = []\n        boundary_channels = []\n        tmp_out = None\n        for i in range(self.num_modules):\n            hg, boundary_channel = self._modules['m' + str(i)](previous, tmp_out)\n\n            ll = hg\n            ll = self._modules['top_m_' + str(i)](ll)\n\n            ll = F.relu(self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)), True)\n\n            # Predict heatmaps\n            tmp_out = self._modules['l' + str(i)](ll)\n            if self.end_relu:\n                tmp_out = F.relu(tmp_out)  # HACK: Added relu\n            outputs.append(tmp_out)\n            boundary_channels.append(boundary_channel)\n\n            if i < self.num_modules - 1:\n                ll = self._modules['bl' + str(i)](ll)\n                tmp_out_ = self._modules['al' + str(i)](tmp_out)\n                previous = previous + ll + tmp_out_\n\n        return outputs, boundary_channels\n\n    def get_landmarks(self, img):\n        H, W, _ = img.shape\n        offset = W / 64, H / 64, 0, 0\n\n        img = cv2.resize(img, (256, 256))\n        inp = img[..., ::-1]\n        inp = torch.from_numpy(np.ascontiguousarray(inp.transpose((2, 0, 1)))).float()\n        inp = inp.to(self.device)\n        inp.div_(255.0).unsqueeze_(0)\n\n        outputs, _ = self.forward(inp)\n        out = outputs[-1][:, :-1, :, :]\n        heatmaps = out.detach().cpu().numpy()\n\n        pred = calculate_points(heatmaps).reshape(-1, 2)\n\n        pred *= offset[:2]\n        pred += offset[-2:]\n\n        return pred\n"
  },
  {
    "path": "src/face3d/util/nvdiffrast.py",
    "content": "\"\"\"This script is the differentiable renderer for Deep3DFaceRecon_pytorch\n    Attention, antialiasing step is missing in current version.\n\"\"\"\nimport pytorch3d.ops\nimport torch\nimport torch.nn.functional as F\nimport kornia\nfrom kornia.geometry.camera import pixel2cam\nimport numpy as np\nfrom typing import List\nfrom scipy.io import loadmat\nfrom torch import nn\n\nfrom pytorch3d.structures import Meshes\nfrom pytorch3d.renderer import (\n    look_at_view_transform,\n    FoVPerspectiveCameras,\n    DirectionalLights,\n    RasterizationSettings,\n    MeshRenderer,\n    MeshRasterizer,\n    SoftPhongShader,\n    TexturesUV,\n)\n\n# def ndc_projection(x=0.1, n=1.0, f=50.0):\n#     return np.array([[n/x,    0,            0,              0],\n#                      [  0, n/-x,            0,              0],\n#                      [  0,    0, -(f+n)/(f-n), -(2*f*n)/(f-n)],\n#                      [  0,    0,           -1,              0]]).astype(np.float32)\n\nclass MeshRenderer(nn.Module):\n    def __init__(self,\n                rasterize_fov,\n                znear=0.1,\n                zfar=10, \n                rasterize_size=224):\n        super(MeshRenderer, self).__init__()\n\n        # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear\n        # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(\n        #         torch.diag(torch.tensor([1., -1, -1, 1])))\n        self.rasterize_size = rasterize_size\n        self.fov = rasterize_fov\n        self.znear = znear\n        self.zfar = zfar\n\n        self.rasterizer = None\n    \n    def forward(self, vertex, tri, feat=None):\n        \"\"\"\n        Return:\n            mask               -- torch.tensor, size (B, 1, H, W)\n            depth              -- torch.tensor, size (B, 1, H, W)\n            features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None\n\n        Parameters:\n            vertex          -- torch.tensor, size (B, N, 3)\n            tri             -- torch.tensor, size (B, M, 3) or (M, 3), triangles\n            feat(optional)  -- torch.tensor, size (B, N ,C), features\n        \"\"\"\n        device = vertex.device\n        rsize = int(self.rasterize_size)\n        # ndc_proj = self.ndc_proj.to(device)\n        # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v\n        if vertex.shape[-1] == 3:\n            vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)\n            vertex[..., 0] = -vertex[..., 0]\n\n\n        # vertex_ndc = vertex @ ndc_proj.t()\n        if self.rasterizer is None:\n            self.rasterizer = MeshRasterizer()\n            print(\"create rasterizer on device cuda:%d\"%device.index)\n        \n        # ranges = None\n        # if isinstance(tri, List) or len(tri.shape) == 3:\n        #     vum = vertex_ndc.shape[1]\n        #     fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)\n        #     fstartidx = torch.cumsum(fnum, dim=0) - fnum\n        #     ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()\n        #     for i in range(tri.shape[0]):\n        #         tri[i] = tri[i] + i*vum\n        #     vertex_ndc = torch.cat(vertex_ndc, dim=0)\n        #     tri = torch.cat(tri, dim=0)\n\n        # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]\n        tri = tri.type(torch.int32).contiguous()\n\n        # rasterize\n        cameras = FoVPerspectiveCameras(\n            device=device,\n            fov=self.fov,\n            znear=self.znear,\n            zfar=self.zfar,\n        )\n\n        raster_settings = RasterizationSettings(\n            image_size=rsize\n        )\n\n        # print(vertex.shape, tri.shape)\n        mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1)))\n\n        fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)\n        rast_out = fragments.pix_to_face.squeeze(-1)\n        depth = fragments.zbuf\n\n        # render depth\n        depth = depth.permute(0, 3, 1, 2)\n        mask = (rast_out > 0).float().unsqueeze(1)\n        depth = mask * depth\n        \n\n        image = None\n        if feat is not None:\n            attributes = feat.reshape(-1,3)[mesh.faces_packed()]\n            image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,\n                                                      fragments.bary_coords,\n                                                      attributes)\n            # print(image.shape)\n            image = image.squeeze(-2).permute(0, 3, 1, 2)\n            image = mask * image\n        \n        return mask, depth, image\n\n"
  },
  {
    "path": "src/face3d/util/preprocess.py",
    "content": "\"\"\"This script contains the image preprocessing code for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nfrom scipy.io import loadmat\nfrom PIL import Image\nimport cv2\nimport os\nfrom skimage import transform as trans\nimport torch\nimport warnings\nwarnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning) \nwarnings.filterwarnings(\"ignore\", category=FutureWarning) \n\n\n# calculating least square problem for image alignment\ndef POS(xp, x):\n    npts = xp.shape[1]\n\n    A = np.zeros([2*npts, 8])\n\n    A[0:2*npts-1:2, 0:3] = x.transpose()\n    A[0:2*npts-1:2, 3] = 1\n\n    A[1:2*npts:2, 4:7] = x.transpose()\n    A[1:2*npts:2, 7] = 1\n\n    b = np.reshape(xp.transpose(), [2*npts, 1])\n\n    k, _, _, _ = np.linalg.lstsq(A, b)\n\n    R1 = k[0:3]\n    R2 = k[4:7]\n    sTx = k[3]\n    sTy = k[7]\n    s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2\n    t = np.stack([sTx, sTy], axis=0)\n\n    return t, s\n    \n# resize and crop images for face reconstruction\ndef resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):\n    w0, h0 = img.size\n    w = (w0*s).astype(np.int32)\n    h = (h0*s).astype(np.int32)\n    left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)\n    right = left + target_size\n    up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)\n    below = up + target_size\n\n    img = img.resize((w, h), resample=Image.BICUBIC)\n    img = img.crop((left, up, right, below))\n\n    if mask is not None:\n        mask = mask.resize((w, h), resample=Image.BICUBIC)\n        mask = mask.crop((left, up, right, below))\n\n    lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -\n                  t[1] + h0/2], axis=1)*s\n    lm = lm - np.reshape(\n            np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])\n\n    return img, lm, mask\n\n# utils for face reconstruction\ndef extract_5p(lm):\n    lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1\n    lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(\n        lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)\n    lm5p = lm5p[[1, 2, 0, 3, 4], :]\n    return lm5p\n\n# utils for face reconstruction\ndef align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):\n    \"\"\"\n    Return:\n        transparams        --numpy.array  (raw_W, raw_H, scale, tx, ty)\n        img_new            --PIL.Image  (target_size, target_size, 3)\n        lm_new             --numpy.array  (68, 2), y direction is opposite to v direction\n        mask_new           --PIL.Image  (target_size, target_size)\n    \n    Parameters:\n        img                --PIL.Image  (raw_H, raw_W, 3)\n        lm                 --numpy.array  (68, 2), y direction is opposite to v direction\n        lm3D               --numpy.array  (5, 3)\n        mask               --PIL.Image  (raw_H, raw_W, 3)\n    \"\"\"\n\n    w0, h0 = img.size\n    if lm.shape[0] != 5:\n        lm5p = extract_5p(lm)\n    else:\n        lm5p = lm\n\n    # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face\n    t, s = POS(lm5p.transpose(), lm3D.transpose())\n    s = rescale_factor/s\n\n    # processing the image\n    img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)\n    trans_params = np.array([w0, h0, s, t[0], t[1]])\n\n    return trans_params, img_new, lm_new, mask_new\n"
  },
  {
    "path": "src/face3d/util/skin_mask.py",
    "content": "\"\"\"This script is to generate skin attention mask for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport math\nimport numpy as np\nimport os\nimport cv2\n\nclass GMM:\n    def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):\n        self.dim = dim # feature dimension\n        self.num = num # number of Gaussian components\n        self.w = w # weights of Gaussian components (a list of scalars)\n        self.mu= mu # mean of Gaussian components (a list of 1xdim vectors)\n        self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices)\n        self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars)\n        self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices)\n\n        self.factor = [0]*num\n        for i in range(self.num):\n            self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5\n        \n    def likelihood(self, data):\n        assert(data.shape[1] == self.dim)\n        N = data.shape[0]\n        lh = np.zeros(N)\n\n        for i in range(self.num):\n            data_ = data - self.mu[i]\n\n            tmp = np.matmul(data_,self.cov_inv[i]) * data_\n            tmp = np.sum(tmp,axis=1)\n            power = -0.5 * tmp\n\n            p = np.array([math.exp(power[j]) for j in range(N)])\n            p = p/self.factor[i]\n            lh += p*self.w[i]\n        \n        return lh\n\n\ndef _rgb2ycbcr(rgb):\n    m = np.array([[65.481, 128.553, 24.966],\n                  [-37.797, -74.203, 112],\n                  [112, -93.786, -18.214]])\n    shape = rgb.shape\n    rgb = rgb.reshape((shape[0] * shape[1], 3))\n    ycbcr = np.dot(rgb, m.transpose() / 255.)\n    ycbcr[:, 0] += 16.\n    ycbcr[:, 1:] += 128.\n    return ycbcr.reshape(shape)\n\n\ndef _bgr2ycbcr(bgr):\n    rgb = bgr[..., ::-1]\n    return _rgb2ycbcr(rgb)\n\n\ngmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415]\ngmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]),\n                np.array([150.19858, 105.18467, 155.51428]),\n                np.array([183.92976, 107.62468, 152.71820]),\n                np.array([114.90524, 113.59782, 151.38217])]\ngmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.]\ngmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]),\n                    np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]),\n                    np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]),\n                    np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])]\n\ngmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv)\n\ngmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393]\ngmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]),\n                    np.array([110.91392, 125.52969, 130.19237]),\n                    np.array([129.75864, 129.96107, 126.96808]),\n                    np.array([112.29587, 128.85121, 129.05431])]\ngmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63]\ngmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]),\n                    np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]),\n                    np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]),\n                    np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])]\n\ngmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv)\n\nprior_skin = 0.8\nprior_nonskin = 1 - prior_skin\n\n\n# calculate skin attention mask\ndef skinmask(imbgr):\n    im = _bgr2ycbcr(imbgr)\n\n    data = im.reshape((-1,3))\n\n    lh_skin = gmm_skin.likelihood(data)\n    lh_nonskin = gmm_nonskin.likelihood(data)\n\n    tmp1 = prior_skin * lh_skin\n    tmp2 = prior_nonskin * lh_nonskin\n    post_skin = tmp1 / (tmp1+tmp2) # posterior probability\n\n    post_skin = post_skin.reshape((im.shape[0],im.shape[1]))\n\n    post_skin = np.round(post_skin*255)\n    post_skin = post_skin.astype(np.uint8)\n    post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3\n\n    return post_skin\n\n\ndef get_skin_mask(img_path):\n    print('generating skin masks......')\n    names = [i for i in sorted(os.listdir(\n        img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]\n    save_path = os.path.join(img_path, 'mask')\n    if not os.path.isdir(save_path):\n        os.makedirs(save_path)\n    \n    for i in range(0, len(names)):\n        name = names[i]\n        print('%05d' % (i), ' ', name)\n        full_image_name = os.path.join(img_path, name)\n        img = cv2.imread(full_image_name).astype(np.float32)\n        skin_img = skinmask(img)\n        cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8))\n"
  },
  {
    "path": "src/face3d/util/test_mean_face.txt",
    "content": "-5.228591537475585938e+01\n2.078247070312500000e-01\n-5.064269638061523438e+01\n-1.315765380859375000e+01\n-4.952939224243164062e+01\n-2.592591094970703125e+01\n-4.793047332763671875e+01\n-3.832135772705078125e+01\n-4.512159729003906250e+01\n-5.059623336791992188e+01\n-3.917720794677734375e+01\n-6.043736648559570312e+01\n-2.929953765869140625e+01\n-6.861183166503906250e+01\n-1.719801330566406250e+01\n-7.572736358642578125e+01\n-1.961936950683593750e+00\n-7.862001037597656250e+01\n1.467941284179687500e+01\n-7.607844543457031250e+01\n2.744073486328125000e+01\n-6.915261840820312500e+01\n3.855677795410156250e+01\n-5.950350570678710938e+01\n4.478240966796875000e+01\n-4.867547225952148438e+01\n4.714337158203125000e+01\n-3.800830078125000000e+01\n4.940315246582031250e+01\n-2.496297454833984375e+01\n5.117234802246093750e+01\n-1.241538238525390625e+01\n5.190507507324218750e+01\n8.244247436523437500e-01\n-4.150688934326171875e+01\n2.386329650878906250e+01\n-3.570307159423828125e+01\n3.017010498046875000e+01\n-2.790358734130859375e+01\n3.212951660156250000e+01\n-1.941773223876953125e+01\n3.156523132324218750e+01\n-1.138106536865234375e+01\n2.841992187500000000e+01\n5.993263244628906250e+00\n2.895182800292968750e+01\n1.343590545654296875e+01\n3.189880371093750000e+01\n2.203153991699218750e+01\n3.302221679687500000e+01\n2.992478942871093750e+01\n3.099150085449218750e+01\n3.628388977050781250e+01\n2.765748596191406250e+01\n-1.933914184570312500e+00\n1.405374145507812500e+01\n-2.153038024902343750e+00\n5.772636413574218750e+00\n-2.270050048828125000e+00\n-2.121643066406250000e+00\n-2.218330383300781250e+00\n-1.068978118896484375e+01\n-1.187252044677734375e+01\n-1.997912597656250000e+01\n-6.879402160644531250e+00\n-2.143579864501953125e+01\n-1.227821350097656250e+00\n-2.193494415283203125e+01\n4.623237609863281250e+00\n-2.152721405029296875e+01\n9.721397399902343750e+00\n-1.953671264648437500e+01\n-3.648714447021484375e+01\n9.811126708984375000e+00\n-3.130242919921875000e+01\n1.422447967529296875e+01\n-2.212834930419921875e+01\n1.493019866943359375e+01\n-1.500880432128906250e+01\n1.073588562011718750e+01\n-2.095037078857421875e+01\n9.054298400878906250e+00\n-3.050099182128906250e+01\n8.704177856445312500e+00\n1.173237609863281250e+01\n1.054329681396484375e+01\n1.856353759765625000e+01\n1.535009765625000000e+01\n2.893331909179687500e+01\n1.451992797851562500e+01\n3.452944946289062500e+01\n1.065280151367187500e+01\n2.875990295410156250e+01\n8.654792785644531250e+00\n1.942100524902343750e+01\n9.422447204589843750e+00\n-2.204488372802734375e+01\n-3.983994293212890625e+01\n-1.324458312988281250e+01\n-3.467377471923828125e+01\n-6.749649047851562500e+00\n-3.092894744873046875e+01\n-9.183349609375000000e-01\n-3.196458435058593750e+01\n4.220649719238281250e+00\n-3.090406036376953125e+01\n1.089889526367187500e+01\n-3.497008514404296875e+01\n1.874589538574218750e+01\n-4.065438079833984375e+01\n1.124106597900390625e+01\n-4.438417816162109375e+01\n5.181709289550781250e+00\n-4.649170684814453125e+01\n-1.158607482910156250e+00\n-4.680406951904296875e+01\n-7.918922424316406250e+00\n-4.671575164794921875e+01\n-1.452505493164062500e+01\n-4.416526031494140625e+01\n-2.005007171630859375e+01\n-3.997841644287109375e+01\n-1.054919433593750000e+01\n-3.849683380126953125e+01\n-1.051826477050781250e+00\n-3.794863128662109375e+01\n6.412681579589843750e+00\n-3.804645538330078125e+01\n1.627674865722656250e+01\n-4.039697265625000000e+01\n6.373878479003906250e+00\n-4.087213897705078125e+01\n-8.551712036132812500e-01\n-4.157129669189453125e+01\n-1.014953613281250000e+01\n-4.128469085693359375e+01\n"
  },
  {
    "path": "src/face3d/util/util.py",
    "content": "\"\"\"This script contains basic utilities for Deep3DFaceRecon_pytorch\n\"\"\"\nfrom __future__ import print_function\nimport numpy as np\nimport torch\nfrom PIL import Image\nimport os\nimport importlib\nimport argparse\nfrom argparse import Namespace\nimport torchvision\n\n\ndef str2bool(v):\n    if isinstance(v, bool):\n        return v\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\n\n\ndef copyconf(default_opt, **kwargs):\n    conf = Namespace(**vars(default_opt))\n    for key in kwargs:\n        setattr(conf, key, kwargs[key])\n    return conf\n\ndef genvalconf(train_opt, **kwargs):\n    conf = Namespace(**vars(train_opt))\n    attr_dict = train_opt.__dict__\n    for key, value in attr_dict.items():\n        if 'val' in key and key.split('_')[0] in attr_dict:\n            setattr(conf, key.split('_')[0], value)\n\n    for key in kwargs:\n        setattr(conf, key, kwargs[key])\n\n    return conf\n        \ndef find_class_in_module(target_cls_name, module):\n    target_cls_name = target_cls_name.replace('_', '').lower()\n    clslib = importlib.import_module(module)\n    cls = None\n    for name, clsobj in clslib.__dict__.items():\n        if name.lower() == target_cls_name:\n            cls = clsobj\n\n    assert cls is not None, \"In %s, there should be a class whose name matches %s in lowercase without underscore(_)\" % (module, target_cls_name)\n\n    return cls\n\n\ndef tensor2im(input_image, imtype=np.uint8):\n    \"\"\"\"Converts a Tensor array into a numpy image array.\n\n    Parameters:\n        input_image (tensor) --  the input image tensor array, range(0, 1)\n        imtype (type)        --  the desired type of the converted numpy array\n    \"\"\"\n    if not isinstance(input_image, np.ndarray):\n        if isinstance(input_image, torch.Tensor):  # get the data from a variable\n            image_tensor = input_image.data\n        else:\n            return input_image\n        image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy()  # convert it into a numpy array\n        if image_numpy.shape[0] == 1:  # grayscale to RGB\n            image_numpy = np.tile(image_numpy, (3, 1, 1))\n        image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0  # post-processing: tranpose and scaling\n    else:  # if it is a numpy array, do nothing\n        image_numpy = input_image\n    return image_numpy.astype(imtype)\n\n\ndef diagnose_network(net, name='network'):\n    \"\"\"Calculate and print the mean of average absolute(gradients)\n\n    Parameters:\n        net (torch network) -- Torch network\n        name (str) -- the name of the network\n    \"\"\"\n    mean = 0.0\n    count = 0\n    for param in net.parameters():\n        if param.grad is not None:\n            mean += torch.mean(torch.abs(param.grad.data))\n            count += 1\n    if count > 0:\n        mean = mean / count\n    print(name)\n    print(mean)\n\n\ndef save_image(image_numpy, image_path, aspect_ratio=1.0):\n    \"\"\"Save a numpy image to the disk\n\n    Parameters:\n        image_numpy (numpy array) -- input numpy array\n        image_path (str)          -- the path of the image\n    \"\"\"\n\n    image_pil = Image.fromarray(image_numpy)\n    h, w, _ = image_numpy.shape\n\n    if aspect_ratio is None:\n        pass\n    elif aspect_ratio > 1.0:\n        image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)\n    elif aspect_ratio < 1.0:\n        image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)\n    image_pil.save(image_path)\n\n\ndef print_numpy(x, val=True, shp=False):\n    \"\"\"Print the mean, min, max, median, std, and size of a numpy array\n\n    Parameters:\n        val (bool) -- if print the values of the numpy array\n        shp (bool) -- if print the shape of the numpy array\n    \"\"\"\n    x = x.astype(np.float64)\n    if shp:\n        print('shape,', x.shape)\n    if val:\n        x = x.flatten()\n        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (\n            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))\n\n\ndef mkdirs(paths):\n    \"\"\"create empty directories if they don't exist\n\n    Parameters:\n        paths (str list) -- a list of directory paths\n    \"\"\"\n    if isinstance(paths, list) and not isinstance(paths, str):\n        for path in paths:\n            mkdir(path)\n    else:\n        mkdir(paths)\n\n\ndef mkdir(path):\n    \"\"\"create a single empty directory if it didn't exist\n\n    Parameters:\n        path (str) -- a single directory path\n    \"\"\"\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\ndef correct_resize_label(t, size):\n    device = t.device\n    t = t.detach().cpu()\n    resized = []\n    for i in range(t.size(0)):\n        one_t = t[i, :1]\n        one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))\n        one_np = one_np[:, :, 0]\n        one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)\n        resized_t = torch.from_numpy(np.array(one_image)).long()\n        resized.append(resized_t)\n    return torch.stack(resized, dim=0).to(device)\n\n\ndef correct_resize(t, size, mode=Image.BICUBIC):\n    device = t.device\n    t = t.detach().cpu()\n    resized = []\n    for i in range(t.size(0)):\n        one_t = t[i:i + 1]\n        one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)\n        resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0\n        resized.append(resized_t)\n    return torch.stack(resized, dim=0).to(device)\n\ndef draw_landmarks(img, landmark, color='r', step=2):\n    \"\"\"\n    Return:\n        img              -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)\n        \n\n    Parameters:\n        img              -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)\n        landmark         -- numpy.array, (B, 68, 2), y direction is opposite to v direction\n        color            -- str, 'r' or 'b' (red or blue)\n    \"\"\"\n    if color =='r':\n        c = np.array([255., 0, 0])\n    else:\n        c = np.array([0, 0, 255.])\n\n    _, H, W, _ = img.shape\n    img, landmark = img.copy(), landmark.copy()\n    landmark[..., 1] = H - 1 - landmark[..., 1]\n    landmark = np.round(landmark).astype(np.int32)\n    for i in range(landmark.shape[1]):\n        x, y = landmark[:, i, 0], landmark[:, i, 1]\n        for j in range(-step, step):\n            for k in range(-step, step):\n                u = np.clip(x + j, 0, W - 1)\n                v = np.clip(y + k, 0, H - 1)\n                for m in range(landmark.shape[0]):\n                    img[m, v[m], u[m]] = c\n    return img\n"
  },
  {
    "path": "src/face3d/util/visualizer.py",
    "content": "\"\"\"This script defines the visualizer for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nimport os\nimport sys\nimport ntpath\nimport time\nfrom . import util, html\nfrom subprocess import Popen, PIPE\nfrom torch.utils.tensorboard import SummaryWriter\n\ndef save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):\n    \"\"\"Save images to the disk.\n\n    Parameters:\n        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)\n        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs\n        image_path (str)         -- the string is used to create image paths\n        aspect_ratio (float)     -- the aspect ratio of saved images\n        width (int)              -- the images will be resized to width x width\n\n    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.\n    \"\"\"\n    image_dir = webpage.get_image_dir()\n    short_path = ntpath.basename(image_path[0])\n    name = os.path.splitext(short_path)[0]\n\n    webpage.add_header(name)\n    ims, txts, links = [], [], []\n\n    for label, im_data in visuals.items():\n        im = util.tensor2im(im_data)\n        image_name = '%s/%s.png' % (label, name)\n        os.makedirs(os.path.join(image_dir, label), exist_ok=True)\n        save_path = os.path.join(image_dir, image_name)\n        util.save_image(im, save_path, aspect_ratio=aspect_ratio)\n        ims.append(image_name)\n        txts.append(label)\n        links.append(image_name)\n    webpage.add_images(ims, txts, links, width=width)\n\n\nclass Visualizer():\n    \"\"\"This class includes several functions that can display/save images and print/save logging information.\n\n    It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize the Visualizer class\n\n        Parameters:\n            opt -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        Step 1: Cache the training/test options\n        Step 2: create a tensorboard writer\n        Step 3: create an HTML object for saveing HTML filters\n        Step 4: create a logging file to store training losses\n        \"\"\"\n        self.opt = opt  # cache the option\n        self.use_html = opt.isTrain and not opt.no_html\n        self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name))\n        self.win_size = opt.display_winsize\n        self.name = opt.name\n        self.saved = False\n        if self.use_html:  # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/\n            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')\n            self.img_dir = os.path.join(self.web_dir, 'images')\n            print('create web directory %s...' % self.web_dir)\n            util.mkdirs([self.web_dir, self.img_dir])\n        # create a logging file to store training losses\n        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')\n        with open(self.log_name, \"a\") as log_file:\n            now = time.strftime(\"%c\")\n            log_file.write('================ Training Loss (%s) ================\\n' % now)\n\n    def reset(self):\n        \"\"\"Reset the self.saved status\"\"\"\n        self.saved = False\n\n\n    def display_current_results(self, visuals, total_iters, epoch, save_result):\n        \"\"\"Display current results on tensorboad; save current results to an HTML file.\n\n        Parameters:\n            visuals (OrderedDict) - - dictionary of images to display or save\n            total_iters (int) -- total iterations\n            epoch (int) - - the current epoch\n            save_result (bool) - - if save the current results to an HTML file\n        \"\"\"\n        for label, image in visuals.items():\n            self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC')\n\n        if self.use_html and (save_result or not self.saved):  # save images to an HTML file if they haven't been saved.\n            self.saved = True\n            # save images to the disk\n            for label, image in visuals.items():\n                image_numpy = util.tensor2im(image)\n                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))\n                util.save_image(image_numpy, img_path)\n\n            # update website\n            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)\n            for n in range(epoch, 0, -1):\n                webpage.add_header('epoch [%d]' % n)\n                ims, txts, links = [], [], []\n\n                for label, image_numpy in visuals.items():\n                    image_numpy = util.tensor2im(image)\n                    img_path = 'epoch%.3d_%s.png' % (n, label)\n                    ims.append(img_path)\n                    txts.append(label)\n                    links.append(img_path)\n                webpage.add_images(ims, txts, links, width=self.win_size)\n            webpage.save()\n\n    def plot_current_losses(self, total_iters, losses):\n        # G_loss_collection = {}\n        # D_loss_collection = {}\n        # for name, value in losses.items():\n        #     if 'G' in name or 'NCE' in name or 'idt' in name:\n        #         G_loss_collection[name] = value\n        #     else:\n        #         D_loss_collection[name] = value\n        # self.writer.add_scalars('G_collec', G_loss_collection, total_iters)\n        # self.writer.add_scalars('D_collec', D_loss_collection, total_iters)\n        for name, value in losses.items():\n            self.writer.add_scalar(name, value, total_iters)\n\n    # losses: same format as |losses| of plot_current_losses\n    def print_current_losses(self, epoch, iters, losses, t_comp, t_data):\n        \"\"\"print current losses on console; also save the losses to the disk\n\n        Parameters:\n            epoch (int) -- current epoch\n            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)\n            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs\n            t_comp (float) -- computational time per data point (normalized by batch_size)\n            t_data (float) -- data loading time per data point (normalized by batch_size)\n        \"\"\"\n        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)\n        for k, v in losses.items():\n            message += '%s: %.3f ' % (k, v)\n\n        print(message)  # print the message\n        with open(self.log_name, \"a\") as log_file:\n            log_file.write('%s\\n' % message)  # save the message\n\n\nclass MyVisualizer:\n    def __init__(self, opt):\n        \"\"\"Initialize the Visualizer class\n\n        Parameters:\n            opt -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        Step 1: Cache the training/test options\n        Step 2: create a tensorboard writer\n        Step 3: create an HTML object for saveing HTML filters\n        Step 4: create a logging file to store training losses\n        \"\"\"\n        self.opt = opt  # cache the optio\n        self.name = opt.name\n        self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results')\n        \n        if opt.phase != 'test':\n            self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs'))\n            # create a logging file to store training losses\n            self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')\n            with open(self.log_name, \"a\") as log_file:\n                now = time.strftime(\"%c\")\n                log_file.write('================ Training Loss (%s) ================\\n' % now)\n\n\n    def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None,\n            add_image=True):\n        \"\"\"Display current results on tensorboad; save current results to an HTML file.\n\n        Parameters:\n            visuals (OrderedDict) - - dictionary of images to display or save\n            total_iters (int) -- total iterations\n            epoch (int) - - the current epoch\n            dataset (str) - - 'train' or 'val' or 'test'\n        \"\"\"\n        # if (not add_image) and (not save_results): return\n        \n        for label, image in visuals.items():\n            for i in range(image.shape[0]):\n                image_numpy = util.tensor2im(image[i])\n                if add_image:\n                    self.writer.add_image(label + '%s_%02d'%(dataset, i + count),\n                            image_numpy, total_iters, dataformats='HWC')\n\n                if save_results:\n                    save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters))\n                    if not os.path.isdir(save_path):\n                        os.makedirs(save_path)\n\n                    if name is not None:\n                        img_path = os.path.join(save_path, '%s.png' % name)\n                    else:\n                        img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count))\n                    util.save_image(image_numpy, img_path)\n\n\n    def plot_current_losses(self, total_iters, losses, dataset='train'):\n        for name, value in losses.items():\n            self.writer.add_scalar(name + '/%s'%dataset, value, total_iters)\n\n    # losses: same format as |losses| of plot_current_losses\n    def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'):\n        \"\"\"print current losses on console; also save the losses to the disk\n\n        Parameters:\n            epoch (int) -- current epoch\n            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)\n            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs\n            t_comp (float) -- computational time per data point (normalized by batch_size)\n            t_data (float) -- data loading time per data point (normalized by batch_size)\n        \"\"\"\n        message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (\n            dataset, epoch, iters, t_comp, t_data)\n        for k, v in losses.items():\n            message += '%s: %.3f ' % (k, v)\n\n        print(message)  # print the message\n        with open(self.log_name, \"a\") as log_file:\n            log_file.write('%s\\n' % message)  # save the message\n"
  },
  {
    "path": "src/face3d/visualize.py",
    "content": "# check the sync of 3dmm feature and the audio\nimport cv2\nimport numpy as np\nfrom src.face3d.models.bfm import ParametricFaceModel\nfrom src.face3d.models.facerecon_model import FaceReconModel\nimport torch\nimport subprocess, platform\nimport scipy.io as scio\nfrom tqdm import tqdm \n\n# draft\ndef gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64):\n    \n    coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']\n\n    coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']\n\n    coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257\n\n    coeff_full[:, 80:144] = coeff_pred[:, 0:64]\n    coeff_full[:, 224:227]  = coeff_pred[:, 64:67] # 3 dim translation\n    coeff_full[:, 254:]  = coeff_pred[:, 67:] # 3 dim translation\n\n    tmp_video_path = '/tmp/face3dtmp.mp4'\n\n    facemodel = FaceReconModel(args)\n    \n    video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))\n\n    for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):\n        cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)\n\n        facemodel.forward(cur_coeff_full, device)\n\n        predicted_landmark = facemodel.pred_lm # TODO.\n        predicted_landmark = predicted_landmark.cpu().numpy().squeeze()\n\n        rendered_img = facemodel.pred_face\n        rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)\n        out_img = rendered_img[:, :, :3].astype(np.uint8)\n\n        video.write(np.uint8(out_img[:,:,::-1]))\n\n    video.release()\n\n    command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)\n    subprocess.call(command, shell=platform.system() != 'Windows')\n\n"
  },
  {
    "path": "src/facerender/animate.py",
    "content": "import os\nimport cv2\nimport yaml\nimport numpy as np\nimport warnings\nfrom skimage import img_as_ubyte\nimport safetensors\nimport safetensors.torch \nwarnings.filterwarnings('ignore')\n\n\nimport imageio\nimport torch\nimport torchvision\n\n\nfrom src.facerender.modules.keypoint_detector import HEEstimator, KPDetector\nfrom src.facerender.modules.mapping import MappingNet\nfrom src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator\nfrom src.facerender.modules.make_animation import make_animation \n\nfrom pydub import AudioSegment \nfrom src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list\nfrom src.utils.paste_pic import paste_pic\nfrom src.utils.videoio import save_video_with_watermark\n\ntry:\n    import webui  # in webui\n    in_webui = True\nexcept:\n    in_webui = False\n\nclass AnimateFromCoeff():\n\n    def __init__(self, sadtalker_path, device):\n\n        with open(sadtalker_path['facerender_yaml']) as f:\n            config = yaml.safe_load(f)\n\n        generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],\n                                                    **config['model_params']['common_params'])\n        kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],\n                                    **config['model_params']['common_params'])\n        he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],\n                               **config['model_params']['common_params'])\n        mapping = MappingNet(**config['model_params']['mapping_params'])\n\n        generator.to(device)\n        kp_extractor.to(device)\n        he_estimator.to(device)\n        mapping.to(device)\n        for param in generator.parameters():\n            param.requires_grad = False\n        for param in kp_extractor.parameters():\n            param.requires_grad = False \n        for param in he_estimator.parameters():\n            param.requires_grad = False\n        for param in mapping.parameters():\n            param.requires_grad = False\n\n        if sadtalker_path is not None:\n            if 'checkpoint' in sadtalker_path: # use safe tensor\n                self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)\n            else:\n                self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)\n        else:\n            raise AttributeError(\"Checkpoint should be specified for video head pose estimator.\")\n\n        if  sadtalker_path['mappingnet_checkpoint'] is not None:\n            self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)\n        else:\n            raise AttributeError(\"Checkpoint should be specified for video head pose estimator.\") \n\n        self.kp_extractor = kp_extractor\n        self.generator = generator\n        self.he_estimator = he_estimator\n        self.mapping = mapping\n\n        self.kp_extractor.eval()\n        self.generator.eval()\n        self.he_estimator.eval()\n        self.mapping.eval()\n         \n        self.device = device\n    \n    def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None, \n                        kp_detector=None, he_estimator=None,  \n                        device=\"cpu\"):\n\n        checkpoint = safetensors.torch.load_file(checkpoint_path)\n\n        if generator is not None:\n            x_generator = {}\n            for k,v in checkpoint.items():\n                if 'generator' in k:\n                    x_generator[k.replace('generator.', '')] = v\n            generator.load_state_dict(x_generator)\n        if kp_detector is not None:\n            x_generator = {}\n            for k,v in checkpoint.items():\n                if 'kp_extractor' in k:\n                    x_generator[k.replace('kp_extractor.', '')] = v\n            kp_detector.load_state_dict(x_generator)\n        if he_estimator is not None:\n            x_generator = {}\n            for k,v in checkpoint.items():\n                if 'he_estimator' in k:\n                    x_generator[k.replace('he_estimator.', '')] = v\n            he_estimator.load_state_dict(x_generator)\n        \n        return None\n\n    def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, \n                        kp_detector=None, he_estimator=None, optimizer_generator=None, \n                        optimizer_discriminator=None, optimizer_kp_detector=None, \n                        optimizer_he_estimator=None, device=\"cpu\"):\n        checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))\n        if generator is not None:\n            generator.load_state_dict(checkpoint['generator'])\n        if kp_detector is not None:\n            kp_detector.load_state_dict(checkpoint['kp_detector'])\n        if he_estimator is not None:\n            he_estimator.load_state_dict(checkpoint['he_estimator'])\n        if discriminator is not None:\n            try:\n               discriminator.load_state_dict(checkpoint['discriminator'])\n            except:\n               print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')\n        if optimizer_generator is not None:\n            optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])\n        if optimizer_discriminator is not None:\n            try:\n                optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])\n            except RuntimeError as e:\n                print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')\n        if optimizer_kp_detector is not None:\n            optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])\n        if optimizer_he_estimator is not None:\n            optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])\n\n        return checkpoint['epoch']\n    \n    def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,\n                 optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):\n        checkpoint = torch.load(checkpoint_path,  map_location=torch.device(device))\n        if mapping is not None:\n            mapping.load_state_dict(checkpoint['mapping'])\n        if discriminator is not None:\n            discriminator.load_state_dict(checkpoint['discriminator'])\n        if optimizer_mapping is not None:\n            optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])\n        if optimizer_discriminator is not None:\n            optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])\n\n        return checkpoint['epoch']\n\n    def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):\n\n        source_image=x['source_image'].type(torch.FloatTensor)\n        source_semantics=x['source_semantics'].type(torch.FloatTensor)\n        target_semantics=x['target_semantics_list'].type(torch.FloatTensor) \n        source_image=source_image.to(self.device)\n        source_semantics=source_semantics.to(self.device)\n        target_semantics=target_semantics.to(self.device)\n        if 'yaw_c_seq' in x:\n            yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)\n            yaw_c_seq = x['yaw_c_seq'].to(self.device)\n        else:\n            yaw_c_seq = None\n        if 'pitch_c_seq' in x:\n            pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)\n            pitch_c_seq = x['pitch_c_seq'].to(self.device)\n        else:\n            pitch_c_seq = None\n        if 'roll_c_seq' in x:\n            roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor) \n            roll_c_seq = x['roll_c_seq'].to(self.device)\n        else:\n            roll_c_seq = None\n\n        frame_num = x['frame_num']\n\n        predictions_video = make_animation(source_image, source_semantics, target_semantics,\n                                        self.generator, self.kp_extractor, self.he_estimator, self.mapping, \n                                        yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)\n\n        predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])\n        predictions_video = predictions_video[:frame_num]\n\n        video = []\n        for idx in range(predictions_video.shape[0]):\n            image = predictions_video[idx]\n            image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)\n            video.append(image)\n        result = img_as_ubyte(video)\n\n        ### the generated video is 256x256, so we keep the aspect ratio, \n        original_size = crop_info[0]\n        if original_size:\n            result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]\n        \n        video_name = x['video_name']  + '.mp4'\n        path = os.path.join(video_save_dir, 'temp_'+video_name)\n        \n        imageio.mimsave(path, result,  fps=float(25))\n\n        av_path = os.path.join(video_save_dir, video_name)\n        return_path = av_path \n        \n        audio_path =  x['audio_path'] \n        audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]\n        new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')\n        start_time = 0\n        # cog will not keep the .mp3 filename\n        sound = AudioSegment.from_file(audio_path)\n        frames = frame_num \n        end_time = start_time + frames*1/25*1000\n        word1=sound.set_frame_rate(16000)\n        word = word1[start_time:end_time]\n        word.export(new_audio_path, format=\"wav\")\n\n        save_video_with_watermark(path, new_audio_path, av_path, watermark= False)\n        print(f'The generated video is named {video_save_dir}/{video_name}') \n\n        if 'full' in preprocess.lower():\n            # only add watermark to the full image.\n            video_name_full = x['video_name']  + '_full.mp4'\n            full_video_path = os.path.join(video_save_dir, video_name_full)\n            return_path = full_video_path\n            paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)\n            print(f'The generated video is named {video_save_dir}/{video_name_full}') \n        else:\n            full_video_path = av_path \n\n        #### paste back then enhancers\n        if enhancer:\n            video_name_enhancer = x['video_name']  + '_enhanced.mp4'\n            enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)\n            av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) \n            return_path = av_path_enhancer\n\n            try:\n                enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)\n                imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))\n            except:\n                enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)\n                imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))\n            \n            save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)\n            print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')\n            os.remove(enhanced_path)\n\n        os.remove(path)\n        os.remove(new_audio_path)\n\n        return return_path\n\n"
  },
  {
    "path": "src/facerender/modules/dense_motion.py",
    "content": "from torch import nn\nimport torch.nn.functional as F\nimport torch\nfrom src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian\n\nfrom src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d\n\n\nclass DenseMotionNetwork(nn.Module):\n    \"\"\"\n    Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving\n    \"\"\"\n\n    def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress,\n                 estimate_occlusion_map=False):\n        super(DenseMotionNetwork, self).__init__()\n        # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks)\n        self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks)\n\n        self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3)\n\n        self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1)\n        self.norm = BatchNorm3d(compress, affine=True)\n\n        if estimate_occlusion_map:\n            # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3)\n            self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)\n        else:\n            self.occlusion = None\n\n        self.num_kp = num_kp\n\n\n    def create_sparse_motions(self, feature, kp_driving, kp_source):\n        bs, _, d, h, w = feature.shape\n        identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type())\n        identity_grid = identity_grid.view(1, 1, d, h, w, 3)\n        coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3)\n        \n        # if 'jacobian' in kp_driving:\n        if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None:\n            jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))\n            jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)\n            jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1)\n            coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))\n            coordinate_grid = coordinate_grid.squeeze(-1)                  \n\n\n        driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3)    # (bs, num_kp, d, h, w, 3)\n\n        #adding background feature\n        identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)\n        sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)                #bs num_kp+1 d h w 3\n        \n        # sparse_motions = driving_to_source\n\n        return sparse_motions\n\n    def create_deformed_feature(self, feature, sparse_motions):\n        bs, _, d, h, w = feature.shape\n        feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1)      # (bs, num_kp+1, 1, c, d, h, w)\n        feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w)                         # (bs*(num_kp+1), c, d, h, w)\n        sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1))                       # (bs*(num_kp+1), d, h, w, 3) !!!!\n        sparse_deformed = F.grid_sample(feature_repeat, sparse_motions)\n        sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w))                        # (bs, num_kp+1, c, d, h, w)\n        return sparse_deformed\n\n    def create_heatmap_representations(self, feature, kp_driving, kp_source):\n        spatial_size = feature.shape[3:]\n        gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01)\n        gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01)\n        heatmap = gaussian_driving - gaussian_source\n\n        # adding background feature\n        zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type())\n        heatmap = torch.cat([zeros, heatmap], dim=1)\n        heatmap = heatmap.unsqueeze(2)         # (bs, num_kp+1, 1, d, h, w)\n        return heatmap\n\n    def forward(self, feature, kp_driving, kp_source):\n        bs, _, d, h, w = feature.shape\n\n        feature = self.compress(feature)\n        feature = self.norm(feature)\n        feature = F.relu(feature)\n\n        out_dict = dict()\n        sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source)\n        deformed_feature = self.create_deformed_feature(feature, sparse_motion)\n\n        heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source)\n\n        input_ = torch.cat([heatmap, deformed_feature], dim=2)\n        input_ = input_.view(bs, -1, d, h, w)\n\n        # input = deformed_feature.view(bs, -1, d, h, w)      # (bs, num_kp+1 * c, d, h, w)\n\n        prediction = self.hourglass(input_)\n\n\n        mask = self.mask(prediction)\n        mask = F.softmax(mask, dim=1)\n        out_dict['mask'] = mask\n        mask = mask.unsqueeze(2)                                   # (bs, num_kp+1, 1, d, h, w)\n        \n        zeros_mask = torch.zeros_like(mask)   \n        mask = torch.where(mask < 1e-3, zeros_mask, mask) \n\n        sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4)    # (bs, num_kp+1, 3, d, h, w)\n        deformation = (sparse_motion * mask).sum(dim=1)            # (bs, 3, d, h, w)\n        deformation = deformation.permute(0, 2, 3, 4, 1)           # (bs, d, h, w, 3)\n\n        out_dict['deformation'] = deformation\n\n        if self.occlusion:\n            bs, c, d, h, w = prediction.shape\n            prediction = prediction.view(bs, -1, h, w)\n            occlusion_map = torch.sigmoid(self.occlusion(prediction))\n            out_dict['occlusion_map'] = occlusion_map\n\n        return out_dict\n"
  },
  {
    "path": "src/facerender/modules/discriminator.py",
    "content": "from torch import nn\nimport torch.nn.functional as F\nfrom facerender.modules.util import kp2gaussian\nimport torch\n\n\nclass DownBlock2d(nn.Module):\n    \"\"\"\n    Simple block for processing video (encoder).\n    \"\"\"\n\n    def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):\n        super(DownBlock2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)\n\n        if sn:\n            self.conv = nn.utils.spectral_norm(self.conv)\n\n        if norm:\n            self.norm = nn.InstanceNorm2d(out_features, affine=True)\n        else:\n            self.norm = None\n        self.pool = pool\n\n    def forward(self, x):\n        out = x\n        out = self.conv(out)\n        if self.norm:\n            out = self.norm(out)\n        out = F.leaky_relu(out, 0.2)\n        if self.pool:\n            out = F.avg_pool2d(out, (2, 2))\n        return out\n\n\nclass Discriminator(nn.Module):\n    \"\"\"\n    Discriminator similar to Pix2Pix\n    \"\"\"\n\n    def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,\n                 sn=False, **kwargs):\n        super(Discriminator, self).__init__()\n\n        down_blocks = []\n        for i in range(num_blocks):\n            down_blocks.append(\n                DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),\n                            min(max_features, block_expansion * (2 ** (i + 1))),\n                            norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))\n\n        self.down_blocks = nn.ModuleList(down_blocks)\n        self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)\n        if sn:\n            self.conv = nn.utils.spectral_norm(self.conv)\n\n    def forward(self, x):\n        feature_maps = []\n        out = x\n\n        for down_block in self.down_blocks:\n            feature_maps.append(down_block(out))\n            out = feature_maps[-1]\n        prediction_map = self.conv(out)\n\n        return feature_maps, prediction_map\n\n\nclass MultiScaleDiscriminator(nn.Module):\n    \"\"\"\n    Multi-scale (scale) discriminator\n    \"\"\"\n\n    def __init__(self, scales=(), **kwargs):\n        super(MultiScaleDiscriminator, self).__init__()\n        self.scales = scales\n        discs = {}\n        for scale in scales:\n            discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)\n        self.discs = nn.ModuleDict(discs)\n\n    def forward(self, x):\n        out_dict = {}\n        for scale, disc in self.discs.items():\n            scale = str(scale).replace('-', '.')\n            key = 'prediction_' + scale\n            feature_maps, prediction_map = disc(x[key])\n            out_dict['feature_maps_' + scale] = feature_maps\n            out_dict['prediction_map_' + scale] = prediction_map\n        return out_dict\n"
  },
  {
    "path": "src/facerender/modules/generator.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock\nfrom src.facerender.modules.dense_motion import DenseMotionNetwork\n\n\nclass OcclusionAwareGenerator(nn.Module):\n    \"\"\"\n    Generator follows NVIDIA architecture.\n    \"\"\"\n\n    def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,\n                 num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):\n        super(OcclusionAwareGenerator, self).__init__()\n\n        if dense_motion_params is not None:\n            self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,\n                                                           estimate_occlusion_map=estimate_occlusion_map,\n                                                           **dense_motion_params)\n        else:\n            self.dense_motion_network = None\n\n        self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3))\n\n        down_blocks = []\n        for i in range(num_down_blocks):\n            in_features = min(max_features, block_expansion * (2 ** i))\n            out_features = min(max_features, block_expansion * (2 ** (i + 1)))\n            down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n        self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)\n\n        self.reshape_channel = reshape_channel\n        self.reshape_depth = reshape_depth\n\n        self.resblocks_3d = torch.nn.Sequential()\n        for i in range(num_resblocks):\n            self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))\n\n        out_features = block_expansion * (2 ** (num_down_blocks))\n        self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)\n        self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)\n\n        self.resblocks_2d = torch.nn.Sequential()\n        for i in range(num_resblocks):\n            self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1))\n\n        up_blocks = []\n        for i in range(num_down_blocks):\n            in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i)))\n            out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1)))\n            up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))\n        self.up_blocks = nn.ModuleList(up_blocks)\n\n        self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3))\n        self.estimate_occlusion_map = estimate_occlusion_map\n        self.image_channel = image_channel\n\n    def deform_input(self, inp, deformation):\n        _, d_old, h_old, w_old, _ = deformation.shape\n        _, _, d, h, w = inp.shape\n        if d_old != d or h_old != h or w_old != w:\n            deformation = deformation.permute(0, 4, 1, 2, 3)\n            deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')\n            deformation = deformation.permute(0, 2, 3, 4, 1)\n        return F.grid_sample(inp, deformation)\n\n    def forward(self, source_image, kp_driving, kp_source):\n        # Encoding (downsampling) part\n        out = self.first(source_image)\n        for i in range(len(self.down_blocks)):\n            out = self.down_blocks[i](out)\n        out = self.second(out)\n        bs, c, h, w = out.shape\n        # print(out.shape)\n        feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) \n        feature_3d = self.resblocks_3d(feature_3d)\n\n        # Transforming feature representation according to deformation and occlusion\n        output_dict = {}\n        if self.dense_motion_network is not None:\n            dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,\n                                                     kp_source=kp_source)\n            output_dict['mask'] = dense_motion['mask']\n\n            if 'occlusion_map' in dense_motion:\n                occlusion_map = dense_motion['occlusion_map']\n                output_dict['occlusion_map'] = occlusion_map\n            else:\n                occlusion_map = None\n            deformation = dense_motion['deformation']\n            out = self.deform_input(feature_3d, deformation)\n\n            bs, c, d, h, w = out.shape\n            out = out.view(bs, c*d, h, w)\n            out = self.third(out)\n            out = self.fourth(out)\n\n            if occlusion_map is not None:\n                if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:\n                    occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')\n                out = out * occlusion_map\n\n            # output_dict[\"deformed\"] = self.deform_input(source_image, deformation)  # 3d deformation cannot deform 2d image\n\n        # Decoding part\n        out = self.resblocks_2d(out)\n        for i in range(len(self.up_blocks)):\n            out = self.up_blocks[i](out)\n        out = self.final(out)\n        out = F.sigmoid(out)\n\n        output_dict[\"prediction\"] = out\n\n        return output_dict\n\n\nclass SPADEDecoder(nn.Module):\n    def __init__(self):\n        super().__init__()\n        ic = 256\n        oc = 64\n        norm_G = 'spadespectralinstance'\n        label_nc = 256\n        \n        self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)\n        self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)\n        self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)\n        self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)\n        self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)\n        self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)\n        self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)\n        self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)\n        self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)\n        self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)\n        self.up = nn.Upsample(scale_factor=2)\n        \n    def forward(self, feature):\n        seg = feature\n        x = self.fc(feature)\n        x = self.G_middle_0(x, seg)\n        x = self.G_middle_1(x, seg)\n        x = self.G_middle_2(x, seg)\n        x = self.G_middle_3(x, seg)\n        x = self.G_middle_4(x, seg)\n        x = self.G_middle_5(x, seg)\n        x = self.up(x)                \n        x = self.up_0(x, seg)         # 256, 128, 128\n        x = self.up(x)                \n        x = self.up_1(x, seg)         # 64, 256, 256\n\n        x = self.conv_img(F.leaky_relu(x, 2e-1))\n        # x = torch.tanh(x)\n        x = F.sigmoid(x)\n        \n        return x\n\n\nclass OcclusionAwareSPADEGenerator(nn.Module):\n\n    def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,\n                 num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):\n        super(OcclusionAwareSPADEGenerator, self).__init__()\n\n        if dense_motion_params is not None:\n            self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,\n                                                           estimate_occlusion_map=estimate_occlusion_map,\n                                                           **dense_motion_params)\n        else:\n            self.dense_motion_network = None\n\n        self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))\n\n        down_blocks = []\n        for i in range(num_down_blocks):\n            in_features = min(max_features, block_expansion * (2 ** i))\n            out_features = min(max_features, block_expansion * (2 ** (i + 1)))\n            down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n        self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)\n\n        self.reshape_channel = reshape_channel\n        self.reshape_depth = reshape_depth\n\n        self.resblocks_3d = torch.nn.Sequential()\n        for i in range(num_resblocks):\n            self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))\n\n        out_features = block_expansion * (2 ** (num_down_blocks))\n        self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)\n        self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)\n\n        self.estimate_occlusion_map = estimate_occlusion_map\n        self.image_channel = image_channel\n\n        self.decoder = SPADEDecoder()\n\n    def deform_input(self, inp, deformation):\n        _, d_old, h_old, w_old, _ = deformation.shape\n        _, _, d, h, w = inp.shape\n        if d_old != d or h_old != h or w_old != w:\n            deformation = deformation.permute(0, 4, 1, 2, 3)\n            deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')\n            deformation = deformation.permute(0, 2, 3, 4, 1)\n        return F.grid_sample(inp, deformation)\n\n    def forward(self, source_image, kp_driving, kp_source):\n        # Encoding (downsampling) part\n        out = self.first(source_image)\n        for i in range(len(self.down_blocks)):\n            out = self.down_blocks[i](out)\n        out = self.second(out)\n        bs, c, h, w = out.shape\n        # print(out.shape)\n        feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) \n        feature_3d = self.resblocks_3d(feature_3d)\n\n        # Transforming feature representation according to deformation and occlusion\n        output_dict = {}\n        if self.dense_motion_network is not None:\n            dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,\n                                                     kp_source=kp_source)\n            output_dict['mask'] = dense_motion['mask']\n\n            # import pdb; pdb.set_trace()\n\n            if 'occlusion_map' in dense_motion:\n                occlusion_map = dense_motion['occlusion_map']\n                output_dict['occlusion_map'] = occlusion_map\n            else:\n                occlusion_map = None\n            deformation = dense_motion['deformation']\n            out = self.deform_input(feature_3d, deformation)\n\n            bs, c, d, h, w = out.shape\n            out = out.view(bs, c*d, h, w)\n            out = self.third(out)\n            out = self.fourth(out)\n\n            # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map)\n            \n            if occlusion_map is not None:\n                if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:\n                    occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')\n                out = out * occlusion_map\n\n        # Decoding part\n        out = self.decoder(out)\n\n        output_dict[\"prediction\"] = out\n        \n        return output_dict"
  },
  {
    "path": "src/facerender/modules/keypoint_detector.py",
    "content": "from torch import nn\nimport torch\nimport torch.nn.functional as F\n\nfrom src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\nfrom src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck\n\n\nclass KPDetector(nn.Module):\n    \"\"\"\n    Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint.\n    \"\"\"\n\n    def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth,\n                 num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False):\n        super(KPDetector, self).__init__()\n\n        self.predictor = KPHourglass(block_expansion, in_features=image_channel,\n                                     max_features=max_features,  reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks)\n\n        # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3)\n        self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1)\n\n        if estimate_jacobian:\n            self.num_jacobian_maps = 1 if single_jacobian_map else num_kp\n            # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3)\n            self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1)\n            '''\n            initial as:\n            [[1 0 0]\n             [0 1 0]\n             [0 0 1]]\n            '''\n            self.jacobian.weight.data.zero_()\n            self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))\n        else:\n            self.jacobian = None\n\n        self.temperature = temperature\n        self.scale_factor = scale_factor\n        if self.scale_factor != 1:\n            self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor)\n\n    def gaussian2kp(self, heatmap):\n        \"\"\"\n        Extract the mean from a heatmap\n        \"\"\"\n        shape = heatmap.shape\n        heatmap = heatmap.unsqueeze(-1)\n        grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)\n        value = (heatmap * grid).sum(dim=(2, 3, 4))\n        kp = {'value': value}\n\n        return kp\n\n    def forward(self, x):\n        if self.scale_factor != 1:\n            x = self.down(x)\n\n        feature_map = self.predictor(x)\n        prediction = self.kp(feature_map)\n\n        final_shape = prediction.shape\n        heatmap = prediction.view(final_shape[0], final_shape[1], -1)\n        heatmap = F.softmax(heatmap / self.temperature, dim=2)\n        heatmap = heatmap.view(*final_shape)\n\n        out = self.gaussian2kp(heatmap)\n\n        if self.jacobian is not None:\n            jacobian_map = self.jacobian(feature_map)\n            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2],\n                                                final_shape[3], final_shape[4])\n            heatmap = heatmap.unsqueeze(2)\n\n            jacobian = heatmap * jacobian_map\n            jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1)\n            jacobian = jacobian.sum(dim=-1)\n            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3)\n            out['jacobian'] = jacobian\n\n        return out\n\n\nclass HEEstimator(nn.Module):\n    \"\"\"\n    Estimating head pose and expression.\n    \"\"\"\n\n    def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True):\n        super(HEEstimator, self).__init__()\n\n        self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2)\n        self.norm1 = BatchNorm2d(block_expansion, affine=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n        self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1)\n        self.norm2 = BatchNorm2d(256, affine=True)\n\n        self.block1 = nn.Sequential()\n        for i in range(3):\n            self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1))\n\n        self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)\n        self.norm3 = BatchNorm2d(512, affine=True)\n        self.block2 = ResBottleneck(in_features=512, stride=2)\n\n        self.block3 = nn.Sequential()\n        for i in range(3):\n            self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1))\n\n        self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1)\n        self.norm4 = BatchNorm2d(1024, affine=True)\n        self.block4 = ResBottleneck(in_features=1024, stride=2)\n\n        self.block5 = nn.Sequential()\n        for i in range(5):\n            self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1))\n\n        self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1)\n        self.norm5 = BatchNorm2d(2048, affine=True)\n        self.block6 = ResBottleneck(in_features=2048, stride=2)\n\n        self.block7 = nn.Sequential()\n        for i in range(2):\n            self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1))\n\n        self.fc_roll = nn.Linear(2048, num_bins)\n        self.fc_pitch = nn.Linear(2048, num_bins)\n        self.fc_yaw = nn.Linear(2048, num_bins)\n\n        self.fc_t = nn.Linear(2048, 3)\n\n        self.fc_exp = nn.Linear(2048, 3*num_kp)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.norm1(out)\n        out = F.relu(out)\n        out = self.maxpool(out)\n\n        out = self.conv2(out)\n        out = self.norm2(out)\n        out = F.relu(out)\n\n        out = self.block1(out)\n\n        out = self.conv3(out)\n        out = self.norm3(out)\n        out = F.relu(out)\n        out = self.block2(out)\n\n        out = self.block3(out)\n\n        out = self.conv4(out)\n        out = self.norm4(out)\n        out = F.relu(out)\n        out = self.block4(out)\n\n        out = self.block5(out)\n\n        out = self.conv5(out)\n        out = self.norm5(out)\n        out = F.relu(out)\n        out = self.block6(out)\n\n        out = self.block7(out)\n\n        out = F.adaptive_avg_pool2d(out, 1)\n        out = out.view(out.shape[0], -1)\n\n        yaw = self.fc_roll(out)\n        pitch = self.fc_pitch(out)\n        roll = self.fc_yaw(out)\n        t = self.fc_t(out)\n        exp = self.fc_exp(out)\n\n        return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}\n\n"
  },
  {
    "path": "src/facerender/modules/make_animation.py",
    "content": "from scipy.spatial import ConvexHull\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom tqdm import tqdm \n\ndef normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,\n                 use_relative_movement=False, use_relative_jacobian=False):\n    if adapt_movement_scale:\n        source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume\n        driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume\n        adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)\n    else:\n        adapt_movement_scale = 1\n\n    kp_new = {k: v for k, v in kp_driving.items()}\n\n    if use_relative_movement:\n        kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])\n        kp_value_diff *= adapt_movement_scale\n        kp_new['value'] = kp_value_diff + kp_source['value']\n\n        if use_relative_jacobian:\n            jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))\n            kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])\n\n    return kp_new\n\ndef headpose_pred_to_degree(pred):\n    device = pred.device\n    idx_tensor = [idx for idx in range(66)]\n    idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device)\n    pred = F.softmax(pred)\n    degree = torch.sum(pred*idx_tensor, 1) * 3 - 99\n    return degree\n\ndef get_rotation_matrix(yaw, pitch, roll):\n    yaw = yaw / 180 * 3.14\n    pitch = pitch / 180 * 3.14\n    roll = roll / 180 * 3.14\n\n    roll = roll.unsqueeze(1)\n    pitch = pitch.unsqueeze(1)\n    yaw = yaw.unsqueeze(1)\n\n    pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), \n                          torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),\n                          torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)\n    pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)\n\n    yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), \n                           torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),\n                           -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)\n    yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)\n\n    roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),  \n                         torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),\n                         torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)\n    roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)\n\n    rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)\n\n    return rot_mat\n\ndef keypoint_transformation(kp_canonical, he, wo_exp=False):\n    kp = kp_canonical['value']    # (bs, k, 3) \n    yaw, pitch, roll= he['yaw'], he['pitch'], he['roll']      \n    yaw = headpose_pred_to_degree(yaw) \n    pitch = headpose_pred_to_degree(pitch)\n    roll = headpose_pred_to_degree(roll)\n\n    if 'yaw_in' in he:\n        yaw = he['yaw_in']\n    if 'pitch_in' in he:\n        pitch = he['pitch_in']\n    if 'roll_in' in he:\n        roll = he['roll_in']\n\n    rot_mat = get_rotation_matrix(yaw, pitch, roll)    # (bs, 3, 3)\n\n    t, exp = he['t'], he['exp']\n    if wo_exp:\n        exp =  exp*0  \n    \n    # keypoint rotation\n    kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)\n\n    # keypoint translation\n    t[:, 0] = t[:, 0]*0\n    t[:, 2] = t[:, 2]*0\n    t = t.unsqueeze(1).repeat(1, kp.shape[1], 1)\n    kp_t = kp_rotated + t\n\n    # add expression deviation \n    exp = exp.view(exp.shape[0], -1, 3)\n    kp_transformed = kp_t + exp\n\n    return {'value': kp_transformed}\n\n\n\ndef make_animation(source_image, source_semantics, target_semantics,\n                            generator, kp_detector, he_estimator, mapping, \n                            yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,\n                            use_exp=True, use_half=False):\n    with torch.no_grad():\n        predictions = []\n\n        kp_canonical = kp_detector(source_image)\n        he_source = mapping(source_semantics)\n        kp_source = keypoint_transformation(kp_canonical, he_source)\n    \n        for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):\n            # still check the dimension\n            # print(target_semantics.shape, source_semantics.shape)\n            target_semantics_frame = target_semantics[:, frame_idx]\n            he_driving = mapping(target_semantics_frame)\n            if yaw_c_seq is not None:\n                he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]\n            if pitch_c_seq is not None:\n                he_driving['pitch_in'] = pitch_c_seq[:, frame_idx] \n            if roll_c_seq is not None:\n                he_driving['roll_in'] = roll_c_seq[:, frame_idx] \n            \n            kp_driving = keypoint_transformation(kp_canonical, he_driving)\n                \n            kp_norm = kp_driving\n            out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)\n            '''\n            source_image_new = out['prediction'].squeeze(1)\n            kp_canonical_new =  kp_detector(source_image_new)\n            he_source_new = he_estimator(source_image_new) \n            kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)\n            kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)\n            out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)\n            '''\n            predictions.append(out['prediction'])\n        predictions_ts = torch.stack(predictions, dim=1)\n    return predictions_ts\n\nclass AnimateModel(torch.nn.Module):\n    \"\"\"\n    Merge all generator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, generator, kp_extractor, mapping):\n        super(AnimateModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.generator = generator\n        self.mapping = mapping\n\n        self.kp_extractor.eval()\n        self.generator.eval()\n        self.mapping.eval()\n\n    def forward(self, x):\n        \n        source_image = x['source_image']\n        source_semantics = x['source_semantics']\n        target_semantics = x['target_semantics']\n        yaw_c_seq = x['yaw_c_seq']\n        pitch_c_seq = x['pitch_c_seq']\n        roll_c_seq = x['roll_c_seq']\n\n        predictions_video = make_animation(source_image, source_semantics, target_semantics,\n                                        self.generator, self.kp_extractor,\n                                        self.mapping, use_exp = True,\n                                        yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq)\n        \n        return predictions_video"
  },
  {
    "path": "src/facerender/modules/mapping.py",
    "content": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MappingNet(nn.Module):\n    def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins):\n        super( MappingNet, self).__init__()\n\n        self.layer = layer\n        nonlinearity = nn.LeakyReLU(0.1)\n\n        self.first = nn.Sequential(\n            torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))\n\n        for i in range(layer):\n            net = nn.Sequential(nonlinearity,\n                torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))\n            setattr(self, 'encoder' + str(i), net)   \n\n        self.pooling = nn.AdaptiveAvgPool1d(1)\n        self.output_nc = descriptor_nc\n\n        self.fc_roll = nn.Linear(descriptor_nc, num_bins)\n        self.fc_pitch = nn.Linear(descriptor_nc, num_bins)\n        self.fc_yaw = nn.Linear(descriptor_nc, num_bins)\n        self.fc_t = nn.Linear(descriptor_nc, 3)\n        self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp)\n\n    def forward(self, input_3dmm):\n        out = self.first(input_3dmm)\n        for i in range(self.layer):\n            model = getattr(self, 'encoder' + str(i))\n            out = model(out) + out[:,:,3:-3]\n        out = self.pooling(out)\n        out = out.view(out.shape[0], -1)\n        #print('out:', out.shape)\n\n        yaw = self.fc_yaw(out)\n        pitch = self.fc_pitch(out)\n        roll = self.fc_roll(out)\n        t = self.fc_t(out)\n        exp = self.fc_exp(out)\n\n        return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} "
  },
  {
    "path": "src/facerender/modules/util.py",
    "content": "from torch import nn\n\nimport torch.nn.functional as F\nimport torch\n\nfrom src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\nfrom src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d\n\nimport torch.nn.utils.spectral_norm as spectral_norm\n\n\ndef kp2gaussian(kp, spatial_size, kp_variance):\n    \"\"\"\n    Transform a keypoint into gaussian like representation\n    \"\"\"\n    mean = kp['value']\n\n    coordinate_grid = make_coordinate_grid(spatial_size, mean.type())\n    number_of_leading_dimensions = len(mean.shape) - 1\n    shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape\n    coordinate_grid = coordinate_grid.view(*shape)\n    repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)\n    coordinate_grid = coordinate_grid.repeat(*repeats)\n\n    # Preprocess kp shape\n    shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)\n    mean = mean.view(*shape)\n\n    mean_sub = (coordinate_grid - mean)\n\n    out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)\n\n    return out\n\ndef make_coordinate_grid_2d(spatial_size, type):\n    \"\"\"\n    Create a meshgrid [-1,1] x [-1,1] of given spatial_size.\n    \"\"\"\n    h, w = spatial_size\n    x = torch.arange(w).type(type)\n    y = torch.arange(h).type(type)\n\n    x = (2 * (x / (w - 1)) - 1)\n    y = (2 * (y / (h - 1)) - 1)\n\n    yy = y.view(-1, 1).repeat(1, w)\n    xx = x.view(1, -1).repeat(h, 1)\n\n    meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)\n\n    return meshed\n\n\ndef make_coordinate_grid(spatial_size, type):\n    d, h, w = spatial_size\n    x = torch.arange(w).type(type)\n    y = torch.arange(h).type(type)\n    z = torch.arange(d).type(type)\n\n    x = (2 * (x / (w - 1)) - 1)\n    y = (2 * (y / (h - 1)) - 1)\n    z = (2 * (z / (d - 1)) - 1)\n   \n    yy = y.view(1, -1, 1).repeat(d, 1, w)\n    xx = x.view(1, 1, -1).repeat(d, h, 1)\n    zz = z.view(-1, 1, 1).repeat(1, h, w)\n\n    meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)\n\n    return meshed\n\n\nclass ResBottleneck(nn.Module):\n    def __init__(self, in_features, stride):\n        super(ResBottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1)\n        self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride)\n        self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1)\n        self.norm1 = BatchNorm2d(in_features//4, affine=True)\n        self.norm2 = BatchNorm2d(in_features//4, affine=True)\n        self.norm3 = BatchNorm2d(in_features, affine=True)\n\n        self.stride = stride\n        if self.stride != 1:\n            self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride)\n            self.norm4 = BatchNorm2d(in_features, affine=True)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.norm1(out)\n        out = F.relu(out)\n        out = self.conv2(out)\n        out = self.norm2(out)\n        out = F.relu(out)\n        out = self.conv3(out)\n        out = self.norm3(out)\n        if self.stride != 1:\n            x = self.skip(x)\n            x = self.norm4(x)\n        out += x\n        out = F.relu(out)\n        return out\n\n\nclass ResBlock2d(nn.Module):\n    \"\"\"\n    Res block, preserve spatial resolution.\n    \"\"\"\n\n    def __init__(self, in_features, kernel_size, padding):\n        super(ResBlock2d, self).__init__()\n        self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,\n                               padding=padding)\n        self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,\n                               padding=padding)\n        self.norm1 = BatchNorm2d(in_features, affine=True)\n        self.norm2 = BatchNorm2d(in_features, affine=True)\n\n    def forward(self, x):\n        out = self.norm1(x)\n        out = F.relu(out)\n        out = self.conv1(out)\n        out = self.norm2(out)\n        out = F.relu(out)\n        out = self.conv2(out)\n        out += x\n        return out\n\n\nclass ResBlock3d(nn.Module):\n    \"\"\"\n    Res block, preserve spatial resolution.\n    \"\"\"\n\n    def __init__(self, in_features, kernel_size, padding):\n        super(ResBlock3d, self).__init__()\n        self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,\n                               padding=padding)\n        self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,\n                               padding=padding)\n        self.norm1 = BatchNorm3d(in_features, affine=True)\n        self.norm2 = BatchNorm3d(in_features, affine=True)\n\n    def forward(self, x):\n        out = self.norm1(x)\n        out = F.relu(out)\n        out = self.conv1(out)\n        out = self.norm2(out)\n        out = F.relu(out)\n        out = self.conv2(out)\n        out += x\n        return out\n\n\nclass UpBlock2d(nn.Module):\n    \"\"\"\n    Upsampling block for use in decoder.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):\n        super(UpBlock2d, self).__init__()\n\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,\n                              padding=padding, groups=groups)\n        self.norm = BatchNorm2d(out_features, affine=True)\n\n    def forward(self, x):\n        out = F.interpolate(x, scale_factor=2)\n        out = self.conv(out)\n        out = self.norm(out)\n        out = F.relu(out)\n        return out\n\nclass UpBlock3d(nn.Module):\n    \"\"\"\n    Upsampling block for use in decoder.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):\n        super(UpBlock3d, self).__init__()\n\n        self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,\n                              padding=padding, groups=groups)\n        self.norm = BatchNorm3d(out_features, affine=True)\n\n    def forward(self, x):\n        # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear')\n        out = F.interpolate(x, scale_factor=(1, 2, 2))\n        out = self.conv(out)\n        out = self.norm(out)\n        out = F.relu(out)\n        return out\n\n\nclass DownBlock2d(nn.Module):\n    \"\"\"\n    Downsampling block for use in encoder.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):\n        super(DownBlock2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,\n                              padding=padding, groups=groups)\n        self.norm = BatchNorm2d(out_features, affine=True)\n        self.pool = nn.AvgPool2d(kernel_size=(2, 2))\n\n    def forward(self, x):\n        out = self.conv(x)\n        out = self.norm(out)\n        out = F.relu(out)\n        out = self.pool(out)\n        return out\n\n\nclass DownBlock3d(nn.Module):\n    \"\"\"\n    Downsampling block for use in encoder.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):\n        super(DownBlock3d, self).__init__()\n        '''\n        self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,\n                              padding=padding, groups=groups, stride=(1, 2, 2))\n        '''\n        self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,\n                              padding=padding, groups=groups)\n        self.norm = BatchNorm3d(out_features, affine=True)\n        self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))\n\n    def forward(self, x):\n        out = self.conv(x)\n        out = self.norm(out)\n        out = F.relu(out)\n        out = self.pool(out)\n        return out\n\n\nclass SameBlock2d(nn.Module):\n    \"\"\"\n    Simple block, preserve spatial resolution.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):\n        super(SameBlock2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,\n                              kernel_size=kernel_size, padding=padding, groups=groups)\n        self.norm = BatchNorm2d(out_features, affine=True)\n        if lrelu:\n            self.ac = nn.LeakyReLU()\n        else:\n            self.ac = nn.ReLU()\n\n    def forward(self, x):\n        out = self.conv(x)\n        out = self.norm(out)\n        out = self.ac(out)\n        return out\n\n\nclass Encoder(nn.Module):\n    \"\"\"\n    Hourglass Encoder\n    \"\"\"\n\n    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):\n        super(Encoder, self).__init__()\n\n        down_blocks = []\n        for i in range(num_blocks):\n            down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),\n                                           min(max_features, block_expansion * (2 ** (i + 1))),\n                                           kernel_size=3, padding=1))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n    def forward(self, x):\n        outs = [x]\n        for down_block in self.down_blocks:\n            outs.append(down_block(outs[-1]))\n        return outs\n\n\nclass Decoder(nn.Module):\n    \"\"\"\n    Hourglass Decoder\n    \"\"\"\n\n    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):\n        super(Decoder, self).__init__()\n\n        up_blocks = []\n\n        for i in range(num_blocks)[::-1]:\n            in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))\n            out_filters = min(max_features, block_expansion * (2 ** i))\n            up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))\n\n        self.up_blocks = nn.ModuleList(up_blocks)\n        # self.out_filters = block_expansion\n        self.out_filters = block_expansion + in_features\n\n        self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)\n        self.norm = BatchNorm3d(self.out_filters, affine=True)\n\n    def forward(self, x):\n        out = x.pop()\n        # for up_block in self.up_blocks[:-1]:\n        for up_block in self.up_blocks:\n            out = up_block(out)\n            skip = x.pop()\n            out = torch.cat([out, skip], dim=1)\n        # out = self.up_blocks[-1](out)\n        out = self.conv(out)\n        out = self.norm(out)\n        out = F.relu(out)\n        return out\n\n\nclass Hourglass(nn.Module):\n    \"\"\"\n    Hourglass architecture.\n    \"\"\"\n\n    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):\n        super(Hourglass, self).__init__()\n        self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)\n        self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)\n        self.out_filters = self.decoder.out_filters\n\n    def forward(self, x):\n        return self.decoder(self.encoder(x))\n\n\nclass KPHourglass(nn.Module):\n    \"\"\"\n    Hourglass architecture.\n    \"\"\" \n\n    def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256):\n        super(KPHourglass, self).__init__()\n        \n        self.down_blocks = nn.Sequential()\n        for i in range(num_blocks):\n            self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),\n                                                                   min(max_features, block_expansion * (2 ** (i + 1))),\n                                                                   kernel_size=3, padding=1))\n\n        in_filters = min(max_features, block_expansion * (2 ** num_blocks))\n        self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1)\n\n        self.up_blocks = nn.Sequential()\n        for i in range(num_blocks):\n            in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i)))\n            out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1)))\n            self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))\n\n        self.reshape_depth = reshape_depth\n        self.out_filters = out_filters\n\n    def forward(self, x):\n        out = self.down_blocks(x)\n        out = self.conv(out)\n        bs, c, h, w = out.shape\n        out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w)\n        out = self.up_blocks(out)\n\n        return out\n        \n\n\nclass AntiAliasInterpolation2d(nn.Module):\n    \"\"\"\n    Band-limited downsampling, for better preservation of the input signal.\n    \"\"\"\n    def __init__(self, channels, scale):\n        super(AntiAliasInterpolation2d, self).__init__()\n        sigma = (1 / scale - 1) / 2\n        kernel_size = 2 * round(sigma * 4) + 1\n        self.ka = kernel_size // 2\n        self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka\n\n        kernel_size = [kernel_size, kernel_size]\n        sigma = [sigma, sigma]\n        # The gaussian kernel is the product of the\n        # gaussian function of each dimension.\n        kernel = 1\n        meshgrids = torch.meshgrid(\n            [\n                torch.arange(size, dtype=torch.float32)\n                for size in kernel_size\n                ]\n        )\n        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):\n            mean = (size - 1) / 2\n            kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))\n\n        # Make sure sum of values in gaussian kernel equals 1.\n        kernel = kernel / torch.sum(kernel)\n        # Reshape to depthwise convolutional weight\n        kernel = kernel.view(1, 1, *kernel.size())\n        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))\n\n        self.register_buffer('weight', kernel)\n        self.groups = channels\n        self.scale = scale\n        inv_scale = 1 / scale\n        self.int_inv_scale = int(inv_scale)\n\n    def forward(self, input):\n        if self.scale == 1.0:\n            return input\n\n        out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))\n        out = F.conv2d(out, weight=self.weight, groups=self.groups)\n        out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]\n\n        return out\n\n\nclass SPADE(nn.Module):\n    def __init__(self, norm_nc, label_nc):\n        super().__init__()\n\n        self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)\n        nhidden = 128\n\n        self.mlp_shared = nn.Sequential(\n            nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),\n            nn.ReLU())\n        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)\n        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)\n\n    def forward(self, x, segmap):\n        normalized = self.param_free_norm(x)\n        segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')\n        actv = self.mlp_shared(segmap)\n        gamma = self.mlp_gamma(actv)\n        beta = self.mlp_beta(actv)\n        out = normalized * (1 + gamma) + beta\n        return out\n    \n\nclass SPADEResnetBlock(nn.Module):\n    def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):\n        super().__init__()\n        # Attributes\n        self.learned_shortcut = (fin != fout)\n        fmiddle = min(fin, fout)\n        self.use_se = use_se\n        # create conv layers\n        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)\n        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)\n        if self.learned_shortcut:\n            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)\n        # apply spectral norm if specified\n        if 'spectral' in norm_G:\n            self.conv_0 = spectral_norm(self.conv_0)\n            self.conv_1 = spectral_norm(self.conv_1)\n            if self.learned_shortcut:\n                self.conv_s = spectral_norm(self.conv_s)\n        # define normalization layers\n        self.norm_0 = SPADE(fin, label_nc)\n        self.norm_1 = SPADE(fmiddle, label_nc)\n        if self.learned_shortcut:\n            self.norm_s = SPADE(fin, label_nc)\n\n    def forward(self, x, seg1):\n        x_s = self.shortcut(x, seg1)\n        dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))\n        dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))\n        out = x_s + dx\n        return out\n\n    def shortcut(self, x, seg1):\n        if self.learned_shortcut:\n            x_s = self.conv_s(self.norm_s(x, seg1))\n        else:\n            x_s = x\n        return x_s\n\n    def actvn(self, x):\n        return F.leaky_relu(x, 2e-1)\n\nclass audio2image(nn.Module):\n    def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params):\n        super().__init__()\n        # Attributes\n        self.generator = generator\n        self.kp_extractor = kp_extractor\n        self.he_estimator_video = he_estimator_video\n        self.he_estimator_audio = he_estimator_audio\n        self.train_params = train_params\n\n    def headpose_pred_to_degree(self, pred):\n        device = pred.device\n        idx_tensor = [idx for idx in range(66)]\n        idx_tensor = torch.FloatTensor(idx_tensor).to(device)\n        pred = F.softmax(pred)\n        degree = torch.sum(pred*idx_tensor, 1) * 3 - 99\n\n        return degree\n    \n    def get_rotation_matrix(self, yaw, pitch, roll):\n        yaw = yaw / 180 * 3.14\n        pitch = pitch / 180 * 3.14\n        roll = roll / 180 * 3.14\n\n        roll = roll.unsqueeze(1)\n        pitch = pitch.unsqueeze(1)\n        yaw = yaw.unsqueeze(1)\n\n        roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), \n                          torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll),\n                          torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1)\n        roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)\n\n        pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), \n                           torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch),\n                           -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1)\n        pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)\n\n        yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw),  \n                         torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw),\n                         torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1)\n        yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)\n\n        rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat)\n\n        return rot_mat\n\n    def keypoint_transformation(self, kp_canonical, he):\n        kp = kp_canonical['value']    # (bs, k, 3)\n        yaw, pitch, roll = he['yaw'], he['pitch'], he['roll']\n        t, exp = he['t'], he['exp']\n    \n        yaw = self.headpose_pred_to_degree(yaw)\n        pitch = self.headpose_pred_to_degree(pitch)\n        roll = self.headpose_pred_to_degree(roll)\n\n        rot_mat = self.get_rotation_matrix(yaw, pitch, roll)    # (bs, 3, 3)\n    \n        # keypoint rotation\n        kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)\n\n    \n\n        # keypoint translation\n        t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1)\n        kp_t = kp_rotated + t\n\n        # add expression deviation \n        exp = exp.view(exp.shape[0], -1, 3)\n        kp_transformed = kp_t + exp\n\n        return {'value': kp_transformed}\n\n    def forward(self, source_image, target_audio):\n        pose_source = self.he_estimator_video(source_image)\n        pose_generated = self.he_estimator_audio(target_audio)\n        kp_canonical = self.kp_extractor(source_image)\n        kp_source = self.keypoint_transformation(kp_canonical, pose_source)\n        kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated)\n        generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated)\n        return generated"
  },
  {
    "path": "src/facerender/sync_batchnorm/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : __init__.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nfrom .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d\nfrom .replicate import DataParallelWithCallback, patch_replication_callback\n"
  },
  {
    "path": "src/facerender/sync_batchnorm/batchnorm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : batchnorm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport collections\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast\n\nfrom .comm import SyncMaster\n\n__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']\n\n\ndef _sum_ft(tensor):\n    \"\"\"sum over the first and last dimention\"\"\"\n    return tensor.sum(dim=0).sum(dim=-1)\n\n\ndef _unsqueeze_ft(tensor):\n    \"\"\"add new dementions at the front and the tail\"\"\"\n    return tensor.unsqueeze(0).unsqueeze(-1)\n\n\n_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])\n_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])\n\n\nclass _SynchronizedBatchNorm(_BatchNorm):\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):\n        super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)\n\n        self._sync_master = SyncMaster(self._data_parallel_master)\n\n        self._is_parallel = False\n        self._parallel_id = None\n        self._slave_pipe = None\n\n    def forward(self, input):\n        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.\n        if not (self._is_parallel and self.training):\n            return F.batch_norm(\n                input, self.running_mean, self.running_var, self.weight, self.bias,\n                self.training, self.momentum, self.eps)\n\n        # Resize the input to (B, C, -1).\n        input_shape = input.size()\n        input = input.view(input.size(0), self.num_features, -1)\n\n        # Compute the sum and square-sum.\n        sum_size = input.size(0) * input.size(2)\n        input_sum = _sum_ft(input)\n        input_ssum = _sum_ft(input ** 2)\n\n        # Reduce-and-broadcast the statistics.\n        if self._parallel_id == 0:\n            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))\n        else:\n            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))\n\n        # Compute the output.\n        if self.affine:\n            # MJY:: Fuse the multiplication for speed.\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)\n        else:\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)\n\n        # Reshape it.\n        return output.view(input_shape)\n\n    def __data_parallel_replicate__(self, ctx, copy_id):\n        self._is_parallel = True\n        self._parallel_id = copy_id\n\n        # parallel_id == 0 means master device.\n        if self._parallel_id == 0:\n            ctx.sync_master = self._sync_master\n        else:\n            self._slave_pipe = ctx.sync_master.register_slave(copy_id)\n\n    def _data_parallel_master(self, intermediates):\n        \"\"\"Reduce the sum and square-sum, compute the statistics, and broadcast it.\"\"\"\n\n        # Always using same \"device order\" makes the ReduceAdd operation faster.\n        # Thanks to:: Tete Xiao (http://tetexiao.com/)\n        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())\n\n        to_reduce = [i[1][:2] for i in intermediates]\n        to_reduce = [j for i in to_reduce for j in i]  # flatten\n        target_gpus = [i[1].sum.get_device() for i in intermediates]\n\n        sum_size = sum([i[1].sum_size for i in intermediates])\n        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)\n        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)\n\n        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)\n\n        outputs = []\n        for i, rec in enumerate(intermediates):\n            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))\n\n        return outputs\n\n    def _compute_mean_std(self, sum_, ssum, size):\n        \"\"\"Compute the mean and standard-deviation with sum and square-sum. This method\n        also maintains the moving average on the master device.\"\"\"\n        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'\n        mean = sum_ / size\n        sumvar = ssum - sum_ * mean\n        unbias_var = sumvar / (size - 1)\n        bias_var = sumvar / size\n\n        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data\n        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data\n\n        return mean, bias_var.clamp(self.eps) ** -0.5\n\n\nclass SynchronizedBatchNorm1d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a\n    mini-batch.\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm1d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of size\n            `batch_size x num_features [x width]`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C)` or :math:`(N, C, L)`\n        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 2 and input.dim() != 3:\n            raise ValueError('expected 2D or 3D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm1d, self)._check_input_dim(input)\n\n\nclass SynchronizedBatchNorm2d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 4d input that is seen as a mini-batch\n    of 3d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm2d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, H, W)`\n        - Output: :math:`(N, C, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError('expected 4D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm2d, self)._check_input_dim(input)\n\n\nclass SynchronizedBatchNorm3d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 5d input that is seen as a mini-batch\n    of 4d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm3d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm\n    or Spatio-temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x depth x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, D, H, W)`\n        - Output: :math:`(N, C, D, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 5:\n            raise ValueError('expected 5D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm3d, self)._check_input_dim(input)\n"
  },
  {
    "path": "src/facerender/sync_batchnorm/comm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : comm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport queue\nimport collections\nimport threading\n\n__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']\n\n\nclass FutureResult(object):\n    \"\"\"A thread-safe future implementation. Used only as one-to-one pipe.\"\"\"\n\n    def __init__(self):\n        self._result = None\n        self._lock = threading.Lock()\n        self._cond = threading.Condition(self._lock)\n\n    def put(self, result):\n        with self._lock:\n            assert self._result is None, 'Previous result has\\'t been fetched.'\n            self._result = result\n            self._cond.notify()\n\n    def get(self):\n        with self._lock:\n            if self._result is None:\n                self._cond.wait()\n\n            res = self._result\n            self._result = None\n            return res\n\n\n_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])\n_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])\n\n\nclass SlavePipe(_SlavePipeBase):\n    \"\"\"Pipe for master-slave communication.\"\"\"\n\n    def run_slave(self, msg):\n        self.queue.put((self.identifier, msg))\n        ret = self.result.get()\n        self.queue.put(True)\n        return ret\n\n\nclass SyncMaster(object):\n    \"\"\"An abstract `SyncMaster` object.\n\n    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should\n    call `register(id)` and obtain an `SlavePipe` to communicate with the master.\n    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,\n    and passed to a registered callback.\n    - After receiving the messages, the master device should gather the information and determine to message passed\n    back to each slave devices.\n    \"\"\"\n\n    def __init__(self, master_callback):\n        \"\"\"\n\n        Args:\n            master_callback: a callback to be invoked after having collected messages from slave devices.\n        \"\"\"\n        self._master_callback = master_callback\n        self._queue = queue.Queue()\n        self._registry = collections.OrderedDict()\n        self._activated = False\n\n    def __getstate__(self):\n        return {'master_callback': self._master_callback}\n\n    def __setstate__(self, state):\n        self.__init__(state['master_callback'])\n\n    def register_slave(self, identifier):\n        \"\"\"\n        Register an slave device.\n\n        Args:\n            identifier: an identifier, usually is the device id.\n\n        Returns: a `SlavePipe` object which can be used to communicate with the master device.\n\n        \"\"\"\n        if self._activated:\n            assert self._queue.empty(), 'Queue is not clean before next initialization.'\n            self._activated = False\n            self._registry.clear()\n        future = FutureResult()\n        self._registry[identifier] = _MasterRegistry(future)\n        return SlavePipe(identifier, self._queue, future)\n\n    def run_master(self, master_msg):\n        \"\"\"\n        Main entry for the master device in each forward pass.\n        The messages were first collected from each devices (including the master device), and then\n        an callback will be invoked to compute the message to be sent back to each devices\n        (including the master device).\n\n        Args:\n            master_msg: the message that the master want to send to itself. This will be placed as the first\n            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.\n\n        Returns: the message to be sent back to the master device.\n\n        \"\"\"\n        self._activated = True\n\n        intermediates = [(0, master_msg)]\n        for i in range(self.nr_slaves):\n            intermediates.append(self._queue.get())\n\n        results = self._master_callback(intermediates)\n        assert results[0][0] == 0, 'The first result should belongs to the master.'\n\n        for i, res in results:\n            if i == 0:\n                continue\n            self._registry[i].result.put(res)\n\n        for i in range(self.nr_slaves):\n            assert self._queue.get() is True\n\n        return results[0][1]\n\n    @property\n    def nr_slaves(self):\n        return len(self._registry)\n"
  },
  {
    "path": "src/facerender/sync_batchnorm/replicate.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : replicate.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport functools\n\nfrom torch.nn.parallel.data_parallel import DataParallel\n\n__all__ = [\n    'CallbackContext',\n    'execute_replication_callbacks',\n    'DataParallelWithCallback',\n    'patch_replication_callback'\n]\n\n\nclass CallbackContext(object):\n    pass\n\n\ndef execute_replication_callbacks(modules):\n    \"\"\"\n    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.\n\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Note that, as all modules are isomorphism, we assign each sub-module with a context\n    (shared among multiple copies of this module on different devices).\n    Through this context, different copies can share some information.\n\n    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback\n    of any slave copies.\n    \"\"\"\n    master_copy = modules[0]\n    nr_modules = len(list(master_copy.modules()))\n    ctxs = [CallbackContext() for _ in range(nr_modules)]\n\n    for i, module in enumerate(modules):\n        for j, m in enumerate(module.modules()):\n            if hasattr(m, '__data_parallel_replicate__'):\n                m.__data_parallel_replicate__(ctxs[j], i)\n\n\nclass DataParallelWithCallback(DataParallel):\n    \"\"\"\n    Data Parallel with a replication callback.\n\n    An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by\n    original `replicate` function.\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n        # sync_bn.__data_parallel_replicate__ will be invoked.\n    \"\"\"\n\n    def replicate(self, module, device_ids):\n        modules = super(DataParallelWithCallback, self).replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n\ndef patch_replication_callback(data_parallel):\n    \"\"\"\n    Monkey-patch an existing `DataParallel` object. Add the replication callback.\n    Useful when you have customized `DataParallel` implementation.\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])\n        > patch_replication_callback(sync_bn)\n        # this is equivalent to\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n    \"\"\"\n\n    assert isinstance(data_parallel, DataParallel)\n\n    old_replicate = data_parallel.replicate\n\n    @functools.wraps(old_replicate)\n    def new_replicate(module, device_ids):\n        modules = old_replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n    data_parallel.replicate = new_replicate\n"
  },
  {
    "path": "src/facerender/sync_batchnorm/unittest.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : unittest.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport unittest\n\nimport numpy as np\nfrom torch.autograd import Variable\n\n\ndef as_numpy(v):\n    if isinstance(v, Variable):\n        v = v.data\n    return v.cpu().numpy()\n\n\nclass TorchTestCase(unittest.TestCase):\n    def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):\n        npa, npb = as_numpy(a), as_numpy(b)\n        self.assertTrue(\n                np.allclose(npa, npb, atol=atol),\n                'Tensor close check failed\\n{}\\n{}\\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())\n        )\n"
  },
  {
    "path": "src/generate_batch.py",
    "content": "import os\n\nfrom tqdm import tqdm\nimport torch\nimport numpy as np\nimport random\nimport scipy.io as scio\nimport src.utils.audio as audio\n\ndef crop_pad_audio(wav, audio_length):\n    if len(wav) > audio_length:\n        wav = wav[:audio_length]\n    elif len(wav) < audio_length:\n        wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)\n    return wav\n\ndef parse_audio_length(audio_length, sr, fps):\n    bit_per_frames = sr / fps\n\n    num_frames = int(audio_length / bit_per_frames)\n    audio_length = int(num_frames * bit_per_frames)\n\n    return audio_length, num_frames\n\ndef generate_blink_seq(num_frames):\n    ratio = np.zeros((num_frames,1))\n    frame_id = 0\n    while frame_id in range(num_frames):\n        start = 80\n        if frame_id+start+9<=num_frames - 1:\n            ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]\n            frame_id = frame_id+start+9\n        else:\n            break\n    return ratio \n\ndef generate_blink_seq_randomly(num_frames):\n    ratio = np.zeros((num_frames,1))\n    if num_frames<=20:\n        return ratio\n    frame_id = 0\n    while frame_id in range(num_frames):\n        start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) \n        if frame_id+start+5<=num_frames - 1:\n            ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]\n            frame_id = frame_id+start+5\n        else:\n            break\n    return ratio\n\ndef get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):\n\n    syncnet_mel_step_size = 16\n    fps = 25\n\n    pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]\n    audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]\n\n    \n    if idlemode:\n        num_frames = int(length_of_audio * 25)\n        indiv_mels = np.zeros((num_frames, 80, 16))\n    else:\n        wav = audio.load_wav(audio_path, 16000) \n        wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)\n        wav = crop_pad_audio(wav, wav_length)\n        orig_mel = audio.melspectrogram(wav).T\n        spec = orig_mel.copy()         # nframes 80\n        indiv_mels = []\n\n        for i in tqdm(range(num_frames), 'mel:'):\n            start_frame_num = i-2\n            start_idx = int(80. * (start_frame_num / float(fps)))\n            end_idx = start_idx + syncnet_mel_step_size\n            seq = list(range(start_idx, end_idx))\n            seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]\n            m = spec[seq, :]\n            indiv_mels.append(m.T)\n        indiv_mels = np.asarray(indiv_mels)         # T 80 16\n\n    ratio = generate_blink_seq_randomly(num_frames)      # T\n    source_semantics_path = first_coeff_path\n    source_semantics_dict = scio.loadmat(source_semantics_path)\n    ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70]         #1 70\n    ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)\n\n    if ref_eyeblink_coeff_path is not None:\n        ratio[:num_frames] = 0\n        refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)\n        refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]\n        refeyeblink_num_frames = refeyeblink_coeff.shape[0]\n        if refeyeblink_num_frames<num_frames:\n            div = num_frames//refeyeblink_num_frames\n            re = num_frames%refeyeblink_num_frames\n            refeyeblink_coeff_list = [refeyeblink_coeff for i in range(div)]\n            refeyeblink_coeff_list.append(refeyeblink_coeff[:re, :64])\n            refeyeblink_coeff = np.concatenate(refeyeblink_coeff_list, axis=0)\n            print(refeyeblink_coeff.shape[0])\n\n        ref_coeff[:, :64] = refeyeblink_coeff[:num_frames, :64] \n    \n    indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16\n\n    if use_blink:\n        ratio = torch.FloatTensor(ratio).unsqueeze(0)                       # bs T\n    else:\n        ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.) \n                               # bs T\n    ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0)                # bs 1 70\n\n    indiv_mels = indiv_mels.to(device)\n    ratio = ratio.to(device)\n    ref_coeff = ref_coeff.to(device)\n\n    return {'indiv_mels': indiv_mels,  \n            'ref': ref_coeff, \n            'num_frames': num_frames, \n            'ratio_gt': ratio,\n            'audio_name': audio_name, 'pic_name': pic_name}\n\n"
  },
  {
    "path": "src/generate_facerender_batch.py",
    "content": "import os\nimport numpy as np\nfrom PIL import Image\nfrom skimage import io, img_as_float32, transform\nimport torch\nimport scipy.io as scio\n\ndef get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path, \n                        batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None, \n                        expression_scale=1.0, still_mode = False, preprocess='crop', size = 256):\n\n    semantic_radius = 13\n    video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]\n    txt_path = os.path.splitext(coeff_path)[0]\n\n    data={}\n\n    img1 = Image.open(pic_path)\n    source_image = np.array(img1)\n    source_image = img_as_float32(source_image)\n    source_image = transform.resize(source_image, (size, size, 3))\n    source_image = source_image.transpose((2, 0, 1))\n    source_image_ts = torch.FloatTensor(source_image).unsqueeze(0)\n    source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1)\n    data['source_image'] = source_image_ts\n \n    source_semantics_dict = scio.loadmat(first_coeff_path)\n    generated_dict = scio.loadmat(coeff_path)\n\n    if 'full' not in preprocess.lower():\n        source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70]         #1 70\n        generated_3dmm = generated_dict['coeff_3dmm'][:,:70]\n\n    else:\n        source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73]         #1 70\n        generated_3dmm = generated_dict['coeff_3dmm'][:,:70]\n\n    source_semantics_new = transform_semantic_1(source_semantics, semantic_radius)\n    source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0)\n    source_semantics_ts = source_semantics_ts.repeat(batch_size, 1, 1)\n    data['source_semantics'] = source_semantics_ts\n\n    # target \n    generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale\n\n    if 'full' in preprocess.lower():\n        generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)\n\n    if still_mode:\n        generated_3dmm[:, 64:] = np.repeat(source_semantics[:, 64:], generated_3dmm.shape[0], axis=0)\n\n    with open(txt_path+'.txt', 'w') as f:\n        for coeff in generated_3dmm:\n            for i in coeff:\n                f.write(str(i)[:7]   + '  '+'\\t')\n            f.write('\\n')\n\n    target_semantics_list = [] \n    frame_num = generated_3dmm.shape[0]\n    data['frame_num'] = frame_num\n    for frame_idx in range(frame_num):\n        target_semantics = transform_semantic_target(generated_3dmm, frame_idx, semantic_radius)\n        target_semantics_list.append(target_semantics)\n\n    remainder = frame_num%batch_size\n    if remainder!=0:\n        for _ in range(batch_size-remainder):\n            target_semantics_list.append(target_semantics)\n\n    target_semantics_np = np.array(target_semantics_list)             #frame_num 70 semantic_radius*2+1\n    target_semantics_np = target_semantics_np.reshape(batch_size, -1, target_semantics_np.shape[-2], target_semantics_np.shape[-1])\n    data['target_semantics_list'] = torch.FloatTensor(target_semantics_np)\n    data['video_name'] = video_name\n    data['audio_path'] = audio_path\n    \n    if input_yaw_list is not None:\n        yaw_c_seq = gen_camera_pose(input_yaw_list, frame_num, batch_size)\n        data['yaw_c_seq'] = torch.FloatTensor(yaw_c_seq)\n    if input_pitch_list is not None:\n        pitch_c_seq = gen_camera_pose(input_pitch_list, frame_num, batch_size)\n        data['pitch_c_seq'] = torch.FloatTensor(pitch_c_seq)\n    if input_roll_list is not None:\n        roll_c_seq = gen_camera_pose(input_roll_list, frame_num, batch_size) \n        data['roll_c_seq'] = torch.FloatTensor(roll_c_seq)\n \n    return data\n\ndef transform_semantic_1(semantic, semantic_radius):\n    semantic_list =  [semantic for i in range(0, semantic_radius*2+1)]\n    coeff_3dmm = np.concatenate(semantic_list, 0)\n    return coeff_3dmm.transpose(1,0)\n\ndef transform_semantic_target(coeff_3dmm, frame_index, semantic_radius):\n    num_frames = coeff_3dmm.shape[0]\n    seq = list(range(frame_index- semantic_radius, frame_index + semantic_radius+1))\n    index = [ min(max(item, 0), num_frames-1) for item in seq ] \n    coeff_3dmm_g = coeff_3dmm[index, :]\n    return coeff_3dmm_g.transpose(1,0)\n\ndef gen_camera_pose(camera_degree_list, frame_num, batch_size):\n\n    new_degree_list = [] \n    if len(camera_degree_list) == 1:\n        for _ in range(frame_num):\n            new_degree_list.append(camera_degree_list[0]) \n        remainder = frame_num%batch_size\n        if remainder!=0:\n            for _ in range(batch_size-remainder):\n                new_degree_list.append(new_degree_list[-1])\n        new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) \n        return new_degree_np\n\n    degree_sum = 0.\n    for i, degree in enumerate(camera_degree_list[1:]):\n        degree_sum += abs(degree-camera_degree_list[i])\n    \n    degree_per_frame = degree_sum/(frame_num-1)\n    for i, degree in enumerate(camera_degree_list[1:]):\n        degree_last = camera_degree_list[i]\n        degree_step = degree_per_frame * abs(degree-degree_last)/(degree-degree_last)\n        new_degree_list =  new_degree_list + list(np.arange(degree_last, degree, degree_step))\n    if len(new_degree_list) > frame_num:\n        new_degree_list = new_degree_list[:frame_num]\n    elif len(new_degree_list) < frame_num:\n        for _ in range(frame_num-len(new_degree_list)):\n            new_degree_list.append(new_degree_list[-1])\n    print(len(new_degree_list))\n    print(frame_num)\n\n    remainder = frame_num%batch_size\n    if remainder!=0:\n        for _ in range(batch_size-remainder):\n            new_degree_list.append(new_degree_list[-1])\n    new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) \n    return new_degree_np\n    \n"
  },
  {
    "path": "src/gradio_demo.py",
    "content": "import torch, uuid\r\nimport os, sys, shutil\r\nfrom src.utils.preprocess import CropAndExtract\r\nfrom src.test_audio2coeff import Audio2Coeff  \r\nfrom src.facerender.animate import AnimateFromCoeff\r\nfrom src.generate_batch import get_data\r\nfrom src.generate_facerender_batch import get_facerender_data\r\n\r\nfrom src.utils.init_path import init_path\r\n\r\nfrom pydub import AudioSegment\r\n\r\n\r\ndef mp3_to_wav(mp3_filename,wav_filename,frame_rate):\r\n    mp3_file = AudioSegment.from_file(file=mp3_filename)\r\n    mp3_file.set_frame_rate(frame_rate).export(wav_filename,format=\"wav\")\r\n\r\n\r\nclass SadTalker():\r\n\r\n    def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):\r\n\r\n        if torch.cuda.is_available() :\r\n            device = \"cuda\"\r\n        else:\r\n            device = \"cpu\"\r\n        \r\n        self.device = device\r\n\r\n        os.environ['TORCH_HOME']= checkpoint_path\r\n\r\n        self.checkpoint_path = checkpoint_path\r\n        self.config_path = config_path\r\n      \r\n\r\n    def test(self, source_image, driven_audio, preprocess='crop', \r\n        still_mode=False,  use_enhancer=False, batch_size=1, size=256, \r\n        pose_style = 0, exp_scale=1.0, \r\n        use_ref_video = False,\r\n        ref_video = None,\r\n        ref_info = None,\r\n        use_idle_mode = False,\r\n        length_of_audio = 0, use_blink=True,\r\n        result_dir='./results/'):\r\n\r\n        self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)\r\n        print(self.sadtalker_paths)\r\n            \r\n        self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)\r\n        self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)\r\n        self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)\r\n\r\n        time_tag = str(uuid.uuid4())\r\n        save_dir = os.path.join(result_dir, time_tag)\r\n        os.makedirs(save_dir, exist_ok=True)\r\n\r\n        input_dir = os.path.join(save_dir, 'input')\r\n        os.makedirs(input_dir, exist_ok=True)\r\n\r\n        print(source_image)\r\n        pic_path = os.path.join(input_dir, os.path.basename(source_image)) \r\n        shutil.move(source_image, input_dir)\r\n\r\n        if driven_audio is not None and os.path.isfile(driven_audio):\r\n            audio_path = os.path.join(input_dir, os.path.basename(driven_audio))  \r\n\r\n            #### mp3 to wav\r\n            if '.mp3' in audio_path:\r\n                mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)\r\n                audio_path = audio_path.replace('.mp3', '.wav')\r\n            else:\r\n                shutil.move(driven_audio, input_dir)\r\n\r\n        elif use_idle_mode:\r\n            audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path\r\n            from pydub import AudioSegment\r\n            one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio)  #duration in milliseconds\r\n            one_sec_segment.export(audio_path, format=\"wav\")\r\n        else:\r\n            print(use_ref_video, ref_info)\r\n            assert use_ref_video == True and ref_info == 'all'\r\n\r\n        if use_ref_video and ref_info == 'all': # full ref mode\r\n            ref_video_videoname = os.path.basename(ref_video)\r\n            audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')\r\n            print('new audiopath:',audio_path)\r\n            # if ref_video contains audio, set the audio from ref_video.\r\n            cmd = r\"ffmpeg -y -hide_banner -loglevel error -i %s %s\"%(ref_video, audio_path)\r\n            os.system(cmd)        \r\n\r\n        os.makedirs(save_dir, exist_ok=True)\r\n        \r\n        #crop image and extract 3dmm from image\r\n        first_frame_dir = os.path.join(save_dir, 'first_frame_dir')\r\n        os.makedirs(first_frame_dir, exist_ok=True)\r\n        first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size)\r\n        \r\n        if first_coeff_path is None:\r\n            raise AttributeError(\"No face is detected\")\r\n\r\n        if use_ref_video:\r\n            print('using ref video for genreation')\r\n            ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]\r\n            ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)\r\n            os.makedirs(ref_video_frame_dir, exist_ok=True)\r\n            print('3DMM Extraction for the reference video providing pose')\r\n            ref_video_coeff_path, _, _ =  self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)\r\n        else:\r\n            ref_video_coeff_path = None\r\n\r\n        if use_ref_video:\r\n            if ref_info == 'pose':\r\n                ref_pose_coeff_path = ref_video_coeff_path\r\n                ref_eyeblink_coeff_path = None\r\n            elif ref_info == 'blink':\r\n                ref_pose_coeff_path = None\r\n                ref_eyeblink_coeff_path = ref_video_coeff_path\r\n            elif ref_info == 'pose+blink':\r\n                ref_pose_coeff_path = ref_video_coeff_path\r\n                ref_eyeblink_coeff_path = ref_video_coeff_path\r\n            elif ref_info == 'all':            \r\n                ref_pose_coeff_path = None\r\n                ref_eyeblink_coeff_path = None\r\n            else:\r\n                raise('error in refinfo')\r\n        else:\r\n            ref_pose_coeff_path = None\r\n            ref_eyeblink_coeff_path = None\r\n\r\n        #audio2ceoff\r\n        if use_ref_video and ref_info == 'all':\r\n            coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)\r\n        else:\r\n            batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?\r\n            coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)\r\n\r\n        #coeff2video\r\n        data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size, expression_scale = exp_scale)\r\n        return_path = self.animate_from_coeff.generate(data, save_dir,  pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)\r\n        video_name = data['video_name']\r\n        print(f'The generated video is named {video_name} in {save_dir}')\r\n\r\n        del self.preprocess_model\r\n        del self.audio_to_coeff\r\n        del self.animate_from_coeff\r\n\r\n        if torch.cuda.is_available():\r\n            torch.cuda.empty_cache()\r\n            torch.cuda.synchronize()\r\n            \r\n        import gc; gc.collect()\r\n        \r\n        return return_path\r\n\r\n    "
  },
  {
    "path": "src/test_audio2coeff.py",
    "content": "import os \nimport torch\nimport numpy as np\nfrom scipy.io import savemat, loadmat\nfrom yacs.config import CfgNode as CN\nfrom scipy.signal import savgol_filter\n\nimport safetensors\nimport safetensors.torch \n\nfrom src.audio2pose_models.audio2pose import Audio2Pose\nfrom src.audio2exp_models.networks import SimpleWrapperV2 \nfrom src.audio2exp_models.audio2exp import Audio2Exp\nfrom src.utils.safetensor_helper import load_x_from_safetensor  \n\ndef load_cpk(checkpoint_path, model=None, optimizer=None, device=\"cpu\"):\n    checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))\n    if model is not None:\n        model.load_state_dict(checkpoint['model'])\n    if optimizer is not None:\n        optimizer.load_state_dict(checkpoint['optimizer'])\n\n    return checkpoint['epoch']\n\nclass Audio2Coeff():\n\n    def __init__(self, sadtalker_path, device):\n        #load config\n        fcfg_pose = open(sadtalker_path['audio2pose_yaml_path'])\n        cfg_pose = CN.load_cfg(fcfg_pose)\n        cfg_pose.freeze()\n        fcfg_exp = open(sadtalker_path['audio2exp_yaml_path'])\n        cfg_exp = CN.load_cfg(fcfg_exp)\n        cfg_exp.freeze()\n\n        # load audio2pose_model\n        self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device)\n        self.audio2pose_model = self.audio2pose_model.to(device)\n        self.audio2pose_model.eval()\n        for param in self.audio2pose_model.parameters():\n            param.requires_grad = False \n        \n        try:\n            if sadtalker_path['use_safetensor']:\n                checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])\n                self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose'))\n            else:\n                load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device)\n        except:\n            raise Exception(\"Failed in loading audio2pose_checkpoint\")\n\n        # load audio2exp_model\n        netG = SimpleWrapperV2()\n        netG = netG.to(device)\n        for param in netG.parameters():\n            netG.requires_grad = False\n        netG.eval()\n        try:\n            if sadtalker_path['use_safetensor']:\n                checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])\n                netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp'))\n            else:\n                load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device)\n        except:\n            raise Exception(\"Failed in loading audio2exp_checkpoint\")\n        self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False)\n        self.audio2exp_model = self.audio2exp_model.to(device)\n        for param in self.audio2exp_model.parameters():\n            param.requires_grad = False\n        self.audio2exp_model.eval()\n \n        self.device = device\n\n    def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None):\n\n        with torch.no_grad():\n            #test\n            results_dict_exp= self.audio2exp_model.test(batch)\n            exp_pred = results_dict_exp['exp_coeff_pred']                         #bs T 64\n\n            #for class_id in  range(1):\n            #class_id = 0#(i+10)%45\n            #class_id = random.randint(0,46)                                   #46 styles can be selected \n            batch['class'] = torch.LongTensor([pose_style]).to(self.device)\n            results_dict_pose = self.audio2pose_model.test(batch) \n            pose_pred = results_dict_pose['pose_pred']                        #bs T 6\n\n            pose_len = pose_pred.shape[1]\n            if pose_len<13: \n                pose_len = int((pose_len-1)/2)*2+1\n                pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device)\n            else:\n                pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device) \n            \n            coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1)            #bs T 70\n\n            coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy() \n\n            if ref_pose_coeff_path is not None: \n                 coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path)\n        \n            savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])),  \n                    {'coeff_3dmm': coeffs_pred_numpy})\n\n            return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name']))\n    \n    def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):\n        num_frames = coeffs_pred_numpy.shape[0]\n        refpose_coeff_dict = loadmat(ref_pose_coeff_path)\n        refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70]\n        refpose_num_frames = refpose_coeff.shape[0]\n        if refpose_num_frames<num_frames:\n            div = num_frames//refpose_num_frames\n            re = num_frames%refpose_num_frames\n            refpose_coeff_list = [refpose_coeff for i in range(div)]\n            refpose_coeff_list.append(refpose_coeff[:re, :])\n            refpose_coeff = np.concatenate(refpose_coeff_list, axis=0)\n\n        #### relative head pose\n        coeffs_pred_numpy[:, 64:70] = coeffs_pred_numpy[:, 64:70] + ( refpose_coeff[:num_frames, :] - refpose_coeff[0:1, :] )\n        return coeffs_pred_numpy\n\n\n"
  },
  {
    "path": "src/utils/audio.py",
    "content": "import librosa\nimport librosa.filters\nimport numpy as np\n# import tensorflow as tf\nfrom scipy import signal\nfrom scipy.io import wavfile\nfrom src.utils.hparams import hparams as hp\n\ndef load_wav(path, sr):\n    return librosa.core.load(path, sr=sr)[0]\n\ndef save_wav(wav, path, sr):\n    wav *= 32767 / max(0.01, np.max(np.abs(wav)))\n    #proposed by @dsmiller\n    wavfile.write(path, sr, wav.astype(np.int16))\n\ndef save_wavenet_wav(wav, path, sr):\n    librosa.output.write_wav(path, wav, sr=sr)\n\ndef preemphasis(wav, k, preemphasize=True):\n    if preemphasize:\n        return signal.lfilter([1, -k], [1], wav)\n    return wav\n\ndef inv_preemphasis(wav, k, inv_preemphasize=True):\n    if inv_preemphasize:\n        return signal.lfilter([1], [1, -k], wav)\n    return wav\n\ndef get_hop_size():\n    hop_size = hp.hop_size\n    if hop_size is None:\n        assert hp.frame_shift_ms is not None\n        hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)\n    return hop_size\n\ndef linearspectrogram(wav):\n    D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))\n    S = _amp_to_db(np.abs(D)) - hp.ref_level_db\n    \n    if hp.signal_normalization:\n        return _normalize(S)\n    return S\n\ndef melspectrogram(wav):\n    D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))\n    S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db\n    \n    if hp.signal_normalization:\n        return _normalize(S)\n    return S\n\ndef _lws_processor():\n    import lws\n    return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode=\"speech\")\n\ndef _stft(y):\n    if hp.use_lws:\n        return _lws_processor(hp).stft(y).T\n    else:\n        return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)\n\n##########################################################\n#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)\ndef num_frames(length, fsize, fshift):\n    \"\"\"Compute number of time frames of spectrogram\n    \"\"\"\n    pad = (fsize - fshift)\n    if length % fshift == 0:\n        M = (length + pad * 2 - fsize) // fshift + 1\n    else:\n        M = (length + pad * 2 - fsize) // fshift + 2\n    return M\n\n\ndef pad_lr(x, fsize, fshift):\n    \"\"\"Compute left and right padding\n    \"\"\"\n    M = num_frames(len(x), fsize, fshift)\n    pad = (fsize - fshift)\n    T = len(x) + 2 * pad\n    r = (M - 1) * fshift + fsize - T\n    return pad, pad + r\n##########################################################\n#Librosa correct padding\ndef librosa_pad_lr(x, fsize, fshift):\n    return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]\n\n# Conversions\n_mel_basis = None\n\ndef _linear_to_mel(spectogram):\n    global _mel_basis\n    if _mel_basis is None:\n        _mel_basis = _build_mel_basis()\n    return np.dot(_mel_basis, spectogram)\n\ndef _build_mel_basis():\n    assert hp.fmax <= hp.sample_rate // 2\n    return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,\n                               fmin=hp.fmin, fmax=hp.fmax)\n\ndef _amp_to_db(x):\n    min_level = np.exp(hp.min_level_db / 20 * np.log(10))\n    return 20 * np.log10(np.maximum(min_level, x))\n\ndef _db_to_amp(x):\n    return np.power(10.0, (x) * 0.05)\n\ndef _normalize(S):\n    if hp.allow_clipping_in_normalization:\n        if hp.symmetric_mels:\n            return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,\n                           -hp.max_abs_value, hp.max_abs_value)\n        else:\n            return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)\n    \n    assert S.max() <= 0 and S.min() - hp.min_level_db >= 0\n    if hp.symmetric_mels:\n        return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value\n    else:\n        return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))\n\ndef _denormalize(D):\n    if hp.allow_clipping_in_normalization:\n        if hp.symmetric_mels:\n            return (((np.clip(D, -hp.max_abs_value,\n                              hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))\n                    + hp.min_level_db)\n        else:\n            return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)\n    \n    if hp.symmetric_mels:\n        return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)\n    else:\n        return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)\n"
  },
  {
    "path": "src/utils/croper.py",
    "content": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport scipy\nimport numpy as np\nfrom PIL import Image\nimport torch\nfrom tqdm import tqdm\nfrom itertools import cycle\n\nfrom src.face3d.extract_kp_videos_safe import KeypointExtractor\nfrom facexlib.alignment import landmark_98_to_68\n\nimport numpy as np\nfrom PIL import Image\n\nclass Preprocesser:\n    def __init__(self, device='cuda'):\n        self.predictor = KeypointExtractor(device)\n\n    def get_landmark(self, img_np):\n        \"\"\"get landmark with dlib\n        :return: np.array shape=(68, 2)\n        \"\"\"\n        with torch.no_grad():\n            dets = self.predictor.det_net.detect_faces(img_np, 0.97)\n\n        if len(dets) == 0:\n            return None\n        det = dets[0]\n\n        img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]\n        lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0]\n\n        #### keypoints to the original location\n        lm[:,0] += int(det[0])\n        lm[:,1] += int(det[1])\n\n        return lm\n\n    def align_face(self, img, lm, output_size=1024):\n        \"\"\"\n        :param filepath: str\n        :return: PIL Image\n        \"\"\"\n        lm_chin = lm[0: 17]  # left-right\n        lm_eyebrow_left = lm[17: 22]  # left-right\n        lm_eyebrow_right = lm[22: 27]  # left-right\n        lm_nose = lm[27: 31]  # top-down\n        lm_nostrils = lm[31: 36]  # top-down\n        lm_eye_left = lm[36: 42]  # left-clockwise\n        lm_eye_right = lm[42: 48]  # left-clockwise\n        lm_mouth_outer = lm[48: 60]  # left-clockwise\n        lm_mouth_inner = lm[60: 68]  # left-clockwise\n\n        # Calculate auxiliary vectors.\n        eye_left = np.mean(lm_eye_left, axis=0)\n        eye_right = np.mean(lm_eye_right, axis=0)\n        eye_avg = (eye_left + eye_right) * 0.5\n        eye_to_eye = eye_right - eye_left\n        mouth_left = lm_mouth_outer[0]\n        mouth_right = lm_mouth_outer[6]\n        mouth_avg = (mouth_left + mouth_right) * 0.5\n        eye_to_mouth = mouth_avg - eye_avg\n\n        # Choose oriented crop rectangle.\n        x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]  # Addition of binocular difference and double mouth difference\n        x /= np.hypot(*x)   # hypot函数计算直角三角形的斜边长，用斜边长对三角形两条直边做归一化\n        x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)    # 双眼差和眼嘴差，选较大的作为基准尺度\n        y = np.flipud(x) * [-1, 1]\n        c = eye_avg + eye_to_mouth * 0.1\n        quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])   # 定义四边形，以面部基准位置为中心上下左右平移得到四个顶点\n        qsize = np.hypot(*x) * 2    # 定义四边形的大小（边长），为基准尺度的2倍\n\n        # Shrink.\n        # 如果计算出的四边形太大了，就按比例缩小它\n        shrink = int(np.floor(qsize / output_size * 0.5))\n        if shrink > 1:\n            rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))\n            img = img.resize(rsize, Image.ANTIALIAS)\n            quad /= shrink\n            qsize /= shrink\n        else:\n            rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1]))))\n\n        # Crop.\n        border = max(int(np.rint(qsize * 0.1)), 3)\n        crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),\n                int(np.ceil(max(quad[:, 1]))))\n        crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),\n                min(crop[3] + border, img.size[1]))\n        if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:\n            # img = img.crop(crop)\n            quad -= crop[0:2]\n\n        # Pad.\n        pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),\n               int(np.ceil(max(quad[:, 1]))))\n        pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),\n               max(pad[3] - img.size[1] + border, 0))\n        # if enable_padding and max(pad) > border - 4:\n        #     pad = np.maximum(pad, int(np.rint(qsize * 0.3)))\n        #     img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')\n        #     h, w, _ = img.shape\n        #     y, x, _ = np.ogrid[:h, :w, :1]\n        #     mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),\n        #                       1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))\n        #     blur = qsize * 0.02\n        #     img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)\n        #     img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)\n        #     img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')\n        #     quad += pad[:2]\n\n        # Transform.\n        quad = (quad + 0.5).flatten()\n        lx = max(min(quad[0], quad[2]), 0)\n        ly = max(min(quad[1], quad[7]), 0)\n        rx = min(max(quad[4], quad[6]), img.size[0])\n        ry = min(max(quad[3], quad[5]), img.size[0])\n\n        # Save aligned image.\n        return rsize, crop, [lx, ly, rx, ry]\n    \n    def crop(self, img_np_list, still=False, xsize=512):    # first frame for all video\n        img_np = img_np_list[0]\n        lm = self.get_landmark(img_np)\n\n        if lm is None:\n            raise 'can not detect the landmark from source image'\n        rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)\n        clx, cly, crx, cry = crop\n        lx, ly, rx, ry = quad\n        lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)\n        for _i in range(len(img_np_list)):\n            _inp = img_np_list[_i]\n            _inp = cv2.resize(_inp, (rsize[0], rsize[1]))\n            _inp = _inp[cly:cry, clx:crx]\n            if not still:\n                _inp = _inp[ly:ry, lx:rx]\n            img_np_list[_i] = _inp\n        return img_np_list, crop, quad\n\n"
  },
  {
    "path": "src/utils/face_enhancer.py",
    "content": "import os\nimport torch \n\nfrom gfpgan import GFPGANer\n\nfrom tqdm import tqdm\n\nfrom src.utils.videoio import load_video_to_cv2\n\nimport cv2\n\n\nclass GeneratorWithLen(object):\n    \"\"\" From https://stackoverflow.com/a/7460929 \"\"\"\n\n    def __init__(self, gen, length):\n        self.gen = gen\n        self.length = length\n\n    def __len__(self):\n        return self.length\n\n    def __iter__(self):\n        return self.gen\n\ndef enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):\n    gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)\n    return list(gen)\n\ndef enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):\n    \"\"\" Provide a generator with a __len__ method so that it can passed to functions that\n    call len()\"\"\"\n\n    if os.path.isfile(images): # handle video to images\n        # TODO: Create a generator version of load_video_to_cv2\n        images = load_video_to_cv2(images)\n\n    gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)\n    gen_with_len = GeneratorWithLen(gen, len(images))\n    return gen_with_len\n\ndef enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):\n    \"\"\" Provide a generator function so that all of the enhanced images don't need\n    to be stored in memory at the same time. This can save tons of RAM compared to\n    the enhancer function. \"\"\"\n\n    print('face enhancer....')\n    if not isinstance(images, list) and os.path.isfile(images): # handle video to images\n        images = load_video_to_cv2(images)\n\n    # ------------------------ set up GFPGAN restorer ------------------------\n    if  method == 'gfpgan':\n        arch = 'clean'\n        channel_multiplier = 2\n        model_name = 'GFPGANv1.4'\n        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'\n    elif method == 'RestoreFormer':\n        arch = 'RestoreFormer'\n        channel_multiplier = 2\n        model_name = 'RestoreFormer'\n        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'\n    elif method == 'codeformer': # TODO:\n        arch = 'CodeFormer'\n        channel_multiplier = 2\n        model_name = 'CodeFormer'\n        url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'\n    else:\n        raise ValueError(f'Wrong model version {method}.')\n\n\n    # ------------------------ set up background upsampler ------------------------\n    if bg_upsampler == 'realesrgan':\n        if not torch.cuda.is_available():  # CPU\n            import warnings\n            warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '\n                          'If you really want to use it, please modify the corresponding codes.')\n            bg_upsampler = None\n        else:\n            from basicsr.archs.rrdbnet_arch import RRDBNet\n            from realesrgan import RealESRGANer\n            model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)\n            bg_upsampler = RealESRGANer(\n                scale=2,\n                model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',\n                model=model,\n                tile=400,\n                tile_pad=10,\n                pre_pad=0,\n                half=True)  # need to set False in CPU mode\n    else:\n        bg_upsampler = None\n\n    # determine model paths\n    model_path = os.path.join('gfpgan/weights', model_name + '.pth')\n    \n    if not os.path.isfile(model_path):\n        model_path = os.path.join('checkpoints', model_name + '.pth')\n    \n    if not os.path.isfile(model_path):\n        # download pre-trained models from url\n        model_path = url\n\n    restorer = GFPGANer(\n        model_path=model_path,\n        upscale=2,\n        arch=arch,\n        channel_multiplier=channel_multiplier,\n        bg_upsampler=bg_upsampler)\n\n    # ------------------------ restore ------------------------\n    for idx in tqdm(range(len(images)), 'Face Enhancer:'):\n        \n        img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)\n        \n        # restore faces and background if necessary\n        cropped_faces, restored_faces, r_img = restorer.enhance(\n            img,\n            has_aligned=False,\n            only_center_face=False,\n            paste_back=True)\n        \n        r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)\n        yield r_img\n"
  },
  {
    "path": "src/utils/hparams.py",
    "content": "from glob import glob\nimport os\n\nclass HParams:\n\tdef __init__(self, **kwargs):\n\t\tself.data = {}\n\n\t\tfor key, value in kwargs.items():\n\t\t\tself.data[key] = value\n\n\tdef __getattr__(self, key):\n\t\tif key not in self.data:\n\t\t\traise AttributeError(\"'HParams' object has no attribute %s\" % key)\n\t\treturn self.data[key]\n\n\tdef set_hparam(self, key, value):\n\t\tself.data[key] = value\n\n\n# Default hyperparameters\nhparams = HParams(\n\tnum_mels=80,  # Number of mel-spectrogram channels and local conditioning dimensionality\n\t#  network\n\trescale=True,  # Whether to rescale audio prior to preprocessing\n\trescaling_max=0.9,  # Rescaling value\n\t\n\t# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction\n\t# It\"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder\n\t# Does not work if n_ffit is not multiple of hop_size!!\n\tuse_lws=False,\n\t\n\tn_fft=800,  # Extra window size is filled with 0 paddings to match this parameter\n\thop_size=200,  # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)\n\twin_size=800,  # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)\n\tsample_rate=16000,  # 16000Hz (corresponding to librispeech) (sox --i <filename>)\n\t\n\tframe_shift_ms=None,  # Can replace hop_size parameter. (Recommended: 12.5)\n\t\n\t# Mel and Linear spectrograms normalization/scaling and clipping\n\tsignal_normalization=True,\n\t# Whether to normalize mel spectrograms to some predefined range (following below parameters)\n\tallow_clipping_in_normalization=True,  # Only relevant if mel_normalization = True\n\tsymmetric_mels=True,\n\t# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, \n\t# faster and cleaner convergence)\n\tmax_abs_value=4.,\n\t# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not \n\t# be too big to avoid gradient explosion, \n\t# not too small for fast convergence)\n\t# Contribution by @begeekmyfriend\n\t# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude \n\t# levels. Also allows for better G&L phase reconstruction)\n\tpreemphasize=True,  # whether to apply filter\n\tpreemphasis=0.97,  # filter coefficient.\n\t\n\t# Limits\n\tmin_level_db=-100,\n\tref_level_db=20,\n\tfmin=55,\n\t# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To \n\t# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])\n\tfmax=7600,  # To be increased/reduced depending on data.\n\n\t###################### Our training parameters #################################\n\timg_size=96,\n\tfps=25,\n\t\n\tbatch_size=16,\n\tinitial_learning_rate=1e-4,\n\tnepochs=300000,  ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs\n\tnum_workers=20,\n\tcheckpoint_interval=3000,\n\teval_interval=3000,\n\twriter_interval=300,\n    save_optimizer_state=True,\n\n    syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. \n\tsyncnet_batch_size=64,\n\tsyncnet_lr=1e-4,\n\tsyncnet_eval_interval=1000,\n\tsyncnet_checkpoint_interval=10000,\n\n\tdisc_wt=0.07,\n\tdisc_initial_learning_rate=1e-4,\n)\n\n\n\n# Default hyperparameters\nhparamsdebug = HParams(\n\tnum_mels=80,  # Number of mel-spectrogram channels and local conditioning dimensionality\n\t#  network\n\trescale=True,  # Whether to rescale audio prior to preprocessing\n\trescaling_max=0.9,  # Rescaling value\n\t\n\t# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction\n\t# It\"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder\n\t# Does not work if n_ffit is not multiple of hop_size!!\n\tuse_lws=False,\n\t\n\tn_fft=800,  # Extra window size is filled with 0 paddings to match this parameter\n\thop_size=200,  # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)\n\twin_size=800,  # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)\n\tsample_rate=16000,  # 16000Hz (corresponding to librispeech) (sox --i <filename>)\n\t\n\tframe_shift_ms=None,  # Can replace hop_size parameter. (Recommended: 12.5)\n\t\n\t# Mel and Linear spectrograms normalization/scaling and clipping\n\tsignal_normalization=True,\n\t# Whether to normalize mel spectrograms to some predefined range (following below parameters)\n\tallow_clipping_in_normalization=True,  # Only relevant if mel_normalization = True\n\tsymmetric_mels=True,\n\t# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, \n\t# faster and cleaner convergence)\n\tmax_abs_value=4.,\n\t# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not \n\t# be too big to avoid gradient explosion, \n\t# not too small for fast convergence)\n\t# Contribution by @begeekmyfriend\n\t# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude \n\t# levels. Also allows for better G&L phase reconstruction)\n\tpreemphasize=True,  # whether to apply filter\n\tpreemphasis=0.97,  # filter coefficient.\n\t\n\t# Limits\n\tmin_level_db=-100,\n\tref_level_db=20,\n\tfmin=55,\n\t# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To \n\t# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])\n\tfmax=7600,  # To be increased/reduced depending on data.\n\n\t###################### Our training parameters #################################\n\timg_size=96,\n\tfps=25,\n\t\n\tbatch_size=2,\n\tinitial_learning_rate=1e-3,\n\tnepochs=100000,  ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs\n\tnum_workers=0,\n\tcheckpoint_interval=10000,\n\teval_interval=10,\n\twriter_interval=5,\n    save_optimizer_state=True,\n\n    syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. \n\tsyncnet_batch_size=64,\n\tsyncnet_lr=1e-4,\n\tsyncnet_eval_interval=10000,\n\tsyncnet_checkpoint_interval=10000,\n\n\tdisc_wt=0.07,\n\tdisc_initial_learning_rate=1e-4,\n)\n\n\ndef hparams_debug_string():\n\tvalues = hparams.values()\n\thp = [\"  %s: %s\" % (name, values[name]) for name in sorted(values) if name != \"sentences\"]\n\treturn \"Hyperparameters:\\n\" + \"\\n\".join(hp)\n"
  },
  {
    "path": "src/utils/init_path.py",
    "content": "import os\nimport glob\n\ndef init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'):\n\n    if old_version:\n        #### load all the checkpoint of `pth`\n        sadtalker_paths = {\n                'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),\n                'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),\n                'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),\n                'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),\n                'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')\n        }\n\n        use_safetensor = False\n    elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))):\n        print('using safetensor as default')\n        sadtalker_paths = {\n            \"checkpoint\":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'),\n            }\n        use_safetensor = True\n    else:\n        print(\"WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!\")\n        use_safetensor = False\n        \n        sadtalker_paths = {\n                'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),\n                'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),\n                'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),\n                'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),\n                'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')\n        }\n\n    sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting'\n    sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml')\n    sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml')\n    sadtalker_paths['use_safetensor'] =  use_safetensor # os.path.join(config_dir, 'auido2exp.yaml')\n\n    if 'full' in preprocess:\n        sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar')\n        sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml')\n    else:\n        sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar')\n        sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml')\n\n    return sadtalker_paths"
  },
  {
    "path": "src/utils/model2safetensor.py",
    "content": "import torch\nimport yaml\nimport os\n\nimport safetensors\nfrom safetensors.torch import save_file\nfrom yacs.config import CfgNode as CN\nimport sys\n\nsys.path.append('/apdcephfs/private_shadowcun/SadTalker')\n\nfrom src.face3d.models import networks\n\nfrom src.facerender.modules.keypoint_detector import HEEstimator, KPDetector\nfrom src.facerender.modules.mapping import MappingNet\nfrom src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator\n\nfrom src.audio2pose_models.audio2pose import Audio2Pose\nfrom src.audio2exp_models.networks import SimpleWrapperV2 \nfrom src.test_audio2coeff import load_cpk\n\nsize = 256\n############ face vid2vid\nconfig_path = os.path.join('src', 'config', 'facerender.yaml')\ncurrent_root_path = '.'\n\npath_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')\nnet_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='')\ncheckpoint = torch.load(path_of_net_recon_model, map_location='cpu')    \nnet_recon.load_state_dict(checkpoint['net_recon'])\n\nwith open(config_path) as f:\n    config = yaml.safe_load(f)\n\ngenerator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],\n                                            **config['model_params']['common_params'])\nkp_extractor = KPDetector(**config['model_params']['kp_detector_params'],\n                            **config['model_params']['common_params'])\nhe_estimator = HEEstimator(**config['model_params']['he_estimator_params'],\n                        **config['model_params']['common_params'])\nmapping = MappingNet(**config['model_params']['mapping_params'])\n\ndef load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None, \n                        kp_detector=None, he_estimator=None, optimizer_generator=None, \n                        optimizer_discriminator=None, optimizer_kp_detector=None, \n                        optimizer_he_estimator=None, device=\"cpu\"):\n\n    checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))\n    if generator is not None:\n        generator.load_state_dict(checkpoint['generator'])\n    if kp_detector is not None:\n        kp_detector.load_state_dict(checkpoint['kp_detector'])\n    if he_estimator is not None:\n        he_estimator.load_state_dict(checkpoint['he_estimator'])\n    if discriminator is not None:\n        try:\n            discriminator.load_state_dict(checkpoint['discriminator'])\n        except:\n            print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')\n    if optimizer_generator is not None:\n        optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])\n    if optimizer_discriminator is not None:\n        try:\n            optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])\n        except RuntimeError as e:\n            print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')\n    if optimizer_kp_detector is not None:\n        optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])\n    if optimizer_he_estimator is not None:\n        optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])\n\n    return checkpoint['epoch']\n\n\ndef load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None, \n                        kp_detector=None, he_estimator=None,  \n                        device=\"cpu\"):\n\n    checkpoint = safetensors.torch.load_file(checkpoint_path)\n\n    if generator is not None:\n        x_generator = {}\n        for k,v in checkpoint.items():\n            if 'generator' in k:\n                x_generator[k.replace('generator.', '')] = v\n        generator.load_state_dict(x_generator)\n    if kp_detector is not None:\n        x_generator = {}\n        for k,v in checkpoint.items():\n            if 'kp_extractor' in k:\n                x_generator[k.replace('kp_extractor.', '')] = v\n        kp_detector.load_state_dict(x_generator)\n    if he_estimator is not None:\n        x_generator = {}\n        for k,v in checkpoint.items():\n            if 'he_estimator' in k:\n                x_generator[k.replace('he_estimator.', '')] = v\n        he_estimator.load_state_dict(x_generator)\n    \n    return None\n\nfree_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar'\nload_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)\n\nwav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')\n\naudio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')\naudio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')\n\naudio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')\naudio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')\n\nfcfg_pose = open(audio2pose_yaml_path)\ncfg_pose = CN.load_cfg(fcfg_pose)\ncfg_pose.freeze()\naudio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint)\naudio2pose_model.eval()\nload_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu')\n\n# load audio2exp_model\nnetG = SimpleWrapperV2()\nnetG.eval()\nload_cpk(audio2exp_checkpoint, model=netG, device='cpu')\n\nclass SadTalker(torch.nn.Module):\n    def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon):\n        super(SadTalker, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.generator = generator\n        self.audio2exp = netG\n        self.audio2pose = audio2pose\n        self.face_3drecon = face_3drecon\n\n\nmodel = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon)\n\n# here, we want to convert it to safetensor\nsave_file(model.state_dict(), \"checkpoints/SadTalker_V0.0.2_\"+str(size)+\".safetensors\")\n\n### test\nload_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None)"
  },
  {
    "path": "src/utils/paste_pic.py",
    "content": "import cv2, os\nimport numpy as np\nfrom tqdm import tqdm\nimport uuid\n\nfrom src.utils.videoio import save_video_with_watermark \n\ndef paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False):\n\n    if not os.path.isfile(pic_path):\n        raise ValueError('pic_path must be a valid path to video/image file')\n    elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:\n        # loader for first frame\n        full_img = cv2.imread(pic_path)\n    else:\n        # loader for videos\n        video_stream = cv2.VideoCapture(pic_path)\n        fps = video_stream.get(cv2.CAP_PROP_FPS)\n        full_frames = [] \n        while 1:\n            still_reading, frame = video_stream.read()\n            if not still_reading:\n                video_stream.release()\n                break \n            break \n        full_img = frame\n    frame_h = full_img.shape[0]\n    frame_w = full_img.shape[1]\n\n    video_stream = cv2.VideoCapture(video_path)\n    fps = video_stream.get(cv2.CAP_PROP_FPS)\n    crop_frames = []\n    while 1:\n        still_reading, frame = video_stream.read()\n        if not still_reading:\n            video_stream.release()\n            break\n        crop_frames.append(frame)\n    \n    if len(crop_info) != 3:\n        print(\"you didn't crop the image\")\n        return\n    else:\n        r_w, r_h = crop_info[0]\n        clx, cly, crx, cry = crop_info[1]\n        lx, ly, rx, ry = crop_info[2]\n        lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)\n        # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx\n        # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx\n\n        if extended_crop:\n            oy1, oy2, ox1, ox2 = cly, cry, clx, crx\n        else:\n            oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx\n\n    tmp_path = str(uuid.uuid4())+'.mp4'\n    out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))\n    for crop_frame in tqdm(crop_frames, 'seamlessClone:'):\n        p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1)) \n\n        mask = 255*np.ones(p.shape, p.dtype)\n        location = ((ox1+ox2) // 2, (oy1+oy2) // 2)\n        gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE)\n        out_tmp.write(gen_img)\n\n    out_tmp.release()\n\n    save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False)\n    os.remove(tmp_path)\n"
  },
  {
    "path": "src/utils/preprocess.py",
    "content": "import numpy as np\nimport cv2, os, sys, torch\nfrom tqdm import tqdm\nfrom PIL import Image \n\n# 3dmm extraction\nimport safetensors\nimport safetensors.torch \nfrom src.face3d.util.preprocess import align_img\nfrom src.face3d.util.load_mats import load_lm3d\nfrom src.face3d.models import networks\n\nfrom scipy.io import loadmat, savemat\nfrom src.utils.croper import Preprocesser\n\n\nimport warnings\n\nfrom src.utils.safetensor_helper import load_x_from_safetensor \nwarnings.filterwarnings(\"ignore\")\n\ndef split_coeff(coeffs):\n        \"\"\"\n        Return:\n            coeffs_dict     -- a dict of torch.tensors\n\n        Parameters:\n            coeffs          -- torch.tensor, size (B, 256)\n        \"\"\"\n        id_coeffs = coeffs[:, :80]\n        exp_coeffs = coeffs[:, 80: 144]\n        tex_coeffs = coeffs[:, 144: 224]\n        angles = coeffs[:, 224: 227]\n        gammas = coeffs[:, 227: 254]\n        translations = coeffs[:, 254:]\n        return {\n            'id': id_coeffs,\n            'exp': exp_coeffs,\n            'tex': tex_coeffs,\n            'angle': angles,\n            'gamma': gammas,\n            'trans': translations\n        }\n\n\nclass CropAndExtract():\n    def __init__(self, sadtalker_path, device):\n\n        self.propress = Preprocesser(device)\n        self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)\n        \n        if sadtalker_path['use_safetensor']:\n            checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint'])    \n            self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon'))\n        else:\n            checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device))    \n            self.net_recon.load_state_dict(checkpoint['net_recon'])\n\n        self.net_recon.eval()\n        self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting'])\n        self.device = device\n    \n    def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256):\n\n        pic_name = os.path.splitext(os.path.split(input_path)[-1])[0]  \n\n        landmarks_path =  os.path.join(save_dir, pic_name+'_landmarks.txt') \n        coeff_path =  os.path.join(save_dir, pic_name+'.mat')  \n        png_path =  os.path.join(save_dir, pic_name+'.png')  \n\n        #load input\n        if not os.path.isfile(input_path):\n            raise ValueError('input_path must be a valid path to video/image file')\n        elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:\n            # loader for first frame\n            full_frames = [cv2.imread(input_path)]\n            fps = 25\n        else:\n            # loader for videos\n            video_stream = cv2.VideoCapture(input_path)\n            fps = video_stream.get(cv2.CAP_PROP_FPS)\n            full_frames = [] \n            while 1:\n                still_reading, frame = video_stream.read()\n                if not still_reading:\n                    video_stream.release()\n                    break \n                full_frames.append(frame) \n                if source_image_flag:\n                    break\n\n        x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  for frame in full_frames] \n\n        #### crop images as the \n        if 'crop' in crop_or_resize.lower(): # default crop\n            x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)\n            clx, cly, crx, cry = crop\n            lx, ly, rx, ry = quad\n            lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)\n            oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx\n            crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)\n        elif 'full' in crop_or_resize.lower():\n            x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)\n            clx, cly, crx, cry = crop\n            lx, ly, rx, ry = quad\n            lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)\n            oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx\n            crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)\n        else: # resize mode\n            oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] \n            crop_info = ((ox2 - ox1, oy2 - oy1), None, None)\n\n        frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames]\n        if len(frames_pil) == 0:\n            print('No face is detected in the input file')\n            return None, None\n\n        # save crop info\n        for frame in frames_pil:\n            cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))\n\n        # 2. get the landmark according to the detected face. \n        if not os.path.isfile(landmarks_path): \n            lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path)\n        else:\n            print(' Using saved landmarks.')\n            lm = np.loadtxt(landmarks_path).astype(np.float32)\n            lm = lm.reshape([len(x_full_frames), -1, 2])\n\n        if not os.path.isfile(coeff_path):\n            # load 3dmm paramter generator from Deep3DFaceRecon_pytorch \n            video_coeffs, full_coeffs = [],  []\n            for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'):\n                frame = frames_pil[idx]\n                W,H = frame.size\n                lm1 = lm[idx].reshape([-1, 2])\n            \n                if np.mean(lm1) == -1:\n                    lm1 = (self.lm3d_std[:, :2]+1)/2.\n                    lm1 = np.concatenate(\n                        [lm1[:, :1]*W, lm1[:, 1:2]*H], 1\n                    )\n                else:\n                    lm1[:, -1] = H - 1 - lm1[:, -1]\n\n                trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std)\n \n                trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)\n                im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0)\n                \n                with torch.no_grad():\n                    full_coeff = self.net_recon(im_t)\n                    coeffs = split_coeff(full_coeff)\n\n                pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}\n \n                pred_coeff = np.concatenate([\n                    pred_coeff['exp'], \n                    pred_coeff['angle'],\n                    pred_coeff['trans'],\n                    trans_params[2:][None],\n                    ], 1)\n                video_coeffs.append(pred_coeff)\n                full_coeffs.append(full_coeff.cpu().numpy())\n\n            semantic_npy = np.array(video_coeffs)[:,0] \n\n            savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0]})\n\n        return coeff_path, png_path, crop_info\n"
  },
  {
    "path": "src/utils/safetensor_helper.py",
    "content": "\n\ndef load_x_from_safetensor(checkpoint, key):\n    x_generator = {}\n    for k,v in checkpoint.items():\n        if key in k:\n            x_generator[k.replace(key+'.', '')] = v\n    return x_generator"
  },
  {
    "path": "src/utils/text2speech.py",
    "content": "import os\nimport tempfile\nfrom TTS.api import TTS\n\n\nclass TTSTalker():\n    def __init__(self) -> None:\n        model_name = TTS().list_models()[0]\n        self.tts = TTS(model_name)\n\n    def test(self, text, language='en'):\n\n        tempf  = tempfile.NamedTemporaryFile(\n                delete = False,\n                suffix = ('.'+'wav'),\n            )\n\n        self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name)\n\n        return tempf.name\n"
  },
  {
    "path": "src/utils/videoio.py",
    "content": "import shutil\nimport uuid\n\nimport os\n\nimport cv2\n\ndef load_video_to_cv2(input_path):\n    video_stream = cv2.VideoCapture(input_path)\n    fps = video_stream.get(cv2.CAP_PROP_FPS)\n    full_frames = [] \n    while 1:\n        still_reading, frame = video_stream.read()\n        if not still_reading:\n            video_stream.release()\n            break \n        full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n    return full_frames\n\ndef save_video_with_watermark(video, audio, save_path, watermark=False):\n    temp_file = str(uuid.uuid4())+'.mp4'\n    cmd = r'ffmpeg -y -hide_banner -loglevel error -i \"%s\" -i \"%s\" -vcodec copy \"%s\"' % (video, audio, temp_file)\n    os.system(cmd)\n\n    if watermark is False:\n        shutil.move(temp_file, save_path)\n    else:\n        # watermark\n        try:\n            ##### check if stable-diffusion-webui\n            import webui\n            from modules import paths\n            watarmark_path = paths.script_path+\"/extensions/SadTalker/docs/sadtalker_logo.png\"\n        except:\n            # get the root path of sadtalker.\n            dir_path = os.path.dirname(os.path.realpath(__file__))\n            watarmark_path = dir_path+\"/../../docs/sadtalker_logo.png\"\n\n        cmd = r'ffmpeg -y -hide_banner -loglevel error -i \"%s\" -i \"%s\" -filter_complex \"[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10\" \"%s\"' % (temp_file, watarmark_path, save_path)\n        os.system(cmd)\n        os.remove(temp_file)"
  },
  {
    "path": "webui.bat",
    "content": "@echo off\n\nIF NOT EXIST venv (\npython -m venv venv\n) ELSE (\necho venv folder already exists, skipping creation...\n)\ncall .\\venv\\Scripts\\activate.bat\n\nset PYTHON=\"venv\\Scripts\\Python.exe\"\necho venv %PYTHON%\n\n%PYTHON% Launcher.py\n\necho.\necho Launch unsuccessful. Exiting.\npause"
  },
  {
    "path": "webui.sh",
    "content": "#!/usr/bin/env bash\n\n\n# If run from macOS, load defaults from webui-macos-env.sh\nif [[ \"$OSTYPE\" == \"darwin\"* ]]; then\n    export TORCH_COMMAND=\"pip install torch==1.12.1 torchvision==0.13.1\"\nfi\n\n# python3 executable\nif [[ -z \"${python_cmd}\" ]]\nthen\n    python_cmd=\"python3\"\nfi\n\n# git executable\nif [[ -z \"${GIT}\" ]]\nthen\n    export GIT=\"git\"\nfi\n\n# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)\nif [[ -z \"${venv_dir}\" ]]\nthen\n    venv_dir=\"venv\"\nfi\n\nif [[ -z \"${LAUNCH_SCRIPT}\" ]]\nthen\n    LAUNCH_SCRIPT=\"launcher.py\"\nfi\n\n# this script cannot be run as root by default\ncan_run_as_root=1\n\n# read any command line flags to the webui.sh script\nwhile getopts \"f\" flag > /dev/null 2>&1\ndo\n    case ${flag} in\n        f) can_run_as_root=1;;\n        *) break;;\n    esac\ndone\n\n# Disable sentry logging\nexport ERROR_REPORTING=FALSE\n\n# Do not reinstall existing pip packages on Debian/Ubuntu\nexport PIP_IGNORE_INSTALLED=0\n\n# Pretty print\ndelimiter=\"################################################################\"\n\nprintf \"\\n%s\\n\" \"${delimiter}\"\nprintf \"\\e[1m\\e[32mInstall script for SadTalker + Web UI\\n\"\nprintf \"\\e[1m\\e[34mTested on Debian 11 (Bullseye)\\e[0m\"\nprintf \"\\n%s\\n\" \"${delimiter}\"\n\n# Do not run as root\nif [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]]\nthen\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"\\e[1m\\e[31mERROR: This script must not be launched as root, aborting...\\e[0m\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    exit 1\nelse\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"Running on \\e[1m\\e[32m%s\\e[0m user\" \"$(whoami)\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\nfi\n\nif [[ -d .git ]]\nthen\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"Repo already cloned, using it as install directory\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    install_dir=\"${PWD}/../\"\n    clone_dir=\"${PWD##*/}\"\nfi\n\n# Check prerequisites\ngpu_info=$(lspci 2>/dev/null | grep VGA)\ncase \"$gpu_info\" in\n    *\"Navi 1\"*|*\"Navi 2\"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0\n    ;;\n    *\"Renoir\"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        printf \"Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half\"\n        printf \"\\n%s\\n\" \"${delimiter}\"\n    ;;\n    *) \n    ;;\nesac\nif echo \"$gpu_info\" | grep -q \"AMD\" && [[ -z \"${TORCH_COMMAND}\" ]]\nthen\n    export TORCH_COMMAND=\"pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2\"\nfi  \n\nfor preq in \"${GIT}\" \"${python_cmd}\"\ndo\n    if ! hash \"${preq}\" &>/dev/null\n    then\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        printf \"\\e[1m\\e[31mERROR: %s is not installed, aborting...\\e[0m\" \"${preq}\"\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        exit 1\n    fi\ndone\n\nif ! \"${python_cmd}\" -c \"import venv\" &>/dev/null\nthen\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"\\e[1m\\e[31mERROR: python3-venv is not installed, aborting...\\e[0m\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    exit 1\nfi\n\nprintf \"\\n%s\\n\" \"${delimiter}\"\nprintf \"Create and activate python venv\"\nprintf \"\\n%s\\n\" \"${delimiter}\"\ncd \"${install_dir}\"/\"${clone_dir}\"/ || { printf \"\\e[1m\\e[31mERROR: Can't cd to %s/%s/, aborting...\\e[0m\" \"${install_dir}\" \"${clone_dir}\"; exit 1; }\nif [[ ! -d \"${venv_dir}\" ]]\nthen\n    \"${python_cmd}\" -m venv \"${venv_dir}\"\n    first_launch=1\nfi\n# shellcheck source=/dev/null\nif [[ -f \"${venv_dir}\"/bin/activate ]]\nthen\n    source \"${venv_dir}\"/bin/activate\nelse\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"\\e[1m\\e[31mERROR: Cannot activate python venv, aborting...\\e[0m\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    exit 1\nfi\n\nprintf \"\\n%s\\n\" \"${delimiter}\"\nprintf \"Launching launcher.py...\"\nprintf \"\\n%s\\n\" \"${delimiter}\"      \nexec \"${python_cmd}\" \"${LAUNCH_SCRIPT}\" \"$@\""
  }
]