Repository: OpenTalker/SadTalker Branch: main Commit: cd4c0465ae0b Files: 141 Total size: 575.1 KB Directory structure: gitextract_pkbb6gh_/ ├── .gitignore ├── LICENSE ├── README.md ├── app_sadtalker.py ├── cog.yaml ├── docs/ │ ├── FAQ.md │ ├── best_practice.md │ ├── changlelog.md │ ├── face3d.md │ ├── install.md │ └── webui_extension.md ├── inference.py ├── launcher.py ├── predict.py ├── quick_demo.ipynb ├── req.txt ├── requirements.txt ├── requirements3d.txt ├── scripts/ │ ├── download_models.sh │ ├── extension.py │ └── test.sh ├── src/ │ ├── audio2exp_models/ │ │ ├── audio2exp.py │ │ └── networks.py │ ├── audio2pose_models/ │ │ ├── audio2pose.py │ │ ├── audio_encoder.py │ │ ├── cvae.py │ │ ├── discriminator.py │ │ ├── networks.py │ │ └── res_unet.py │ ├── config/ │ │ ├── auido2exp.yaml │ │ ├── auido2pose.yaml │ │ ├── facerender.yaml │ │ ├── facerender_still.yaml │ │ └── similarity_Lm3D_all.mat │ ├── face3d/ │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_dataset.py │ │ │ ├── flist_dataset.py │ │ │ ├── image_folder.py │ │ │ └── template_dataset.py │ │ ├── extract_kp_videos.py │ │ ├── extract_kp_videos_safe.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── arcface_torch/ │ │ │ │ ├── README.md │ │ │ │ ├── backbones/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── iresnet.py │ │ │ │ │ ├── iresnet2060.py │ │ │ │ │ └── mobilefacenet.py │ │ │ │ ├── configs/ │ │ │ │ │ ├── 3millions.py │ │ │ │ │ ├── 3millions_pfc.py │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base.py │ │ │ │ │ ├── glint360k_mbf.py │ │ │ │ │ ├── glint360k_r100.py │ │ │ │ │ ├── glint360k_r18.py │ │ │ │ │ ├── glint360k_r34.py │ │ │ │ │ ├── glint360k_r50.py │ │ │ │ │ ├── ms1mv3_mbf.py │ │ │ │ │ ├── ms1mv3_r18.py │ │ │ │ │ ├── ms1mv3_r2060.py │ │ │ │ │ ├── ms1mv3_r34.py │ │ │ │ │ ├── ms1mv3_r50.py │ │ │ │ │ └── speed.py │ │ │ │ ├── dataset.py │ │ │ │ ├── docs/ │ │ │ │ │ ├── eval.md │ │ │ │ │ ├── install.md │ │ │ │ │ ├── modelzoo.md │ │ │ │ │ └── speed_benchmark.md │ │ │ │ ├── eval/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── verification.py │ │ │ │ ├── eval_ijbc.py │ │ │ │ ├── inference.py │ │ │ │ ├── losses.py │ │ │ │ ├── onnx_helper.py │ │ │ │ ├── onnx_ijbc.py │ │ │ │ ├── partial_fc.py │ │ │ │ ├── requirement.txt │ │ │ │ ├── run.sh │ │ │ │ ├── torch2onnx.py │ │ │ │ ├── train.py │ │ │ │ └── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── plot.py │ │ │ │ ├── utils_amp.py │ │ │ │ ├── utils_callbacks.py │ │ │ │ ├── utils_config.py │ │ │ │ ├── utils_logging.py │ │ │ │ └── utils_os.py │ │ │ ├── base_model.py │ │ │ ├── bfm.py │ │ │ ├── facerecon_model.py │ │ │ ├── losses.py │ │ │ ├── networks.py │ │ │ └── template_model.py │ │ ├── options/ │ │ │ ├── __init__.py │ │ │ ├── base_options.py │ │ │ ├── inference_options.py │ │ │ ├── test_options.py │ │ │ └── train_options.py │ │ ├── util/ │ │ │ ├── BBRegressorParam_r.mat │ │ │ ├── __init__.py │ │ │ ├── detect_lm68.py │ │ │ ├── generate_list.py │ │ │ ├── html.py │ │ │ ├── load_mats.py │ │ │ ├── my_awing_arch.py │ │ │ ├── nvdiffrast.py │ │ │ ├── preprocess.py │ │ │ ├── skin_mask.py │ │ │ ├── test_mean_face.txt │ │ │ ├── util.py │ │ │ └── visualizer.py │ │ └── visualize.py │ ├── facerender/ │ │ ├── animate.py │ │ ├── modules/ │ │ │ ├── dense_motion.py │ │ │ ├── discriminator.py │ │ │ ├── generator.py │ │ │ ├── keypoint_detector.py │ │ │ ├── make_animation.py │ │ │ ├── mapping.py │ │ │ └── util.py │ │ └── sync_batchnorm/ │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ ├── generate_batch.py │ ├── generate_facerender_batch.py │ ├── gradio_demo.py │ ├── test_audio2coeff.py │ └── utils/ │ ├── audio.py │ ├── croper.py │ ├── face_enhancer.py │ ├── hparams.py │ ├── init_path.py │ ├── model2safetensor.py │ ├── paste_pic.py │ ├── preprocess.py │ ├── safetensor_helper.py │ ├── text2speech.py │ └── videoio.py ├── webui.bat └── webui.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ examples/results/* gfpgan/* checkpoints/* assets/* results/* Dockerfile start_docker.sh start.sh checkpoints # Mac .DS_Store ================================================ FILE: LICENSE ================================================ Tencent is pleased to support the open source community by making SadTalker available. Copyright (C), a Tencent company. All rights reserved. SadTalker is licensed under the Apache 2.0 License, except for the third-party components listed below. Terms of the Apache License Version 2.0: --------------------------------------------- Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================
    [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)   [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/vinthony/SadTalker)   [![sd webui-colab](https://img.shields.io/badge/Automatic1111-Colab-green)](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb)  
[![Replicate](https://replicate.com/cjwbw/sadtalker/badge)](https://replicate.com/cjwbw/sadtalker) [![Discord](https://dcbadge.vercel.app/api/server/rrayYqZ4tf?style=flat)](https://discord.gg/rrayYqZ4tf)
Wenxuan Zhang *,1,2 Xiaodong Cun *,2Xuan Wang 3Yong Zhang 2Xi Shen 2
Yu Guo1 Ying Shan 2 Fei Wang 1

1 Xi'an Jiaotong University   2 Tencent AI Lab   3 Ant Group  

CVPR 2023

![sadtalker](https://user-images.githubusercontent.com/4397546/222490039-b1f6156b-bf00-405b-9fda-0c9a9156f991.gif) TL;DR:       single portrait image 🙎‍♂️      +       audio 🎤       =       talking head video 🎞.
## Highlights - The license has been updated to Apache 2.0, and we've removed the non-commercial restriction - **SadTalker has now officially been integrated into Discord, where you can use it for free by sending files. You can also generate high-quailty videos from text prompts. Join: [![Discord](https://dcbadge.vercel.app/api/server/rrayYqZ4tf?style=flat)](https://discord.gg/rrayYqZ4tf)** - We've published a [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) extension. Check out more details [here](docs/webui_extension.md). [Demo Video](https://user-images.githubusercontent.com/4397546/231495639-5d4bb925-ea64-4a36-a519-6389917dac29.mp4) - Full image mode is now available! [More details...](https://github.com/OpenTalker/SadTalker#full-bodyimage-generation) | still+enhancer in v0.0.1 | still + enhancer in v0.0.2 | [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) | |:--------------------: |:--------------------: | :----: | | | | - Several new modes (Still, reference, and resize modes) are now available! - We're happy to see more community demos on [bilibili](https://search.bilibili.com/all?keyword=sadtalker), [YouTube](https://www.youtube.com/results?search_query=sadtalker) and [X (#sadtalker)](https://twitter.com/search?q=%23sadtalker&src). ## Changelog The previous changelog can be found [here](docs/changlelog.md). - __[2023.06.12]__: Added more new features in WebUI extension, see the discussion [here](https://github.com/OpenTalker/SadTalker/discussions/386). - __[2023.06.05]__: Released a new 512x512px (beta) face model. Fixed some bugs and improve the performance. - __[2023.04.15]__: Added a WebUI Colab notebook by [@camenduru](https://github.com/camenduru/): [![sd webui-colab](https://img.shields.io/badge/Automatic1111-Colab-green)](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) - __[2023.04.12]__: Added a more detailed WebUI installation document and fixed a problem when reinstalling. - __[2023.04.12]__: Fixed the WebUI safe issues becasue of 3rd-party packages, and optimized the output path in `sd-webui-extension`. - __[2023.04.08]__: In v0.0.2, we added a logo watermark to the generated video to prevent abuse. _This watermark has since been removed in a later release._ - __[2023.04.08]__: In v0.0.2, we added features for full image animation and a link to download checkpoints from Baidu. We also optimized the enhancer logic. ## To-Do We're tracking new updates in [issue #280](https://github.com/OpenTalker/SadTalker/issues/280). ## Troubleshooting If you have any problems, please read our [FAQs](docs/FAQ.md) before opening an issue. ## 1. Installation. Community tutorials: [中文Windows教程 (Chinese Windows tutorial)](https://www.bilibili.com/video/BV1Dc411W7V6/) | [日本語コース (Japanese tutorial)](https://br-d.fanbox.cc/posts/5685086). ### Linux/Unix 1. Install [Anaconda](https://www.anaconda.com/), Python and `git`. 2. Creating the env and install the requirements. ```bash git clone https://github.com/OpenTalker/SadTalker.git cd SadTalker conda create -n sadtalker python=3.8 conda activate sadtalker pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 conda install ffmpeg pip install -r requirements.txt ### Coqui TTS is optional for gradio demo. ### pip install TTS ``` ### Windows A video tutorial in chinese is available [here](https://www.bilibili.com/video/BV1Dc411W7V6/). You can also follow the following instructions: 1. Install [Python 3.8](https://www.python.org/downloads/windows/) and check "Add Python to PATH". 2. Install [git](https://git-scm.com/download/win) manually or using [Scoop](https://scoop.sh/): `scoop install git`. 3. Install `ffmpeg`, following [this tutorial](https://www.wikihow.com/Install-FFmpeg-on-Windows) or using [scoop](https://scoop.sh/): `scoop install ffmpeg`. 4. Download the SadTalker repository by running `git clone https://github.com/Winfredy/SadTalker.git`. 5. Download the checkpoints and gfpgan models in the [downloads section](#2-download-models). 6. Run `start.bat` from Windows Explorer as normal, non-administrator, user, and a Gradio-powered WebUI demo will be started. ### macOS A tutorial on installing SadTalker on macOS can be found [here](docs/install.md). ### Docker, WSL, etc Please check out additional tutorials [here](docs/install.md). ## 2. Download Models You can run the following script on Linux/macOS to automatically download all the models: ```bash bash scripts/download_models.sh ``` We also provide an offline patch (`gfpgan/`), so no model will be downloaded when generating. ### Pre-Trained Models * [Google Drive](https://drive.google.com/file/d/1gwWh45pF7aelNP_P78uDJL8Sycep-K7j/view?usp=sharing) * [GitHub Releases](https://github.com/OpenTalker/SadTalker/releases) * [Baidu (百度云盘)](https://pan.baidu.com/s/1kb1BCPaLOWX1JJb9Czbn6w?pwd=sadt) (Password: `sadt`) ### GFPGAN Offline Patch * [Google Drive](https://drive.google.com/file/d/19AIBsmfcHW6BRJmeqSFlG5fL445Xmsyi?usp=sharing) * [GitHub Releases](https://github.com/OpenTalker/SadTalker/releases) * [Baidu (百度云盘)](https://pan.baidu.com/s/1P4fRgk9gaSutZnn8YW034Q?pwd=sadt) (Password: `sadt`)
Model Details Model explains: ##### New version | Model | Description | :--- | :---------- |checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker. |checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker. |checkpoints/SadTalker_V0.0.2_256.safetensors | packaged sadtalker checkpoints of old version, 256 face render). |checkpoints/SadTalker_V0.0.2_512.safetensors | packaged sadtalker checkpoints of old version, 512 face render). |gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`. ##### Old version | Model | Description | :--- | :---------- |checkpoints/auido2exp_00300-model.pth | Pre-trained ExpNet in Sadtalker. |checkpoints/auido2pose_00140-model.pth | Pre-trained PoseVAE in Sadtalker. |checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker. |checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker. |checkpoints/facevid2vid_00189-model.pth.tar | Pre-trained face-vid2vid model from [the reappearance of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis). |checkpoints/epoch_20.pth | Pre-trained 3DMM extractor in [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction). |checkpoints/wav2lip.pth | Highly accurate lip-sync model in [Wav2lip](https://github.com/Rudrabha/Wav2Lip). |checkpoints/shape_predictor_68_face_landmarks.dat | Face landmark model used in [dilb](http://dlib.net/). |checkpoints/BFM | 3DMM library file. |checkpoints/hub | Face detection models used in [face alignment](https://github.com/1adrianb/face-alignment). |gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`. The final folder will be shown as: image
## 3. Quick Start Please read our document on [best practices and configuration tips](docs/best_practice.md) ### WebUI Demos **Online Demo**: [HuggingFace](https://huggingface.co/spaces/vinthony/SadTalker) | [SDWebUI-Colab](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) | [Colab](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) **Local WebUI extension**: Please refer to [WebUI docs](docs/webui_extension.md). **Local gradio demo (recommanded)**: A Gradio instance similar to our [Hugging Face demo](https://huggingface.co/spaces/vinthony/SadTalker) can be run locally: ```bash ## you need manually install TTS(https://github.com/coqui-ai/TTS) via `pip install tts` in advanced. python app_sadtalker.py ``` You can also start it more easily: - windows: just double click `webui.bat`, the requirements will be installed automatically. - Linux/Mac OS: run `bash webui.sh` to start the webui. ### CLI usage ##### Animating a portrait image from default config: ```bash python inference.py --driven_audio \ --source_image \ --enhancer gfpgan ``` The results will be saved in `results/$SOME_TIMESTAMP/*.mp4`. ##### Full body/image Generation: Using `--still` to generate a natural full body video. You can add `enhancer` to improve the quality of the generated video. ```bash python inference.py --driven_audio \ --source_image \ --result_dir \ --still \ --preprocess full \ --enhancer gfpgan ``` More examples and configuration and tips can be founded in the [ >>> best practice documents <<<](docs/best_practice.md). ## Citation If you find our work useful in your research, please consider citing: ```bibtex @article{zhang2022sadtalker, title={SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation}, author={Zhang, Wenxuan and Cun, Xiaodong and Wang, Xuan and Zhang, Yong and Shen, Xi and Guo, Yu and Shan, Ying and Wang, Fei}, journal={arXiv preprint arXiv:2211.12194}, year={2022} } ``` ## Acknowledgements Facerender code borrows heavily from [zhanglonghao's reproduction of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis) and [PIRender](https://github.com/RenYurui/PIRender). We thank the authors for sharing their wonderful code. In training process, we also used the model from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction) and [Wav2lip](https://github.com/Rudrabha/Wav2Lip). We thank for their wonderful work. We also use the following 3rd-party libraries: - **Face Utils**: https://github.com/xinntao/facexlib - **Face Enhancement**: https://github.com/TencentARC/GFPGAN - **Image/Video Enhancement**:https://github.com/xinntao/Real-ESRGAN ## Extensions: - [SadTalker-Video-Lip-Sync](https://github.com/Zz-ww/SadTalker-Video-Lip-Sync) from [@Zz-ww](https://github.com/Zz-ww): SadTalker for Video Lip Editing ## Related Works - [StyleHEAT: One-Shot High-Resolution Editable Talking Face Generation via Pre-trained StyleGAN (ECCV 2022)](https://github.com/FeiiYin/StyleHEAT) - [CodeTalker: Speech-Driven 3D Facial Animation with Discrete Motion Prior (CVPR 2023)](https://github.com/Doubiiu/CodeTalker) - [VideoReTalking: Audio-based Lip Synchronization for Talking Head Video Editing In the Wild (SIGGRAPH Asia 2022)](https://github.com/vinthony/video-retalking) - [DPE: Disentanglement of Pose and Expression for General Video Portrait Editing (CVPR 2023)](https://github.com/Carlyx/DPE) - [3D GAN Inversion with Facial Symmetry Prior (CVPR 2023)](https://github.com/FeiiYin/SPI/) - [T2M-GPT: Generating Human Motion from Textual Descriptions with Discrete Representations (CVPR 2023)](https://github.com/Mael-zys/T2M-GPT) ## Disclaimer This is not an official product of Tencent. ``` 1. Please carefully read and comply with the open-source license applicable to this code before using it. 2. Please carefully read and comply with the intellectual property declaration applicable to this code before using it. 3. This open-source code runs completely offline and does not collect any personal information or other data. If you use this code to provide services to end-users and collect related data, please take necessary compliance measures according to applicable laws and regulations (such as publishing privacy policies, adopting necessary data security strategies, etc.). If the collected data involves personal information, user consent must be obtained (if applicable). Any legal liabilities arising from this are unrelated to Tencent. 4. Without Tencent's written permission, you are not authorized to use the names or logos legally owned by Tencent, such as "Tencent." Otherwise, you may be liable for legal responsibilities. 5. This open-source code does not have the ability to directly provide services to end-users. If you need to use this code for further model training or demos, as part of your product to provide services to end-users, or for similar use, please comply with applicable laws and regulations for your product or service. Any legal liabilities arising from this are unrelated to Tencent. 6. It is prohibited to use this open-source code for activities that harm the legitimate rights and interests of others (including but not limited to fraud, deception, infringement of others' portrait rights, reputation rights, etc.), or other behaviors that violate applicable laws and regulations or go against social ethics and good customs (including providing incorrect or false information, spreading pornographic, terrorist, and violent information, etc.). Otherwise, you may be liable for legal responsibilities. ``` LOGO: color and font suggestion: [ChatGPT](https://chat.openai.com), logo font: [Montserrat Alternates ](https://fonts.google.com/specimen/Montserrat+Alternates?preview.text=SadTalker&preview.text_type=custom&query=mont). All the copyrights of the demo images and audio are from community users or the generation from stable diffusion. Feel free to contact us if you would like use to remove them. ================================================ FILE: app_sadtalker.py ================================================ import os, sys import gradio as gr from src.gradio_demo import SadTalker try: import webui # in webui in_webui = True except: in_webui = False def toggle_audio_file(choice): if choice == False: return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) def ref_video_fn(path_of_ref_video): if path_of_ref_video is not None: return gr.update(value=True) else: return gr.update(value=False) def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None): sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True) with gr.Blocks(analytics_enabled=False) as sadtalker_interface: gr.Markdown("

😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)

\
Arxiv       \ Homepage       \ Github
") with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): with gr.Tabs(elem_id="sadtalker_source_image"): with gr.TabItem('Upload image'): with gr.Row(): source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512) with gr.Tabs(elem_id="sadtalker_driven_audio"): with gr.TabItem('Upload OR TTS'): with gr.Column(variant='panel'): driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath") if sys.platform != 'win32' and not in_webui: from src.utils.text2speech import TTSTalker tts_talker = TTSTalker() with gr.Column(variant='panel'): input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.") tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary') tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio]) with gr.Column(variant='panel'): with gr.Tabs(elem_id="sadtalker_checkbox"): with gr.TabItem('Settings'): gr.Markdown("need help? please visit our [best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md) for more detials") with gr.Column(variant='panel'): # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) # size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") # preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?") is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)") batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2) enhancer = gr.Checkbox(label="GFPGAN as Face enhancer") submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary') with gr.Tabs(elem_id="sadtalker_genearted"): gen_video = gr.Video(label="Generated video", format="mp4").style(width=256) if warpfn: submit.click( fn=warpfn(sad_talker.test), inputs=[source_image, driven_audio, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image, pose_style ], outputs=[gen_video] ) else: submit.click( fn=sad_talker.test, inputs=[source_image, driven_audio, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image, pose_style ], outputs=[gen_video] ) return sadtalker_interface if __name__ == "__main__": demo = sadtalker_demo() demo.queue() demo.launch() ================================================ FILE: cog.yaml ================================================ build: gpu: true cuda: "11.3" python_version: "3.8" system_packages: - "ffmpeg" - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - "torch==1.12.1" - "torchvision==0.13.1" - "torchaudio==0.12.1" - "joblib==1.1.0" - "scikit-image==0.19.3" - "basicsr==1.4.2" - "facexlib==0.3.0" - "resampy==0.3.1" - "pydub==0.25.1" - "scipy==1.10.1" - "kornia==0.6.8" - "face_alignment==1.3.5" - "imageio==2.19.3" - "imageio-ffmpeg==0.4.7" - "librosa==0.9.2" # - "tqdm==4.65.0" - "yacs==0.1.8" - "gfpgan==1.3.8" - "dlib-bin==19.24.1" - "av==10.0.0" - "trimesh==3.9.20" run: - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth" - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip" predict: "predict.py:Predictor" ================================================ FILE: docs/FAQ.md ================================================ ## Frequency Asked Question **Q: `ffmpeg` is not recognized as an internal or external command** In Linux, you can install the ffmpeg via `conda install ffmpeg`. Or on Mac OS X, try to install ffmpeg via `brew install ffmpeg`. On windows, make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), then, following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) installation to install `ffmpeg`. **Q: Running Requirments.** Please refer to the discussion here: https://github.com/Winfredy/SadTalker/issues/124#issuecomment-1508113989 **Q: ModuleNotFoundError: No module named 'ai'** please check the checkpoint's size of the `epoch_20.pth`. (https://github.com/Winfredy/SadTalker/issues/167, https://github.com/Winfredy/SadTalker/issues/113) **Q: Illegal Hardware Error: Mac M1** please reinstall the `dlib` by `pip install dlib` individually. (https://github.com/Winfredy/SadTalker/issues/129, https://github.com/Winfredy/SadTalker/issues/109) **Q: FileNotFoundError: [Errno 2] No such file or directory: checkpoints\BFM_Fitting\similarity_Lm3D_all.mat** Make sure you have downloaded the checkpoints and gfpgan as [here](https://github.com/Winfredy/SadTalker#-2-download-trained-models) and placed them in the right place. **Q: RuntimeError: unexpected EOF, expected 237192 more bytes. The file might be corrupted.** The files are not automatically downloaded. Please update the code and download the gfpgan folders as [here](https://github.com/Winfredy/SadTalker#-2-download-trained-models). **Q: CUDA out of memory error** please refer to https://stackoverflow.com/questions/73747731/runtimeerror-cuda-out-of-memory-how-setting-max-split-size-mb ``` # windows set PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 python inference.py ... # linux export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 python inference.py ... ``` **Q: Error while decoding stream #0:0: Invalid data found when processing input [mp3float @ 0000015037628c00] Header missing** Our method only support wav or mp3 files as input, please make sure the feeded audios are in these formats. ================================================ FILE: docs/best_practice.md ================================================ # Best Practices and Tips for configuration > Our model only works on REAL people or the portrait image similar to REAL person. The anime talking head genreation method will be released in future. Advanced confiuration options for `inference.py`: | Name | Configuration | default | Explaination | |:------------- |:------------- |:----- | :------------- | | Enhance Mode | `--enhancer` | None | Using `gfpgan` or `RestoreFormer` to enhance the generated face via face restoration network | Background Enhancer | `--background_enhancer` | None | Using `realesrgan` to enhance the full video. | Still Mode | ` --still` | False | Using the same pose parameters as the original image, fewer head motion. | Expressive Mode | `--expression_scale` | 1.0 | a larger value will make the expression motion stronger. | save path | `--result_dir` |`./results` | The file will be save in the newer location. | preprocess | `--preprocess` | `crop` | Run and produce the results in the croped input image. Other choices: `resize`, where the images will be resized to the specific resolution. `full` Run the full image animation, use with `--still` to get better results. | ref Mode (eye) | `--ref_eyeblink` | None | A video path, where we borrow the eyeblink from this reference video to provide more natural eyebrow movement. | ref Mode (pose) | `--ref_pose` | None | A video path, where we borrow the pose from the head reference video. | 3D Mode | `--face3dvis` | False | Need additional installation. More details to generate the 3d face can be founded [here](docs/face3d.md). | free-view Mode | `--input_yaw`,
`--input_pitch`,
`--input_roll` | None | Genearting novel view or free-view 4D talking head from a single image. More details can be founded [here](https://github.com/Winfredy/SadTalker#generating-4d-free-view-talking-examples-from-audio-and-a-single-image). ### About `--preprocess` Our system automatically handles the input images via `crop`, `resize` and `full`. In `crop` mode, we only generate the croped image via the facial keypoints and generated the facial anime avator. The animation of both expression and head pose are realistic. > Still mode will stop the eyeblink and head pose movement. | [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) | crop | crop w/still | |:--------------------: |:--------------------: | :----: | | | ![full_body_2](example_crop.gif) | ![full_body_2](example_crop_still.gif) | In `resize` mode, we resize the whole images to generate the fully talking head video. Thus, an image similar to the ID photo can be produced. ⚠️ It will produce bad results for full person images. | | | |:--------------------: |:--------------------: | | ❌ not suitable for resize mode | ✅ good for resize mode | | | | In `full` mode, our model will automatically process the croped region and paste back to the original image. Remember to use `--still` to keep the original head pose. | input | `--still` | `--still` & `enhancer` | |:--------------------: |:--------------------: | :--:| | | | ### About `--enhancer` For higher resolution, we intergate [gfpgan](https://github.com/TencentARC/GFPGAN) and [real-esrgan](https://github.com/xinntao/Real-ESRGAN) for different purpose. Just adding `--enhancer ` or `--background_enhancer ` for the enhancement of the face and the full image. ```bash # make sure above packages are available: pip install gfpgan pip install realesrgan ``` ### About `--face3dvis` This flag indicate that we can generated the 3d-rendered face and it's 3d facial landmarks. More details can be founded [here](face3d.md). | Input | Animated 3d face | |:-------------: | :-------------: | | | | > Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub. #### Reference eye-link mode. | Input, w/ reference video , reference video | |:-------------: | | ![free_view](using_ref_video.gif)| | If the reference video is shorter than the input audio, we will loop the reference video . #### Generating 4D free-view talking examples from audio and a single image We use `input_yaw`, `input_pitch`, `input_roll` to control head pose. For example, `--input_yaw -20 30 10` means the input head yaw degree changes from -20 to 30 and then changes from 30 to 10. ```bash python inference.py --driven_audio \ --source_image \ --result_dir
\ --input_yaw -20 30 10 ``` | Results, Free-view results, Novel view results | |:-------------: | | ![free_view](free_view_result.gif)| ================================================ FILE: docs/changlelog.md ================================================ ## changelogs - __[2023.04.06]__: stable-diffiusion webui extension is release. - __[2023.04.03]__: Enable TTS in huggingface and gradio local demo. - __[2023.03.30]__: Launch beta version of the full body mode. - __[2023.03.30]__: Launch new feature: through using reference videos, our algorithm can generate videos with more natural eye blinking and some eyebrow movement. - __[2023.03.29]__: `resize mode` is online by `python infererence.py --preprocess resize`! Where we can produce a larger crop of the image as discussed in https://github.com/Winfredy/SadTalker/issues/35. - __[2023.03.29]__: local gradio demo is online! `python app.py` to start the demo. New `requirments.txt` is used to avoid the bugs in `librosa`. - __[2023.03.28]__: Online demo is launched in [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/vinthony/SadTalker), thanks AK! - __[2023.03.22]__: Launch new feature: generating the 3d face animation from a single image. New applications about it will be updated. - __[2023.03.22]__: Launch new feature: `still mode`, where only a small head pose will be produced via `python inference.py --still`. - __[2023.03.18]__: Support `expression intensity`, now you can change the intensity of the generated motion: `python inference.py --expression_scale 1.3 (some value > 1)`. - __[2023.03.18]__: Reconfig the data folders, now you can download the checkpoint automatically using `bash scripts/download_models.sh`. - __[2023.03.18]__: We have offically integrate the [GFPGAN](https://github.com/TencentARC/GFPGAN) for face enhancement, using `python inference.py --enhancer gfpgan` for better visualization performance. - __[2023.03.14]__: Specify the version of package `joblib` to remove the errors in using `librosa`, [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) is online! - __[2023.03.06]__: Solve some bugs in code and errors in installation - __[2023.03.03]__: Release the test code for audio-driven single image animation! - __[2023.02.28]__: SadTalker has been accepted by CVPR 2023! ================================================ FILE: docs/face3d.md ================================================ ## 3D Face Visualization We use `pytorch3d` to visualize the 3D faces from a single image. The requirements for 3D visualization are difficult to install, so here's a tutorial: ```bash git clone https://github.com/OpenTalker/SadTalker.git cd SadTalker conda create -n sadtalker3d python=3.8 source activate sadtalker3d conda install ffmpeg conda install -c fvcore -c iopath -c conda-forge fvcore iopath conda install libgcc gmp pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 # insintall pytorch3d pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html pip install -r requirements3d.txt ### install gpfgan for enhancer pip install git+https://github.com/TencentARC/GFPGAN ### when occurs gcc version problem `from pytorch import _C` from pytorch3d, add the anaconda path to LD_LIBRARY_PATH export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/$YOUR_ANACONDA_PATH/lib/ ``` Then, generate the result via: ```bash python inference.py --driven_audio \ --source_image \ --result_dir \ --face3dvis ``` The results will appear, named `face3d.mp4`. More applications about 3D face rendering will be released soon. ================================================ FILE: docs/install.md ================================================ ### macOS This method has been tested on a M1 Mac (13.3) ```bash git clone https://github.com/OpenTalker/SadTalker.git cd SadTalker conda create -n sadtalker python=3.8 conda activate sadtalker # install pytorch 2.0 pip install torch torchvision torchaudio conda install ffmpeg pip install -r requirements.txt pip install dlib # macOS needs to install the original dlib. ``` ### Windows Native - Make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) tutorial to install `ffmpeg` or using scoop. ### Windows WSL - Make sure the environment: `export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH` ### Docker Installation A community Docker image by [@thegenerativegeneration](https://github.com/thegenerativegeneration) is available on the [Docker hub](https://hub.docker.com/repository/docker/wawa9000/sadtalker), which can be used directly: ```bash docker run --gpus "all" --rm -v $(pwd):/host_dir wawa9000/sadtalker \ --driven_audio /host_dir/deyu.wav \ --source_image /host_dir/image.jpg \ --expression_scale 1.0 \ --still \ --result_dir /host_dir ``` ================================================ FILE: docs/webui_extension.md ================================================ ## Run SadTalker as a Stable Diffusion WebUI Extension. 1. Install the lastest version of [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and install SadTalker via `extension`. image 2. Download the checkpoints manually, for Linux and Mac: ```bash cd SOMEWHERE_YOU_LIKE bash <(wget -qO- https://raw.githubusercontent.com/Winfredy/OpenTalker/main/scripts/download_models.sh) ``` For Windows, you can download all the checkpoints [here](https://github.com/OpenTalker/SadTalker/tree/main#2-download-models). 3.1. Option 1: put the checkpoint in `stable-diffusion-webui/models/SadTalker` or `stable-diffusion-webui/extensions/SadTalker/checkpoints/`, the checkpoints will be detected automatically. 3.2. Option 2: Set the path of `SADTALKTER_CHECKPOINTS` in `webui_user.sh`(linux) or `webui_user.bat`(windows) by: > only works if you are directly starting webui from `webui_user.sh` or `webui_user.bat`. ```bash # Windows (webui_user.bat) set SADTALKER_CHECKPOINTS=D:\SadTalker\checkpoints # Linux/macOS (webui_user.sh) export SADTALKER_CHECKPOINTS=/path/to/SadTalker/checkpoints ``` 4. Start the WebUI via `webui.sh or webui_user.sh(linux)` or `webui_user.bat(windows)` or any other method. SadTalker can also be used in stable-diffusion-webui directly. image ## Questions 1. if you are running on CPU, you need to specific `--disable-safe-unpickle` in `webui_user.sh` or `webui_user.bat`. ```bash # windows (webui_user.bat) set COMMANDLINE_ARGS="--disable-safe-unpickle" # linux (webui_user.sh) export COMMANDLINE_ARGS="--disable-safe-unpickle" ``` (If you're unable to use the `full` mode, please read this [discussion](https://github.com/Winfredy/SadTalker/issues/78).) ================================================ FILE: inference.py ================================================ from glob import glob import shutil import torch from time import strftime import os, sys, time from argparse import ArgumentParser from src.utils.preprocess import CropAndExtract from src.test_audio2coeff import Audio2Coeff from src.facerender.animate import AnimateFromCoeff from src.generate_batch import get_data from src.generate_facerender_batch import get_facerender_data from src.utils.init_path import init_path def main(args): #torch.backends.cudnn.enabled = False pic_path = args.source_image audio_path = args.driven_audio save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S")) os.makedirs(save_dir, exist_ok=True) pose_style = args.pose_style device = args.device batch_size = args.batch_size input_yaw_list = args.input_yaw input_pitch_list = args.input_pitch input_roll_list = args.input_roll ref_eyeblink = args.ref_eyeblink ref_pose = args.ref_pose current_root_path = os.path.split(sys.argv[0])[0] sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess) #init model preprocess_model = CropAndExtract(sadtalker_paths, device) audio_to_coeff = Audio2Coeff(sadtalker_paths, device) animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device) #crop image and extract 3dmm from image first_frame_dir = os.path.join(save_dir, 'first_frame_dir') os.makedirs(first_frame_dir, exist_ok=True) print('3DMM Extraction for source image') first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\ source_image_flag=True, pic_size=args.size) if first_coeff_path is None: print("Can't get the coeffs of the input") return if ref_eyeblink is not None: ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0] ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname) os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) print('3DMM Extraction for the reference video providing eye blinking') ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False) else: ref_eyeblink_coeff_path=None if ref_pose is not None: if ref_pose == ref_eyeblink: ref_pose_coeff_path = ref_eyeblink_coeff_path else: ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0] ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname) os.makedirs(ref_pose_frame_dir, exist_ok=True) print('3DMM Extraction for the reference video providing pose') ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False) else: ref_pose_coeff_path=None #audio2ceoff batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still) coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) # 3dface render if args.face3dvis: from src.face3d.visualize import gen_composed_video gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4')) #coeff2video data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, input_yaw_list, input_pitch_list, input_roll_list, expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size) result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \ enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size) shutil.move(result, save_dir+'.mp4') print('The generated video is named:', save_dir+'.mp4') if not args.verbose: shutil.rmtree(save_dir) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio") parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image") parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking") parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose") parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output") parser.add_argument("--result_dir", default='./results', help="path to output") parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)") parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender") parser.add_argument("--size", type=int, default=256, help="the image size of the facerender") parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender") parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ") parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user") parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user") parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]") parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]") parser.add_argument("--cpu", dest="cpu", action="store_true") parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks") parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion") parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" ) parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" ) parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" ) # net structure and parameters parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless') parser.add_argument('--init_path', type=str, default=None, help='Useless') parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc') parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/') parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') # default renderer parameters parser.add_argument('--focal', type=float, default=1015.) parser.add_argument('--center', type=float, default=112.) parser.add_argument('--camera_d', type=float, default=10.) parser.add_argument('--z_near', type=float, default=5.) parser.add_argument('--z_far', type=float, default=15.) args = parser.parse_args() if torch.cuda.is_available() and not args.cpu: args.device = "cuda" else: args.device = "cpu" main(args) ================================================ FILE: launcher.py ================================================ # this scripts installs necessary requirements and launches main program in webui.py # borrow from : https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/launch.py import subprocess import os import sys import importlib.util import shlex import platform import json python = sys.executable git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") stored_commit_hash = None skip_install = False dir_repos = "repositories" script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) if 'GRADIO_ANALYTICS_ENABLED' not in os.environ: os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' def check_python_version(): is_windows = platform.system() == "Windows" major = sys.version_info.major minor = sys.version_info.minor micro = sys.version_info.micro if is_windows: supported_minors = [10] else: supported_minors = [7, 8, 9, 10, 11] if not (major == 3 and minor in supported_minors): raise (f""" INCOMPATIBLE PYTHON VERSION This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}. If you encounter an error with "RuntimeError: Couldn't install torch." message, or any other error regarding unsuccessful package (library) installation, please downgrade (or upgrade) to the latest version of 3.10 Python and delete current Python and "venv" folder in WebUI's directory. You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/ {"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""} Use --skip-python-version-check to suppress this warning. """) def commit_hash(): global stored_commit_hash if stored_commit_hash is not None: return stored_commit_hash try: stored_commit_hash = run(f"{git} rev-parse HEAD").strip() except Exception: stored_commit_hash = "" return stored_commit_hash def run(command, desc=None, errdesc=None, custom_env=None, live=False): if desc is not None: print(desc) if live: result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env) if result.returncode != 0: raise RuntimeError(f"""{errdesc or 'Error running command'}. Command: {command} Error code: {result.returncode}""") return "" result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) if result.returncode != 0: message = f"""{errdesc or 'Error running command'}. Command: {command} Error code: {result.returncode} stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''} stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''} """ raise RuntimeError(message) return result.stdout.decode(encoding="utf8", errors="ignore") def check_run(command): result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) return result.returncode == 0 def is_installed(package): try: spec = importlib.util.find_spec(package) except ModuleNotFoundError: return False return spec is not None def repo_dir(name): return os.path.join(script_path, dir_repos, name) def run_python(code, desc=None, errdesc=None): return run(f'"{python}" -c "{code}"', desc, errdesc) def run_pip(args, desc=None): if skip_install: return index_url_line = f' --index-url {index_url}' if index_url != '' else '' return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") def check_run_python(code): return check_run(f'"{python}" -c "{code}"') def git_clone(url, dir, name, commithash=None): # TODO clone into temporary dir and move if successful if os.path.exists(dir): if commithash is None: return current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() if current_hash == commithash: return run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") return run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") if commithash is not None: run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") def git_pull_recursive(dir): for subdir, _, _ in os.walk(dir): if os.path.exists(os.path.join(subdir, '.git')): try: output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash']) print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n") except subprocess.CalledProcessError as e: print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n") def run_extension_installer(extension_dir): path_installer = os.path.join(extension_dir, "install.py") if not os.path.isfile(path_installer): return try: env = os.environ.copy() env['PYTHONPATH'] = os.path.abspath(".") print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: print(e, file=sys.stderr) def prepare_environment(): global skip_install torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113") ## check windows if sys.platform != 'win32': requirements_file = os.environ.get('REQS_FILE', "req.txt") else: requirements_file = os.environ.get('REQS_FILE', "requirements.txt") commit = commit_hash() print(f"Python {sys.version}") print(f"Commit hash: {commit}") if not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)") if sys.platform != 'win32' and not is_installed('tts'): run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.") def start(): print(f"Launching SadTalker Web UI") from app_sadtalker import sadtalker_demo demo = sadtalker_demo() demo.queue() demo.launch() if __name__ == "__main__": prepare_environment() start() ================================================ FILE: predict.py ================================================ """run bash scripts/download_models.sh first to prepare the weights file""" import os import shutil from argparse import Namespace from src.utils.preprocess import CropAndExtract from src.test_audio2coeff import Audio2Coeff from src.facerender.animate import AnimateFromCoeff from src.generate_batch import get_data from src.generate_facerender_batch import get_facerender_data from src.utils.init_path import init_path from cog import BasePredictor, Input, Path checkpoints = "checkpoints" class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" device = "cuda" sadtalker_paths = init_path(checkpoints,os.path.join("src","config")) # init model self.preprocess_model = CropAndExtract(sadtalker_paths, device ) self.audio_to_coeff = Audio2Coeff( sadtalker_paths, device, ) self.animate_from_coeff = { "full": AnimateFromCoeff( sadtalker_paths, device, ), "others": AnimateFromCoeff( sadtalker_paths, device, ), } def predict( self, source_image: Path = Input( description="Upload the source image, it can be video.mp4 or picture.png", ), driven_audio: Path = Input( description="Upload the driven audio, accepts .wav and .mp4 file", ), enhancer: str = Input( description="Choose a face enhancer", choices=["gfpgan", "RestoreFormer"], default="gfpgan", ), preprocess: str = Input( description="how to preprocess the images", choices=["crop", "resize", "full"], default="full", ), ref_eyeblink: Path = Input( description="path to reference video providing eye blinking", default=None, ), ref_pose: Path = Input( description="path to reference video providing pose", default=None, ), still: bool = Input( description="can crop back to the original videos for the full body aniamtion when preprocess is full", default=True, ), ) -> Path: """Run a single prediction on the model""" animate_from_coeff = ( self.animate_from_coeff["full"] if preprocess == "full" else self.animate_from_coeff["others"] ) args = load_default() args.pic_path = str(source_image) args.audio_path = str(driven_audio) device = "cuda" args.still = still args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink) args.ref_pose = None if ref_pose is None else str(ref_pose) # crop image and extract 3dmm from image results_dir = "results" if os.path.exists(results_dir): shutil.rmtree(results_dir) os.makedirs(results_dir) first_frame_dir = os.path.join(results_dir, "first_frame_dir") os.makedirs(first_frame_dir) print("3DMM Extraction for source image") first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( args.pic_path, first_frame_dir, preprocess, source_image_flag=True ) if first_coeff_path is None: print("Can't get the coeffs of the input") return if ref_eyeblink is not None: ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[ 0 ] ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname) os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) print("3DMM Extraction for the reference video providing eye blinking") ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate( ref_eyeblink, ref_eyeblink_frame_dir ) else: ref_eyeblink_coeff_path = None if ref_pose is not None: if ref_pose == ref_eyeblink: ref_pose_coeff_path = ref_eyeblink_coeff_path else: ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0] ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname) os.makedirs(ref_pose_frame_dir, exist_ok=True) print("3DMM Extraction for the reference video providing pose") ref_pose_coeff_path, _, _ = self.preprocess_model.generate( ref_pose, ref_pose_frame_dir ) else: ref_pose_coeff_path = None # audio2ceoff batch = get_data( first_coeff_path, args.audio_path, device, ref_eyeblink_coeff_path, still=still, ) coeff_path = self.audio_to_coeff.generate( batch, results_dir, args.pose_style, ref_pose_coeff_path ) # coeff2video print("coeff2video") data = get_facerender_data( coeff_path, crop_pic_path, first_coeff_path, args.audio_path, args.batch_size, args.input_yaw, args.input_pitch, args.input_roll, expression_scale=args.expression_scale, still_mode=still, preprocess=preprocess, ) animate_from_coeff.generate( data, results_dir, args.pic_path, crop_info, enhancer=enhancer, background_enhancer=args.background_enhancer, preprocess=preprocess) output = "/tmp/out.mp4" mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0]) shutil.copy(mp4_path, output) return Path(output) def load_default(): return Namespace( pose_style=0, batch_size=2, expression_scale=1.0, input_yaw=None, input_pitch=None, input_roll=None, background_enhancer=None, face3dvis=False, net_recon="resnet50", init_path=None, use_last_fc=False, bfm_folder="./src/config/", bfm_model="BFM_model_front.mat", focal=1015.0, center=112.0, camera_d=10.0, z_near=5.0, z_far=15.0, ) ================================================ FILE: quick_demo.ipynb ================================================ { "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "M74Gs_TjYl_B" }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "view-in-github" }, "source": [ "### SadTalker:Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation \n", "\n", "[arxiv](https://arxiv.org/abs/2211.12194) | [project](https://sadtalker.github.io) | [Github](https://github.com/Winfredy/SadTalker)\n", "\n", "Wenxuan Zhang, Xiaodong Cun, Xuan Wang, Yong Zhang, Xi Shen, Yu Guo, Ying Shan, Fei Wang.\n", "\n", "Xi'an Jiaotong University, Tencent AI Lab, Ant Group\n", "\n", "CVPR 2023\n", "\n", "TL;DR: A realistic and stylized talking head video generation method from a single image and audio\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "kA89DV-sKS4i" }, "source": [ "Installation (around 5 mins)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qJ4CplXsYl_E" }, "outputs": [], "source": [ "### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\n", "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mdq6j4E5KQAR" }, "outputs": [], "source": [ "!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.8 2\n", "!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.9 1\n", "!sudo apt install python3.8\n", "\n", "!sudo apt-get install python3.8-distutils\n", "\n", "!python --version\n", "\n", "!apt-get update\n", "\n", "!apt install software-properties-common\n", "\n", "!sudo dpkg --remove --force-remove-reinstreq python3-pip python3-setuptools python3-wheel\n", "\n", "!apt-get install python3-pip\n", "\n", "print('Git clone project and install requirements...')\n", "!git clone https://github.com/Winfredy/SadTalker &> /dev/null\n", "%cd SadTalker\n", "!export PYTHONPATH=/content/SadTalker:$PYTHONPATH\n", "!python3.8 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\n", "!apt update\n", "!apt install ffmpeg &> /dev/null\n", "!python3.8 -m pip install -r requirements.txt" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "DddcKB_nKsnk" }, "source": [ "Download models (1 mins)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eDw3_UN8K2xa" }, "outputs": [], "source": [ "print('Download pre-trained models...')\n", "!rm -rf checkpoints\n", "!bash scripts/download_models.sh" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kK7DYeo7Yl_H" }, "outputs": [], "source": [ "# borrow from makeittalk\n", "import ipywidgets as widgets\n", "import glob\n", "import matplotlib.pyplot as plt\n", "print(\"Choose the image name to animate: (saved in folder 'examples/')\")\n", "img_list = glob.glob1('examples/source_image', '*.png')\n", "img_list.sort()\n", "img_list = [item.split('.')[0] for item in img_list]\n", "default_head_name = widgets.Dropdown(options=img_list, value='full3')\n", "def on_change(change):\n", " if change['type'] == 'change' and change['name'] == 'value':\n", " plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n", " plt.axis('off')\n", " plt.show()\n", "default_head_name.observe(on_change)\n", "display(default_head_name)\n", "plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n", "plt.axis('off')\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "-khNZcnGK4UK" }, "source": [ "Animation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ToBlDusjK5sS" }, "outputs": [], "source": [ "# selected audio from exmaple/driven_audio\n", "img = 'examples/source_image/{}.png'.format(default_head_name.value)\n", "print(img)\n", "!python3.8 inference.py --driven_audio ./examples/driven_audio/RD_Radio31_000.wav \\\n", " --source_image {img} \\\n", " --result_dir ./results --still --preprocess full --enhancer gfpgan" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fAjwGmKKYl_I" }, "outputs": [], "source": [ "# visualize code from makeittalk\n", "from IPython.display import HTML\n", "from base64 import b64encode\n", "import os, sys\n", "\n", "# get the last from results\n", "\n", "results = sorted(os.listdir('./results/'))\n", "\n", "mp4_name = glob.glob('./results/*.mp4')[0]\n", "\n", "mp4 = open('{}'.format(mp4_name),'rb').read()\n", "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", "\n", "print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n", "display(HTML(\"\"\"\n", " \n", " \"\"\" % data_url))\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.9.7" }, "vscode": { "interpreter": { "hash": "db5031b3636a3f037ea48eb287fd3d023feb9033aefc2a9652a92e470fb0851b" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: req.txt ================================================ llvmlite==0.38.1 numpy==1.21.6 face_alignment==1.3.5 imageio==2.19.3 imageio-ffmpeg==0.4.7 librosa==0.10.0.post2 numba==0.55.1 resampy==0.3.1 pydub==0.25.1 scipy==1.10.1 kornia==0.6.8 tqdm yacs==0.1.8 pyyaml joblib==1.1.0 scikit-image==0.19.3 basicsr==1.4.2 facexlib==0.3.0 gradio gfpgan av safetensors ================================================ FILE: requirements.txt ================================================ numpy==1.23.4 face_alignment==1.3.5 imageio==2.19.3 imageio-ffmpeg==0.4.7 librosa==0.9.2 # numba resampy==0.3.1 pydub==0.25.1 scipy==1.10.1 kornia==0.6.8 tqdm yacs==0.1.8 pyyaml joblib==1.1.0 scikit-image==0.19.3 basicsr==1.4.2 facexlib==0.3.0 gradio gfpgan av safetensors ================================================ FILE: requirements3d.txt ================================================ numpy==1.23.4 face_alignment==1.3.5 imageio==2.19.3 imageio-ffmpeg==0.4.7 librosa==0.9.2 # numba resampy==0.3.1 pydub==0.25.1 scipy==1.5.3 kornia==0.6.8 tqdm yacs==0.1.8 pyyaml joblib==1.1.0 scikit-image==0.19.3 basicsr==1.4.2 facexlib==0.3.0 trimesh==3.9.20 gradio gfpgan safetensors ================================================ FILE: scripts/download_models.sh ================================================ mkdir ./checkpoints # lagency download link # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip # unzip -n ./checkpoints/hub.zip -d ./checkpoints/ #### download the new links. wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O ./checkpoints/SadTalker_V0.0.2_256.safetensors wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O ./checkpoints/SadTalker_V0.0.2_512.safetensors # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip # unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/ ### enhancer mkdir -p ./gfpgan/weights wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth ================================================ FILE: scripts/extension.py ================================================ import os, sys from pathlib import Path import tempfile import gradio as gr from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call from modules.shared import opts, OptionInfo from modules import shared, paths, script_callbacks import launch import glob from huggingface_hub import snapshot_download def check_all_files_safetensor(current_dir): kv = { "SadTalker_V0.0.2_256.safetensors": "sadtalker-256", "SadTalker_V0.0.2_512.safetensors": "sadtalker-512", "mapping_00109-model.pth.tar" : "mapping-109" , "mapping_00229-model.pth.tar" : "mapping-229" , } if not os.path.isdir(current_dir): return False dirs = os.listdir(current_dir) for f in dirs: if f in kv.keys(): del kv[f] return len(kv.keys()) == 0 def check_all_files(current_dir): kv = { "auido2exp_00300-model.pth": "audio2exp", "auido2pose_00140-model.pth": "audio2pose", "epoch_20.pth": "face_recon", "facevid2vid_00189-model.pth.tar": "face-render", "mapping_00109-model.pth.tar" : "mapping-109" , "mapping_00229-model.pth.tar" : "mapping-229" , "wav2lip.pth": "wav2lip", "shape_predictor_68_face_landmarks.dat": "dlib", } if not os.path.isdir(current_dir): return False dirs = os.listdir(current_dir) for f in dirs: if f in kv.keys(): del kv[f] return len(kv.keys()) == 0 def download_model(local_dir='./checkpoints'): REPO_ID = 'vinthony/SadTalker' snapshot_download(repo_id=REPO_ID, local_dir=local_dir, local_dir_use_symlinks=False) def get_source_image(image): return image def get_img_from_txt2img(x): talker_path = Path(paths.script_path) / "outputs" imgs_from_txt_dir = str(talker_path / "txt2img-images/") imgs = glob.glob(imgs_from_txt_dir+'/*/*.png') imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x))) img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1]) return img_from_txt_path, img_from_txt_path def get_img_from_img2img(x): talker_path = Path(paths.script_path) / "outputs" imgs_from_img_dir = str(talker_path / "img2img-images/") imgs = glob.glob(imgs_from_img_dir+'/*/*.png') imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x))) img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1]) return img_from_img_path, img_from_img_path def get_default_checkpoint_path(): # check the path of models/checkpoints and extensions/ checkpoint_path = Path(paths.script_path) / "models"/ "SadTalker" extension_checkpoint_path = Path(paths.script_path) / "extensions"/ "SadTalker" / "checkpoints" if check_all_files_safetensor(checkpoint_path): # print('founding sadtalker checkpoint in ' + str(checkpoint_path)) return checkpoint_path if check_all_files_safetensor(extension_checkpoint_path): # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path)) return extension_checkpoint_path if check_all_files(checkpoint_path): # print('founding sadtalker checkpoint in ' + str(checkpoint_path)) return checkpoint_path if check_all_files(extension_checkpoint_path): # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path)) return extension_checkpoint_path return None def install(): kv = { "face_alignment": "face-alignment==1.3.5", "imageio": "imageio==2.19.3", "imageio_ffmpeg": "imageio-ffmpeg==0.4.7", "librosa":"librosa==0.8.0", "pydub":"pydub==0.25.1", "scipy":"scipy==1.8.1", "tqdm": "tqdm", "yacs":"yacs==0.1.8", "yaml": "pyyaml", "av":"av", "gfpgan": "gfpgan", } # # dlib is not necessary currently # if 'darwin' in sys.platform: # kv['dlib'] = "dlib" # else: # kv['dlib'] = 'dlib-bin' # #### we need to have a newer version of imageio for our method. # launch.run_pip("install imageio==2.19.3", "requirements for SadTalker") for k,v in kv.items(): if not launch.is_installed(k): print(k, launch.is_installed(k)) launch.run_pip("install "+ v, "requirements for SadTalker") if os.getenv('SADTALKER_CHECKPOINTS'): print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS')) elif get_default_checkpoint_path() is not None: os.environ['SADTALKER_CHECKPOINTS'] = str(get_default_checkpoint_path()) else: print( """" SadTalker will not support download all the files from hugging face, which will take a long time. please manually set the SADTALKER_CHECKPOINTS in `webui_user.bat`(windows) or `webui_user.sh`(linux) """ ) # python = sys.executable # launch.run(f'"{python}" -m pip uninstall -y huggingface_hub', live=True) # launch.run(f'"{python}" -m pip install --upgrade git+https://github.com/huggingface/huggingface_hub@main', live=True) # ### run the scripts to downlod models to correct localtion. # # print('download models for SadTalker') # # launch.run("cd " + paths.script_path+"/extensions/SadTalker && bash ./scripts/download_models.sh", live=True) # # print('SadTalker is successfully installed!') # download_model(paths.script_path+'/extensions/SadTalker/checkpoints') def on_ui_tabs(): install() sys.path.extend([paths.script_path+'/extensions/SadTalker']) repo_dir = paths.script_path+'/extensions/SadTalker/' result_dir = opts.sadtalker_result_dir os.makedirs(result_dir, exist_ok=True) from app_sadtalker import sadtalker_demo if os.getenv('SADTALKER_CHECKPOINTS'): checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS') else: checkpoint_path = repo_dir+'checkpoints/' audio_to_video = sadtalker_demo(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', warpfn = wrap_queued_call) return [(audio_to_video, "SadTalker", "extension")] def on_ui_settings(): talker_path = Path(paths.script_path) / "outputs" section = ('extension', "SadTalker") opts.add_option("sadtalker_result_dir", OptionInfo(str(talker_path / "SadTalker/"), "Path to save results of sadtalker", section=section)) script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_tabs(on_ui_tabs) ================================================ FILE: scripts/test.sh ================================================ # ### some test command before commit. # python inference.py --preprocess crop --size 256 # python inference.py --preprocess crop --size 512 # python inference.py --preprocess extcrop --size 256 # python inference.py --preprocess extcrop --size 512 # python inference.py --preprocess resize --size 256 # python inference.py --preprocess resize --size 512 # python inference.py --preprocess full --size 256 # python inference.py --preprocess full --size 512 # python inference.py --preprocess extfull --size 256 # python inference.py --preprocess extfull --size 512 python inference.py --preprocess full --size 256 --enhancer gfpgan python inference.py --preprocess full --size 512 --enhancer gfpgan python inference.py --preprocess full --size 256 --enhancer gfpgan --still python inference.py --preprocess full --size 512 --enhancer gfpgan --still ================================================ FILE: src/audio2exp_models/audio2exp.py ================================================ from tqdm import tqdm import torch from torch import nn class Audio2Exp(nn.Module): def __init__(self, netG, cfg, device, prepare_training_loss=False): super(Audio2Exp, self).__init__() self.cfg = cfg self.device = device self.netG = netG.to(device) def test(self, batch): mel_input = batch['indiv_mels'] # bs T 1 80 16 bs = mel_input.shape[0] T = mel_input.shape[1] exp_coeff_pred = [] for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames current_mel_input = mel_input[:,i:i+10] #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64 ref = batch['ref'][:, :, :64][:, i:i+10] ratio = batch['ratio_gt'][:, i:i+10] #bs T audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64 exp_coeff_pred += [curr_exp_coeff_pred] # BS x T x 64 results_dict = { 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1) } return results_dict ================================================ FILE: src/audio2exp_models/networks.py ================================================ import torch import torch.nn.functional as F from torch import nn class Conv2d(nn.Module): def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs): super().__init__(*args, **kwargs) self.conv_block = nn.Sequential( nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout) ) self.act = nn.ReLU() self.residual = residual self.use_act = use_act def forward(self, x): out = self.conv_block(x) if self.residual: out += x if self.use_act: return self.act(out) else: return out class SimpleWrapperV2(nn.Module): def __init__(self) -> None: super().__init__() self.audio_encoder = nn.Sequential( Conv2d(1, 32, kernel_size=3, stride=1, padding=1), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 128, kernel_size=3, stride=3, padding=1), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(256, 512, kernel_size=3, stride=1, padding=0), Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ) #### load the pre-trained audio_encoder #self.audio_encoder = self.audio_encoder.to(device) ''' wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict'] state_dict = self.audio_encoder.state_dict() for k,v in wav2lip_state_dict.items(): if 'audio_encoder' in k: print('init:', k) state_dict[k.replace('module.audio_encoder.', '')] = v self.audio_encoder.load_state_dict(state_dict) ''' self.mapping1 = nn.Linear(512+64+1, 64) #self.mapping2 = nn.Linear(30, 64) #nn.init.constant_(self.mapping1.weight, 0.) nn.init.constant_(self.mapping1.bias, 0.) def forward(self, x, ref, ratio): x = self.audio_encoder(x).view(x.size(0), -1) ref_reshape = ref.reshape(x.size(0), -1) ratio = ratio.reshape(x.size(0), -1) y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial return out ================================================ FILE: src/audio2pose_models/audio2pose.py ================================================ import torch from torch import nn from src.audio2pose_models.cvae import CVAE from src.audio2pose_models.discriminator import PoseSequenceDiscriminator from src.audio2pose_models.audio_encoder import AudioEncoder class Audio2Pose(nn.Module): def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): super().__init__() self.cfg = cfg self.seq_len = cfg.MODEL.CVAE.SEQ_LEN self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE self.device = device self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) self.audio_encoder.eval() for param in self.audio_encoder.parameters(): param.requires_grad = False self.netG = CVAE(cfg) self.netD_motion = PoseSequenceDiscriminator(cfg) def forward(self, x): batch = {} coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73 batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6 batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6 batch['class'] = x['class'].squeeze(0).cuda() # bs indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16 # forward audio_emb_list = [] audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512 batch['audio_emb'] = audio_emb batch = self.netG(batch) pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6 pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6 pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6 batch['pose_pred'] = pose_pred batch['pose_gt'] = pose_gt return batch def test(self, x): batch = {} ref = x['ref'] #bs 1 70 batch['ref'] = x['ref'][:,0,-6:] batch['class'] = x['class'] bs = ref.shape[0] indiv_mels= x['indiv_mels'] # bs T 1 80 16 indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame num_frames = x['num_frames'] num_frames = int(num_frames) - 1 # div = num_frames//self.seq_len re = num_frames%self.seq_len audio_emb_list = [] pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, device=batch['ref'].device)] for i in range(div): z = torch.randn(bs, self.latent_dim).to(ref.device) batch['z'] = z audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512 batch['audio_emb'] = audio_emb batch = self.netG.test(batch) pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6 if re != 0: z = torch.randn(bs, self.latent_dim).to(ref.device) batch['z'] = z audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512 if audio_emb.shape[1] != self.seq_len: pad_dim = self.seq_len-audio_emb.shape[1] pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) batch['audio_emb'] = audio_emb batch = self.netG.test(batch) pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) batch['pose_motion_pred'] = pose_motion_pred pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6 batch['pose_pred'] = pose_pred return batch ================================================ FILE: src/audio2pose_models/audio_encoder.py ================================================ import torch from torch import nn from torch.nn import functional as F class Conv2d(nn.Module): def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): super().__init__(*args, **kwargs) self.conv_block = nn.Sequential( nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout) ) self.act = nn.ReLU() self.residual = residual def forward(self, x): out = self.conv_block(x) if self.residual: out += x return self.act(out) class AudioEncoder(nn.Module): def __init__(self, wav2lip_checkpoint, device): super(AudioEncoder, self).__init__() self.audio_encoder = nn.Sequential( Conv2d(1, 32, kernel_size=3, stride=1, padding=1), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 128, kernel_size=3, stride=3, padding=1), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(256, 512, kernel_size=3, stride=1, padding=0), Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) #### load the pre-trained audio_encoder, we do not need to load wav2lip model here. # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] # state_dict = self.audio_encoder.state_dict() # for k,v in wav2lip_state_dict.items(): # if 'audio_encoder' in k: # state_dict[k.replace('module.audio_encoder.', '')] = v # self.audio_encoder.load_state_dict(state_dict) def forward(self, audio_sequences): # audio_sequences = (B, T, 1, 80, 16) B = audio_sequences.size(0) audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 dim = audio_embedding.shape[1] audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 ================================================ FILE: src/audio2pose_models/cvae.py ================================================ import torch import torch.nn.functional as F from torch import nn from src.audio2pose_models.res_unet import ResUnet def class2onehot(idx, class_num): assert torch.max(idx).item() < class_num onehot = torch.zeros(idx.size(0), class_num).to(idx.device) onehot.scatter_(1, idx, 1) return onehot class CVAE(nn.Module): def __init__(self, cfg): super().__init__() encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES latent_size = cfg.MODEL.CVAE.LATENT_SIZE num_classes = cfg.DATASET.NUM_CLASSES audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE seq_len = cfg.MODEL.CVAE.SEQ_LEN self.latent_size = latent_size self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes, audio_emb_in_size, audio_emb_out_size, seq_len) self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes, audio_emb_in_size, audio_emb_out_size, seq_len) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, batch): batch = self.encoder(batch) mu = batch['mu'] logvar = batch['logvar'] z = self.reparameterize(mu, logvar) batch['z'] = z return self.decoder(batch) def test(self, batch): ''' class_id = batch['class'] z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device) batch['z'] = z ''' return self.decoder(batch) class ENCODER(nn.Module): def __init__(self, layer_sizes, latent_size, num_classes, audio_emb_in_size, audio_emb_out_size, seq_len): super().__init__() self.resunet = ResUnet() self.num_classes = num_classes self.seq_len = seq_len self.MLP = nn.Sequential() layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6 for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): self.MLP.add_module( name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) self.linear_means = nn.Linear(layer_sizes[-1], latent_size) self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size) self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) def forward(self, batch): class_id = batch['class'] pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6 ref = batch['ref'] #bs 6 bs = pose_motion_gt.shape[0] audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size #pose encode pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6 pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6 #audio mapping print(audio_in.shape) audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size audio_out = audio_out.reshape(bs, -1) class_bias = self.classbias[class_id] #bs latent_size x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size x_out = self.MLP(x_in) mu = self.linear_means(x_out) logvar = self.linear_means(x_out) #bs latent_size batch.update({'mu':mu, 'logvar':logvar}) return batch class DECODER(nn.Module): def __init__(self, layer_sizes, latent_size, num_classes, audio_emb_in_size, audio_emb_out_size, seq_len): super().__init__() self.resunet = ResUnet() self.num_classes = num_classes self.seq_len = seq_len self.MLP = nn.Sequential() input_size = latent_size + seq_len*audio_emb_out_size + 6 for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): self.MLP.add_module( name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) if i+1 < len(layer_sizes): self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) else: self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) self.pose_linear = nn.Linear(6, 6) self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) def forward(self, batch): z = batch['z'] #bs latent_size bs = z.shape[0] class_id = batch['class'] ref = batch['ref'] #bs 6 audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size #print('audio_in: ', audio_in[:, :, :10]) audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size #print('audio_out: ', audio_out[:, :, :10]) audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size class_bias = self.classbias[class_id] #bs latent_size z = z + class_bias x_in = torch.cat([ref, z, audio_out], dim=-1) x_out = self.MLP(x_in) # bs layer_sizes[-1] x_out = x_out.reshape((bs, self.seq_len, -1)) #print('x_out: ', x_out) pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6 pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6 batch.update({'pose_motion_pred':pose_motion_pred}) return batch ================================================ FILE: src/audio2pose_models/discriminator.py ================================================ import torch import torch.nn.functional as F from torch import nn class ConvNormRelu(nn.Module): def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False, kernel_size=None, stride=None, padding=None, norm='BN', leaky=False): super().__init__() if kernel_size is None: if downsample: kernel_size, stride, padding = 4, 2, 1 else: kernel_size, stride, padding = 3, 1, 1 if conv_type == '2d': self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, bias=False, ) if norm == 'BN': self.norm = nn.BatchNorm2d(out_channels) elif norm == 'IN': self.norm = nn.InstanceNorm2d(out_channels) else: raise NotImplementedError elif conv_type == '1d': self.conv = nn.Conv1d( in_channels, out_channels, kernel_size, stride, padding, bias=False, ) if norm == 'BN': self.norm = nn.BatchNorm1d(out_channels) elif norm == 'IN': self.norm = nn.InstanceNorm1d(out_channels) else: raise NotImplementedError nn.init.kaiming_normal_(self.conv.weight) self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) if isinstance(self.norm, nn.InstanceNorm1d): x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C] else: x = self.norm(x) x = self.act(x) return x class PoseSequenceDiscriminator(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU self.seq = nn.Sequential( ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64 ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32 ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16 nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16 ) def forward(self, x): x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2) x = self.seq(x) x = x.squeeze(1) return x ================================================ FILE: src/audio2pose_models/networks.py ================================================ import torch.nn as nn import torch class ResidualConv(nn.Module): def __init__(self, input_dim, output_dim, stride, padding): super(ResidualConv, self).__init__() self.conv_block = nn.Sequential( nn.BatchNorm2d(input_dim), nn.ReLU(), nn.Conv2d( input_dim, output_dim, kernel_size=3, stride=stride, padding=padding ), nn.BatchNorm2d(output_dim), nn.ReLU(), nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), ) self.conv_skip = nn.Sequential( nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), nn.BatchNorm2d(output_dim), ) def forward(self, x): return self.conv_block(x) + self.conv_skip(x) class Upsample(nn.Module): def __init__(self, input_dim, output_dim, kernel, stride): super(Upsample, self).__init__() self.upsample = nn.ConvTranspose2d( input_dim, output_dim, kernel_size=kernel, stride=stride ) def forward(self, x): return self.upsample(x) class Squeeze_Excite_Block(nn.Module): def __init__(self, channel, reduction=16): super(Squeeze_Excite_Block, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid(), ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) class ASPP(nn.Module): def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): super(ASPP, self).__init__() self.aspp_block1 = nn.Sequential( nn.Conv2d( in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] ), nn.ReLU(inplace=True), nn.BatchNorm2d(out_dims), ) self.aspp_block2 = nn.Sequential( nn.Conv2d( in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] ), nn.ReLU(inplace=True), nn.BatchNorm2d(out_dims), ) self.aspp_block3 = nn.Sequential( nn.Conv2d( in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] ), nn.ReLU(inplace=True), nn.BatchNorm2d(out_dims), ) self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) self._init_weights() def forward(self, x): x1 = self.aspp_block1(x) x2 = self.aspp_block2(x) x3 = self.aspp_block3(x) out = torch.cat([x1, x2, x3], dim=1) return self.output(out) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class Upsample_(nn.Module): def __init__(self, scale=2): super(Upsample_, self).__init__() self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) def forward(self, x): return self.upsample(x) class AttentionBlock(nn.Module): def __init__(self, input_encoder, input_decoder, output_dim): super(AttentionBlock, self).__init__() self.conv_encoder = nn.Sequential( nn.BatchNorm2d(input_encoder), nn.ReLU(), nn.Conv2d(input_encoder, output_dim, 3, padding=1), nn.MaxPool2d(2, 2), ) self.conv_decoder = nn.Sequential( nn.BatchNorm2d(input_decoder), nn.ReLU(), nn.Conv2d(input_decoder, output_dim, 3, padding=1), ) self.conv_attn = nn.Sequential( nn.BatchNorm2d(output_dim), nn.ReLU(), nn.Conv2d(output_dim, 1, 1), ) def forward(self, x1, x2): out = self.conv_encoder(x1) + self.conv_decoder(x2) out = self.conv_attn(out) return out * x2 ================================================ FILE: src/audio2pose_models/res_unet.py ================================================ import torch import torch.nn as nn from src.audio2pose_models.networks import ResidualConv, Upsample class ResUnet(nn.Module): def __init__(self, channel=1, filters=[32, 64, 128, 256]): super(ResUnet, self).__init__() self.input_layer = nn.Sequential( nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), nn.BatchNorm2d(filters[0]), nn.ReLU(), nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), ) self.input_skip = nn.Sequential( nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) ) self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1) self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1) self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1) self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1)) self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1) self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1)) self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1) self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1)) self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1) self.output_layer = nn.Sequential( nn.Conv2d(filters[0], 1, 1, 1), nn.Sigmoid(), ) def forward(self, x): # Encode x1 = self.input_layer(x) + self.input_skip(x) x2 = self.residual_conv_1(x1) x3 = self.residual_conv_2(x2) # Bridge x4 = self.bridge(x3) # Decode x4 = self.upsample_1(x4) x5 = torch.cat([x4, x3], dim=1) x6 = self.up_residual_conv1(x5) x6 = self.upsample_2(x6) x7 = torch.cat([x6, x2], dim=1) x8 = self.up_residual_conv2(x7) x8 = self.upsample_3(x8) x9 = torch.cat([x8, x1], dim=1) x10 = self.up_residual_conv3(x9) output = self.output_layer(x10) return output ================================================ FILE: src/config/auido2exp.yaml ================================================ DATASET: TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt TRAIN_BATCH_SIZE: 32 EVAL_BATCH_SIZE: 32 EXP: True EXP_DIM: 64 FRAME_LEN: 32 COEFF_LEN: 73 NUM_CLASSES: 46 AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb DEBUG: True NUM_REPEATS: 2 T: 40 MODEL: FRAMEWORK: V2 AUDIOENCODER: LEAKY_RELU: True NORM: 'IN' DISCRIMINATOR: LEAKY_RELU: False INPUT_CHANNELS: 6 CVAE: AUDIO_EMB_IN_SIZE: 512 AUDIO_EMB_OUT_SIZE: 128 SEQ_LEN: 32 LATENT_SIZE: 256 ENCODER_LAYER_SIZES: [192, 1024] DECODER_LAYER_SIZES: [1024, 192] TRAIN: MAX_EPOCH: 300 GENERATOR: LR: 2.0e-5 DISCRIMINATOR: LR: 1.0e-5 LOSS: W_FEAT: 0 W_COEFF_EXP: 2 W_LM: 1.0e-2 W_LM_MOUTH: 0 W_REG: 0 W_SYNC: 0 W_COLOR: 0 W_EXPRESSION: 0 W_LIPREADING: 0.01 W_LIPREADING_VV: 0 W_EYE_BLINK: 4 TAG: NAME: small_dataset ================================================ FILE: src/config/auido2pose.yaml ================================================ DATASET: TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt TRAIN_BATCH_SIZE: 64 EVAL_BATCH_SIZE: 1 EXP: True EXP_DIM: 64 FRAME_LEN: 32 COEFF_LEN: 73 NUM_CLASSES: 46 AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb DEBUG: True MODEL: AUDIOENCODER: LEAKY_RELU: True NORM: 'IN' DISCRIMINATOR: LEAKY_RELU: False INPUT_CHANNELS: 6 CVAE: AUDIO_EMB_IN_SIZE: 512 AUDIO_EMB_OUT_SIZE: 6 SEQ_LEN: 32 LATENT_SIZE: 64 ENCODER_LAYER_SIZES: [192, 128] DECODER_LAYER_SIZES: [128, 192] TRAIN: MAX_EPOCH: 150 GENERATOR: LR: 1.0e-4 DISCRIMINATOR: LR: 1.0e-4 LOSS: LAMBDA_REG: 1 LAMBDA_LANDMARKS: 0 LAMBDA_VERTICES: 0 LAMBDA_GAN_MOTION: 0.7 LAMBDA_GAN_COEFF: 0 LAMBDA_KL: 1 TAG: NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder ================================================ FILE: src/config/facerender.yaml ================================================ model_params: common_params: num_kp: 15 image_channel: 3 feature_channel: 32 estimate_jacobian: False # True kp_detector_params: temperature: 0.1 block_expansion: 32 max_features: 1024 scale_factor: 0.25 # 0.25 num_blocks: 5 reshape_channel: 16384 # 16384 = 1024 * 16 reshape_depth: 16 he_estimator_params: block_expansion: 64 max_features: 2048 num_bins: 66 generator_params: block_expansion: 64 max_features: 512 num_down_blocks: 2 reshape_channel: 32 reshape_depth: 16 # 512 = 32 * 16 num_resblocks: 6 estimate_occlusion_map: True dense_motion_params: block_expansion: 32 max_features: 1024 num_blocks: 5 reshape_depth: 16 compress: 4 discriminator_params: scales: [1] block_expansion: 32 max_features: 512 num_blocks: 4 sn: True mapping_params: coeff_nc: 70 descriptor_nc: 1024 layer: 3 num_kp: 15 num_bins: 66 ================================================ FILE: src/config/facerender_still.yaml ================================================ model_params: common_params: num_kp: 15 image_channel: 3 feature_channel: 32 estimate_jacobian: False # True kp_detector_params: temperature: 0.1 block_expansion: 32 max_features: 1024 scale_factor: 0.25 # 0.25 num_blocks: 5 reshape_channel: 16384 # 16384 = 1024 * 16 reshape_depth: 16 he_estimator_params: block_expansion: 64 max_features: 2048 num_bins: 66 generator_params: block_expansion: 64 max_features: 512 num_down_blocks: 2 reshape_channel: 32 reshape_depth: 16 # 512 = 32 * 16 num_resblocks: 6 estimate_occlusion_map: True dense_motion_params: block_expansion: 32 max_features: 1024 num_blocks: 5 reshape_depth: 16 compress: 4 discriminator_params: scales: [1] block_expansion: 32 max_features: 512 num_blocks: 4 sn: True mapping_params: coeff_nc: 73 descriptor_nc: 1024 layer: 3 num_kp: 15 num_bins: 66 ================================================ FILE: src/face3d/data/__init__.py ================================================ """This package includes all the modules related to data loading and preprocessing To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. You need to implement four functions: -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). -- <__len__>: return the size of dataset. -- <__getitem__>: get a data point from data loader. -- : (optionally) add dataset-specific options and set default options. Now you can use the dataset class by specifying flag '--dataset_mode dummy'. See our template dataset class 'template_dataset.py' for more details. """ import numpy as np import importlib import torch.utils.data from face3d.data.base_dataset import BaseDataset def find_dataset_using_name(dataset_name): """Import the module "data/[dataset_name]_dataset.py". In the file, the class called DatasetNameDataset() will be instantiated. It has to be a subclass of BaseDataset, and it is case-insensitive. """ dataset_filename = "data." + dataset_name + "_dataset" datasetlib = importlib.import_module(dataset_filename) dataset = None target_dataset_name = dataset_name.replace('_', '') + 'dataset' for name, cls in datasetlib.__dict__.items(): if name.lower() == target_dataset_name.lower() \ and issubclass(cls, BaseDataset): dataset = cls if dataset is None: raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) return dataset def get_option_setter(dataset_name): """Return the static method of the dataset class.""" dataset_class = find_dataset_using_name(dataset_name) return dataset_class.modify_commandline_options def create_dataset(opt, rank=0): """Create a dataset given the option. This function wraps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' Example: >>> from data import create_dataset >>> dataset = create_dataset(opt) """ data_loader = CustomDatasetDataLoader(opt, rank=rank) dataset = data_loader.load_data() return dataset class CustomDatasetDataLoader(): """Wrapper class of Dataset class that performs multi-threaded data loading""" def __init__(self, opt, rank=0): """Initialize this class Step 1: create a dataset instance given the name [dataset_mode] Step 2: create a multi-threaded data loader. """ self.opt = opt dataset_class = find_dataset_using_name(opt.dataset_mode) self.dataset = dataset_class(opt) self.sampler = None print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) if opt.use_ddp and opt.isTrain: world_size = opt.world_size self.sampler = torch.utils.data.distributed.DistributedSampler( self.dataset, num_replicas=world_size, rank=rank, shuffle=not opt.serial_batches ) self.dataloader = torch.utils.data.DataLoader( self.dataset, sampler=self.sampler, num_workers=int(opt.num_threads / world_size), batch_size=int(opt.batch_size / world_size), drop_last=True) else: self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batch_size, shuffle=(not opt.serial_batches) and opt.isTrain, num_workers=int(opt.num_threads), drop_last=True ) def set_epoch(self, epoch): self.dataset.current_epoch = epoch if self.sampler is not None: self.sampler.set_epoch(epoch) def load_data(self): return self def __len__(self): """Return the number of data in the dataset""" return min(len(self.dataset), self.opt.max_dataset_size) def __iter__(self): """Return a batch of data""" for i, data in enumerate(self.dataloader): if i * self.opt.batch_size >= self.opt.max_dataset_size: break yield data ================================================ FILE: src/face3d/data/base_dataset.py ================================================ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. """ import random import numpy as np import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms from abc import ABC, abstractmethod class BaseDataset(data.Dataset, ABC): """This class is an abstract base class (ABC) for datasets. To create a subclass, you need to implement the following four functions: -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). -- <__len__>: return the size of dataset. -- <__getitem__>: get a data point. -- : (optionally) add dataset-specific options and set default options. """ def __init__(self, opt): """Initialize the class; save the options in the class Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ self.opt = opt # self.root = opt.dataroot self.current_epoch = 0 @staticmethod def modify_commandline_options(parser, is_train): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ return parser @abstractmethod def __len__(self): """Return the total number of images in the dataset.""" return 0 @abstractmethod def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index - - a random integer for data indexing Returns: a dictionary of data with their names. It ususally contains the data itself and its metadata information. """ pass def get_transform(grayscale=False): transform_list = [] if grayscale: transform_list.append(transforms.Grayscale(1)) transform_list += [transforms.ToTensor()] return transforms.Compose(transform_list) def get_affine_mat(opt, size): shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False w, h = size if 'shift' in opt.preprocess: shift_pixs = int(opt.shift_pixs) shift_x = random.randint(-shift_pixs, shift_pixs) shift_y = random.randint(-shift_pixs, shift_pixs) if 'scale' in opt.preprocess: scale = 1 + opt.scale_delta * (2 * random.random() - 1) if 'rot' in opt.preprocess: rot_angle = opt.rot_angle * (2 * random.random() - 1) rot_rad = -rot_angle * np.pi/180 if 'flip' in opt.preprocess: flip = random.random() > 0.5 shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin affine_inv = np.linalg.inv(affine) return affine, affine_inv, flip def apply_img_affine(img, affine_inv, method=Image.BICUBIC): return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) def apply_lm_affine(landmark, affine, flip, size): _, h = size lm = landmark.copy() lm[:, 1] = h - 1 - lm[:, 1] lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) lm = lm @ np.transpose(affine) lm[:, :2] = lm[:, :2] / lm[:, 2:] lm = lm[:, :2] lm[:, 1] = h - 1 - lm[:, 1] if flip: lm_ = lm.copy() lm_[:17] = lm[16::-1] lm_[17:22] = lm[26:21:-1] lm_[22:27] = lm[21:16:-1] lm_[31:36] = lm[35:30:-1] lm_[36:40] = lm[45:41:-1] lm_[40:42] = lm[47:45:-1] lm_[42:46] = lm[39:35:-1] lm_[46:48] = lm[41:39:-1] lm_[48:55] = lm[54:47:-1] lm_[55:60] = lm[59:54:-1] lm_[60:65] = lm[64:59:-1] lm_[65:68] = lm[67:64:-1] lm = lm_ return lm ================================================ FILE: src/face3d/data/flist_dataset.py ================================================ """This script defines the custom dataset for Deep3DFaceRecon_pytorch """ import os.path from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine from data.image_folder import make_dataset from PIL import Image import random import util.util as util import numpy as np import json import torch from scipy.io import loadmat, savemat import pickle from util.preprocess import align_img, estimate_norm from util.load_mats import load_lm3d def default_flist_reader(flist): """ flist format: impath label\nimpath label\n ...(same to caffe's filelist) """ imlist = [] with open(flist, 'r') as rf: for line in rf.readlines(): impath = line.strip() imlist.append(impath) return imlist def jason_flist_reader(flist): with open(flist, 'r') as fp: info = json.load(fp) return info def parse_label(label): return torch.tensor(np.array(label).astype(np.float32)) class FlistDataset(BaseDataset): """ It requires one directories to host training images '/path/to/data/train' You can train the model with the dataset flag '--dataroot /path/to/data'. """ def __init__(self, opt): """Initialize this dataset class. Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseDataset.__init__(self, opt) self.lm3d_std = load_lm3d(opt.bfm_folder) msk_names = default_flist_reader(opt.flist) self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] self.size = len(self.msk_paths) self.opt = opt self.name = 'train' if opt.isTrain else 'val' if '_' in opt.flist: self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index (int) -- a random integer for data indexing Returns a dictionary that contains A, B, A_paths and B_paths img (tensor) -- an image in the input domain msk (tensor) -- its corresponding attention mask lm (tensor) -- its corresponding 3d landmarks im_paths (str) -- image paths aug_flag (bool) -- a flag used to tell whether its raw or augmented """ msk_path = self.msk_paths[index % self.size] # make sure index is within then range img_path = msk_path.replace('mask/', '') lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' raw_img = Image.open(img_path).convert('RGB') raw_msk = Image.open(msk_path).convert('RGB') raw_lm = np.loadtxt(lm_path).astype(np.float32) _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) aug_flag = self.opt.use_aug and self.opt.isTrain if aug_flag: img, lm, msk = self._augmentation(img, lm, self.opt, msk) _, H = img.size M = estimate_norm(lm, H) transform = get_transform() img_tensor = transform(img) msk_tensor = transform(msk)[:1, ...] lm_tensor = parse_label(lm) M_tensor = parse_label(M) return {'imgs': img_tensor, 'lms': lm_tensor, 'msks': msk_tensor, 'M': M_tensor, 'im_paths': img_path, 'aug_flag': aug_flag, 'dataset': self.name} def _augmentation(self, img, lm, opt, msk=None): affine, affine_inv, flip = get_affine_mat(opt, img.size) img = apply_img_affine(img, affine_inv) lm = apply_lm_affine(lm, affine, flip, img.size) if msk is not None: msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) return img, lm, msk def __len__(self): """Return the total number of images in the dataset. """ return self.size ================================================ FILE: src/face3d/data/image_folder.py ================================================ """A modified image folder class We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) so that this class can load images from both current directory and its subdirectories. """ import numpy as np import torch.utils.data as data from PIL import Image import os import os.path IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def make_dataset(dir, max_dataset_size=float("inf")): images = [] assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir, followlinks=True)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) return images[:min(max_dataset_size, len(images))] def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): def __init__(self, root, transform=None, return_paths=False, loader=default_loader): imgs = make_dataset(root) if len(imgs) == 0: raise(RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.transform = transform self.return_paths = return_paths self.loader = loader def __getitem__(self, index): path = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.return_paths: return img, path else: return img def __len__(self): return len(self.imgs) ================================================ FILE: src/face3d/data/template_dataset.py ================================================ """Dataset class template This module provides a template for users to implement custom datasets. You can specify '--dataset_mode template' to use this dataset. The class name should be consistent with both the filename and its dataset_mode option. The filename should be _dataset.py The class name should be Dataset.py You need to implement the following functions: -- : Add dataset-specific options and rewrite default values for existing options. -- <__init__>: Initialize this dataset class. -- <__getitem__>: Return a data point and its metadata information. -- <__len__>: Return the number of images. """ from data.base_dataset import BaseDataset, get_transform # from data.image_folder import make_dataset # from PIL import Image class TemplateDataset(BaseDataset): """A template dataset class for you to implement custom datasets.""" @staticmethod def modify_commandline_options(parser, is_train): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values return parser def __init__(self, opt): """Initialize this dataset class. Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions A few things can be done here. - save the options (have been done in BaseDataset) - get image paths and meta information of the dataset. - define the image transformation. """ # save the option and dataset root BaseDataset.__init__(self, opt) # get the image paths of your dataset; self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root # define the default transform function. You can use ; You can also define your custom transform function self.transform = get_transform(opt) def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index -- a random integer for data indexing Returns: a dictionary of data with their names. It usually contains the data itself and its metadata information. Step 1: get a random image path: e.g., path = self.image_paths[index] Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) Step 4: return a data point as a dictionary. """ path = 'temp' # needs to be a string data_A = None # needs to be a tensor data_B = None # needs to be a tensor return {'data_A': data_A, 'data_B': data_B, 'path': path} def __len__(self): """Return the total number of images.""" return len(self.image_paths) ================================================ FILE: src/face3d/extract_kp_videos.py ================================================ import os import cv2 import time import glob import argparse import face_alignment import numpy as np from PIL import Image from tqdm import tqdm from itertools import cycle from torch.multiprocessing import Pool, Process, set_start_method class KeypointExtractor(): def __init__(self, device): self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device) def extract_keypoint(self, images, name=None, info=True): if isinstance(images, list): keypoints = [] if info: i_range = tqdm(images,desc='landmark Det:') else: i_range = images for image in i_range: current_kp = self.extract_keypoint(image) if np.mean(current_kp) == -1 and keypoints: keypoints.append(keypoints[-1]) else: keypoints.append(current_kp[None]) keypoints = np.concatenate(keypoints, 0) np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) return keypoints else: while True: try: keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] break except RuntimeError as e: if str(e).startswith('CUDA'): print("Warning: out of memory, sleep for 1s") time.sleep(1) else: print(e) break except TypeError: print('No face detected in this image') shape = [68, 2] keypoints = -1. * np.ones(shape) break if name is not None: np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) return keypoints def read_video(filename): frames = [] cap = cv2.VideoCapture(filename) while cap.isOpened(): ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.fromarray(frame) frames.append(frame) else: break cap.release() return frames def run(data): filename, opt, device = data os.environ['CUDA_VISIBLE_DEVICES'] = device kp_extractor = KeypointExtractor() images = read_video(filename) name = filename.split('/')[-2:] os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) kp_extractor.extract_keypoint( images, name=os.path.join(opt.output_dir, name[-2], name[-1]) ) if __name__ == '__main__': set_start_method('spawn') parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--input_dir', type=str, help='the folder of the input files') parser.add_argument('--output_dir', type=str, help='the folder of the output files') parser.add_argument('--device_ids', type=str, default='0,1') parser.add_argument('--workers', type=int, default=4) opt = parser.parse_args() filenames = list() VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) extensions = VIDEO_EXTENSIONS for ext in extensions: os.listdir(f'{opt.input_dir}') print(f'{opt.input_dir}/*.{ext}') filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) print('Total number of videos:', len(filenames)) pool = Pool(opt.workers) args_list = cycle([opt]) device_ids = opt.device_ids.split(",") device_ids = cycle(device_ids) for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): None ================================================ FILE: src/face3d/extract_kp_videos_safe.py ================================================ import os import cv2 import time import glob import argparse import numpy as np from PIL import Image import torch from tqdm import tqdm from itertools import cycle from torch.multiprocessing import Pool, Process, set_start_method from facexlib.alignment import landmark_98_to_68 from facexlib.detection import init_detection_model from facexlib.utils import load_file_from_url from src.face3d.util.my_awing_arch import FAN def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None): if model_name == 'awing_fan': model = FAN(num_modules=4, num_landmarks=98, device=device) model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') model_path = load_file_from_url( url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True) model.eval() model = model.to(device) return model class KeypointExtractor(): def __init__(self, device='cuda'): ### gfpgan/weights try: import webui # in webui root_path = 'extensions/SadTalker/gfpgan/weights' except: root_path = 'gfpgan/weights' self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path) self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path) def extract_keypoint(self, images, name=None, info=True): if isinstance(images, list): keypoints = [] if info: i_range = tqdm(images,desc='landmark Det:') else: i_range = images for image in i_range: current_kp = self.extract_keypoint(image) # current_kp = self.detector.get_landmarks(np.array(image)) if np.mean(current_kp) == -1 and keypoints: keypoints.append(keypoints[-1]) else: keypoints.append(current_kp[None]) keypoints = np.concatenate(keypoints, 0) np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) return keypoints else: while True: try: with torch.no_grad(): # face detection -> face alignment. img = np.array(images) bboxes = self.det_net.detect_faces(images, 0.97) bboxes = bboxes[0] img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :] keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0] #### keypoints to the original location keypoints[:,0] += int(bboxes[0]) keypoints[:,1] += int(bboxes[1]) break except RuntimeError as e: if str(e).startswith('CUDA'): print("Warning: out of memory, sleep for 1s") time.sleep(1) else: print(e) break except TypeError: print('No face detected in this image') shape = [68, 2] keypoints = -1. * np.ones(shape) break if name is not None: np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) return keypoints def read_video(filename): frames = [] cap = cv2.VideoCapture(filename) while cap.isOpened(): ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.fromarray(frame) frames.append(frame) else: break cap.release() return frames def run(data): filename, opt, device = data os.environ['CUDA_VISIBLE_DEVICES'] = device kp_extractor = KeypointExtractor() images = read_video(filename) name = filename.split('/')[-2:] os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) kp_extractor.extract_keypoint( images, name=os.path.join(opt.output_dir, name[-2], name[-1]) ) if __name__ == '__main__': set_start_method('spawn') parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--input_dir', type=str, help='the folder of the input files') parser.add_argument('--output_dir', type=str, help='the folder of the output files') parser.add_argument('--device_ids', type=str, default='0,1') parser.add_argument('--workers', type=int, default=4) opt = parser.parse_args() filenames = list() VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) extensions = VIDEO_EXTENSIONS for ext in extensions: os.listdir(f'{opt.input_dir}') print(f'{opt.input_dir}/*.{ext}') filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) print('Total number of videos:', len(filenames)) pool = Pool(opt.workers) args_list = cycle([opt]) device_ids = opt.device_ids.split(",") device_ids = cycle(device_ids) for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): None ================================================ FILE: src/face3d/models/__init__.py ================================================ """This package contains modules related to objective functions, optimizations, and network architectures. To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. You need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate loss, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. In the function <__init__>, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.model_names (str list): define networks used in our training. -- self.visual_names (str list): specify the images that you want to display and save. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. Now you can use the model class by specifying flag '--model dummy'. See our template model class 'template_model.py' for more details. """ import importlib from src.face3d.models.base_model import BaseModel def find_model_using_name(model_name): """Import the module "models/[model_name]_model.py". In the file, the class called DatasetNameModel() will be instantiated. It has to be a subclass of BaseModel, and it is case-insensitive. """ model_filename = "face3d.models." + model_name + "_model" modellib = importlib.import_module(model_filename) model = None target_model_name = model_name.replace('_', '') + 'model' for name, cls in modellib.__dict__.items(): if name.lower() == target_model_name.lower() \ and issubclass(cls, BaseModel): model = cls if model is None: print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) exit(0) return model def get_option_setter(model_name): """Return the static method of the model class.""" model_class = find_model_using_name(model_name) return model_class.modify_commandline_options def create_model(opt): """Create a model given the option. This function warps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' Example: >>> from models import create_model >>> model = create_model(opt) """ model = find_model_using_name(opt.model) instance = model(opt) print("model [%s] was created" % type(instance).__name__) return instance ================================================ FILE: src/face3d/models/arcface_torch/README.md ================================================ # Distributed Arcface Training in Pytorch This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions identity on a single server. ## Requirements - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md). - `pip install -r requirements.txt`. - Download the dataset from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_) . ## How to Training To train a model, run `train.py` with the path to the configs: ### 1. Single node, 8 GPUs: ```shell python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 ``` ### 2. Multiple nodes, each node 8 GPUs: Node 0: ```shell python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 ``` Node 1: ```shell python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 ``` ### 3.Training resnet2060 with 8 GPUs: ```shell python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py ``` ## Model Zoo - The models are available for non-commercial research purposes only. - All models can be found in here. - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/) ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. As the result, we can evaluate the FAIR performance for different algorithms. For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The globalised multi-racial testset contains 242,143 identities and 1,624,305 images. For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. There are totally 13,928 positive pairs and 96,983,824 negative pairs. | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** | | :---: | :--- | :--- | :--- |:--- |:--- | | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** | | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** | | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** | | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** | | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** | | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** | | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** | | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** | | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** | | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** | ### Performance on IJB-C and Verification Datasets | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log | | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- | | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)| | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)| | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)| | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)| | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)| | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)| [comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.) ## [Speed Benchmark](docs/speed_benchmark.md) **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same accuracy with several times faster training performance and smaller GPU memory. Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed training and mixed precision training. ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png) More details see [speed_benchmark.md](docs/speed_benchmark.md) in docs. ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better) `-` means training failed because of gpu memory limitations. | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | | :--- | :--- | :--- | :--- | |125000 | 4681 | 4824 | 5004 | |1400000 | **1672** | 3043 | 4738 | |5500000 | **-** | **1389** | 3975 | |8000000 | **-** | **-** | 3565 | |16000000 | **-** | **-** | 2679 | |29000000 | **-** | **-** | **1855** | ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better) | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | | :--- | :--- | :--- | :--- | |125000 | 7358 | 5306 | 4868 | |1400000 | 32252 | 11178 | 6056 | |5500000 | **-** | 32188 | 9854 | |8000000 | **-** | **-** | 12310 | |16000000 | **-** | **-** | 19950 | |29000000 | **-** | **-** | 32324 | ## Evaluation ICCV2021-MFR and IJB-C More details see [eval.md](docs/eval.md) in docs. ## Test We tested many versions of PyTorch. Please create an issue if you are having trouble. - [x] torch 1.6.0 - [x] torch 1.7.1 - [x] torch 1.8.0 - [x] torch 1.9.0 ## Citation ``` @inproceedings{deng2019arcface, title={Arcface: Additive angular margin loss for deep face recognition}, author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos}, booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, pages={4690--4699}, year={2019} } @inproceedings{an2020partical_fc, title={Partial FC: Training 10 Million Identities on a Single Machine}, author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and Zhang, Debing and Fu Ying}, booktitle={Arxiv 2010.05222}, year={2020} } ``` ================================================ FILE: src/face3d/models/arcface_torch/backbones/__init__.py ================================================ from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 from .mobilefacenet import get_mbf def get_model(name, **kwargs): # resnet if name == "r18": return iresnet18(False, **kwargs) elif name == "r34": return iresnet34(False, **kwargs) elif name == "r50": return iresnet50(False, **kwargs) elif name == "r100": return iresnet100(False, **kwargs) elif name == "r200": return iresnet200(False, **kwargs) elif name == "r2060": from .iresnet2060 import iresnet2060 return iresnet2060(False, **kwargs) elif name == "mbf": fp16 = kwargs.get("fp16", False) num_features = kwargs.get("num_features", 512) return get_mbf(fp16=fp16, num_features=num_features) else: raise ValueError() ================================================ FILE: src/face3d/models/arcface_torch/backbones/iresnet.py ================================================ import torch from torch import nn __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class IBasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1): super(IBasicBlock, self).__init__() if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) self.conv1 = conv3x3(inplanes, planes) self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) self.prelu = nn.PReLU(planes) self.conv2 = conv3x3(planes, planes, stride) self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.bn1(x) out = self.conv1(out) out = self.bn2(out) out = self.prelu(out) out = self.conv2(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out class IResNet(nn.Module): fc_scale = 7 * 7 def __init__(self, block, layers, dropout=0, num_features=512, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): super(IResNet, self).__init__() self.fp16 = fp16 self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) self.prelu = nn.PReLU(self.inplanes) self.layer1 = self._make_layer(block, 64, layers[0], stride=2) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) self.dropout = nn.Dropout(p=dropout, inplace=True) self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) self.features = nn.BatchNorm1d(num_features, eps=1e-05) nn.init.constant_(self.features.weight, 1.0) self.features.weight.requires_grad = False for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, 0, 0.1) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if zero_init_residual: for m in self.modules(): if isinstance(m, IBasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), ) layers = [] layers.append( block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation)) return nn.Sequential(*layers) def forward(self, x): with torch.cuda.amp.autocast(self.fp16): x = self.conv1(x) x = self.bn1(x) x = self.prelu(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.bn2(x) x = torch.flatten(x, 1) x = self.dropout(x) x = self.fc(x.float() if self.fp16 else x) x = self.features(x) return x def _iresnet(arch, block, layers, pretrained, progress, **kwargs): model = IResNet(block, layers, **kwargs) if pretrained: raise ValueError() return model def iresnet18(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def iresnet34(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def iresnet50(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, progress, **kwargs) def iresnet100(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, progress, **kwargs) def iresnet200(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, progress, **kwargs) ================================================ FILE: src/face3d/models/arcface_torch/backbones/iresnet2060.py ================================================ import torch from torch import nn assert torch.__version__ >= "1.8.1" from torch.utils.checkpoint import checkpoint_sequential __all__ = ['iresnet2060'] def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class IBasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1): super(IBasicBlock, self).__init__() if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) self.conv1 = conv3x3(inplanes, planes) self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) self.prelu = nn.PReLU(planes) self.conv2 = conv3x3(planes, planes, stride) self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.bn1(x) out = self.conv1(out) out = self.bn2(out) out = self.prelu(out) out = self.conv2(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out class IResNet(nn.Module): fc_scale = 7 * 7 def __init__(self, block, layers, dropout=0, num_features=512, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): super(IResNet, self).__init__() self.fp16 = fp16 self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) self.prelu = nn.PReLU(self.inplanes) self.layer1 = self._make_layer(block, 64, layers[0], stride=2) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) self.dropout = nn.Dropout(p=dropout, inplace=True) self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) self.features = nn.BatchNorm1d(num_features, eps=1e-05) nn.init.constant_(self.features.weight, 1.0) self.features.weight.requires_grad = False for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, 0, 0.1) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if zero_init_residual: for m in self.modules(): if isinstance(m, IBasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), ) layers = [] layers.append( block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation)) return nn.Sequential(*layers) def checkpoint(self, func, num_seg, x): if self.training: return checkpoint_sequential(func, num_seg, x) else: return func(x) def forward(self, x): with torch.cuda.amp.autocast(self.fp16): x = self.conv1(x) x = self.bn1(x) x = self.prelu(x) x = self.layer1(x) x = self.checkpoint(self.layer2, 20, x) x = self.checkpoint(self.layer3, 100, x) x = self.layer4(x) x = self.bn2(x) x = torch.flatten(x, 1) x = self.dropout(x) x = self.fc(x.float() if self.fp16 else x) x = self.features(x) return x def _iresnet(arch, block, layers, pretrained, progress, **kwargs): model = IResNet(block, layers, **kwargs) if pretrained: raise ValueError() return model def iresnet2060(pretrained=False, progress=True, **kwargs): return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) ================================================ FILE: src/face3d/models/arcface_torch/backbones/mobilefacenet.py ================================================ ''' Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py Original author cavalleria ''' import torch.nn as nn from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module import torch class Flatten(Module): def forward(self, x): return x.view(x.size(0), -1) class ConvBlock(Module): def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): super(ConvBlock, self).__init__() self.layers = nn.Sequential( Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), BatchNorm2d(num_features=out_c), PReLU(num_parameters=out_c) ) def forward(self, x): return self.layers(x) class LinearBlock(Module): def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): super(LinearBlock, self).__init__() self.layers = nn.Sequential( Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), BatchNorm2d(num_features=out_c) ) def forward(self, x): return self.layers(x) class DepthWise(Module): def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): super(DepthWise, self).__init__() self.residual = residual self.layers = nn.Sequential( ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) ) def forward(self, x): short_cut = None if self.residual: short_cut = x x = self.layers(x) if self.residual: output = short_cut + x else: output = x return output class Residual(Module): def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): super(Residual, self).__init__() modules = [] for _ in range(num_block): modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) self.layers = Sequential(*modules) def forward(self, x): return self.layers(x) class GDC(Module): def __init__(self, embedding_size): super(GDC, self).__init__() self.layers = nn.Sequential( LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), Flatten(), Linear(512, embedding_size, bias=False), BatchNorm1d(embedding_size)) def forward(self, x): return self.layers(x) class MobileFaceNet(Module): def __init__(self, fp16=False, num_features=512): super(MobileFaceNet, self).__init__() scale = 2 self.fp16 = fp16 self.layers = nn.Sequential( ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), ) self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) self.features = GDC(num_features) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() def forward(self, x): with torch.cuda.amp.autocast(self.fp16): x = self.layers(x) x = self.conv_sep(x.float() if self.fp16 else x) x = self.features(x) return x def get_mbf(fp16, num_features): return MobileFaceNet(fp16, num_features) ================================================ FILE: src/face3d/models/arcface_torch/configs/3millions.py ================================================ from easydict import EasyDict as edict # configs for test speed config = edict() config.loss = "arcface" config.network = "r50" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "synthetic" config.num_classes = 300 * 10000 config.num_epoch = 30 config.warmup_epoch = -1 config.decay_epoch = [10, 16, 22] config.val_targets = [] ================================================ FILE: src/face3d/models/arcface_torch/configs/3millions_pfc.py ================================================ from easydict import EasyDict as edict # configs for test speed config = edict() config.loss = "arcface" config.network = "r50" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 0.1 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "synthetic" config.num_classes = 300 * 10000 config.num_epoch = 30 config.warmup_epoch = -1 config.decay_epoch = [10, 16, 22] config.val_targets = [] ================================================ FILE: src/face3d/models/arcface_torch/configs/__init__.py ================================================ ================================================ FILE: src/face3d/models/arcface_torch/configs/base.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "arcface" config.network = "r50" config.resume = False config.output = "ms1mv3_arcface_r50" config.dataset = "ms1m-retinaface-t1" config.embedding_size = 512 config.sample_rate = 1 config.fp16 = False config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 if config.dataset == "emore": config.rec = "/train_tmp/faces_emore" config.num_classes = 85742 config.num_image = 5822653 config.num_epoch = 16 config.warmup_epoch = -1 config.decay_epoch = [8, 14, ] config.val_targets = ["lfw", ] elif config.dataset == "ms1m-retinaface-t1": config.rec = "/train_tmp/ms1m-retinaface-t1" config.num_classes = 93431 config.num_image = 5179510 config.num_epoch = 25 config.warmup_epoch = -1 config.decay_epoch = [11, 17, 22] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] elif config.dataset == "glint360k": config.rec = "/train_tmp/glint360k" config.num_classes = 360232 config.num_image = 17091657 config.num_epoch = 20 config.warmup_epoch = -1 config.decay_epoch = [8, 12, 15, 18] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] elif config.dataset == "webface": config.rec = "/train_tmp/faces_webface_112x112" config.num_classes = 10572 config.num_image = "forget" config.num_epoch = 34 config.warmup_epoch = -1 config.decay_epoch = [20, 28, 32] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/glint360k_mbf.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "cosface" config.network = "mbf" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 0.1 config.fp16 = True config.momentum = 0.9 config.weight_decay = 2e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/glint360k" config.num_classes = 360232 config.num_image = 17091657 config.num_epoch = 20 config.warmup_epoch = -1 config.decay_epoch = [8, 12, 15, 18] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/glint360k_r100.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "cosface" config.network = "r100" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/glint360k" config.num_classes = 360232 config.num_image = 17091657 config.num_epoch = 20 config.warmup_epoch = -1 config.decay_epoch = [8, 12, 15, 18] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/glint360k_r18.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "cosface" config.network = "r18" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/glint360k" config.num_classes = 360232 config.num_image = 17091657 config.num_epoch = 20 config.warmup_epoch = -1 config.decay_epoch = [8, 12, 15, 18] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/glint360k_r34.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "cosface" config.network = "r34" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/glint360k" config.num_classes = 360232 config.num_image = 17091657 config.num_epoch = 20 config.warmup_epoch = -1 config.decay_epoch = [8, 12, 15, 18] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/glint360k_r50.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "cosface" config.network = "r50" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/glint360k" config.num_classes = 360232 config.num_image = 17091657 config.num_epoch = 20 config.warmup_epoch = -1 config.decay_epoch = [8, 12, 15, 18] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "arcface" config.network = "mbf" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 2e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/ms1m-retinaface-t1" config.num_classes = 93431 config.num_image = 5179510 config.num_epoch = 30 config.warmup_epoch = -1 config.decay_epoch = [10, 20, 25] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r18.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "arcface" config.network = "r18" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/ms1m-retinaface-t1" config.num_classes = 93431 config.num_image = 5179510 config.num_epoch = 25 config.warmup_epoch = -1 config.decay_epoch = [10, 16, 22] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "arcface" config.network = "r2060" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 64 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/ms1m-retinaface-t1" config.num_classes = 93431 config.num_image = 5179510 config.num_epoch = 25 config.warmup_epoch = -1 config.decay_epoch = [10, 16, 22] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r34.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "arcface" config.network = "r34" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/ms1m-retinaface-t1" config.num_classes = 93431 config.num_image = 5179510 config.num_epoch = 25 config.warmup_epoch = -1 config.decay_epoch = [10, 16, 22] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r50.py ================================================ from easydict import EasyDict as edict # make training faster # our RAM is 256G # mount -t tmpfs -o size=140G tmpfs /train_tmp config = edict() config.loss = "arcface" config.network = "r50" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "/train_tmp/ms1m-retinaface-t1" config.num_classes = 93431 config.num_image = 5179510 config.num_epoch = 25 config.warmup_epoch = -1 config.decay_epoch = [10, 16, 22] config.val_targets = ["lfw", "cfp_fp", "agedb_30"] ================================================ FILE: src/face3d/models/arcface_torch/configs/speed.py ================================================ from easydict import EasyDict as edict # configs for test speed config = edict() config.loss = "arcface" config.network = "r50" config.resume = False config.output = None config.embedding_size = 512 config.sample_rate = 1.0 config.fp16 = True config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 128 config.lr = 0.1 # batch size is 512 config.rec = "synthetic" config.num_classes = 100 * 10000 config.num_epoch = 30 config.warmup_epoch = -1 config.decay_epoch = [10, 16, 22] config.val_targets = [] ================================================ FILE: src/face3d/models/arcface_torch/dataset.py ================================================ import numbers import os import queue as Queue import threading import mxnet as mx import numpy as np import torch from torch.utils.data import DataLoader, Dataset from torchvision import transforms class BackgroundGenerator(threading.Thread): def __init__(self, generator, local_rank, max_prefetch=6): super(BackgroundGenerator, self).__init__() self.queue = Queue.Queue(max_prefetch) self.generator = generator self.local_rank = local_rank self.daemon = True self.start() def run(self): torch.cuda.set_device(self.local_rank) for item in self.generator: self.queue.put(item) self.queue.put(None) def next(self): next_item = self.queue.get() if next_item is None: raise StopIteration return next_item def __next__(self): return self.next() def __iter__(self): return self class DataLoaderX(DataLoader): def __init__(self, local_rank, **kwargs): super(DataLoaderX, self).__init__(**kwargs) self.stream = torch.cuda.Stream(local_rank) self.local_rank = local_rank def __iter__(self): self.iter = super(DataLoaderX, self).__iter__() self.iter = BackgroundGenerator(self.iter, self.local_rank) self.preload() return self def preload(self): self.batch = next(self.iter, None) if self.batch is None: return None with torch.cuda.stream(self.stream): for k in range(len(self.batch)): self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) def __next__(self): torch.cuda.current_stream().wait_stream(self.stream) batch = self.batch if batch is None: raise StopIteration self.preload() return batch class MXFaceDataset(Dataset): def __init__(self, root_dir, local_rank): super(MXFaceDataset, self).__init__() self.transform = transforms.Compose( [transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) self.root_dir = root_dir self.local_rank = local_rank path_imgrec = os.path.join(root_dir, 'train.rec') path_imgidx = os.path.join(root_dir, 'train.idx') self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') s = self.imgrec.read_idx(0) header, _ = mx.recordio.unpack(s) if header.flag > 0: self.header0 = (int(header.label[0]), int(header.label[1])) self.imgidx = np.array(range(1, int(header.label[0]))) else: self.imgidx = np.array(list(self.imgrec.keys)) def __getitem__(self, index): idx = self.imgidx[index] s = self.imgrec.read_idx(idx) header, img = mx.recordio.unpack(s) label = header.label if not isinstance(label, numbers.Number): label = label[0] label = torch.tensor(label, dtype=torch.long) sample = mx.image.imdecode(img).asnumpy() if self.transform is not None: sample = self.transform(sample) return sample, label def __len__(self): return len(self.imgidx) class SyntheticDataset(Dataset): def __init__(self, local_rank): super(SyntheticDataset, self).__init__() img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) img = np.transpose(img, (2, 0, 1)) img = torch.from_numpy(img).squeeze(0).float() img = ((img / 255) - 0.5) / 0.5 self.img = img self.label = 1 def __getitem__(self, index): return self.img, self.label def __len__(self): return 1000000 ================================================ FILE: src/face3d/models/arcface_torch/docs/eval.md ================================================ ## Eval on ICCV2021-MFR coming soon. ## Eval IJBC You can eval ijbc with pytorch or onnx. 1. Eval IJBC With Onnx ```shell CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 ``` 2. Eval IJBC With Pytorch ```shell CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ --model-prefix ms1mv3_arcface_r50/backbone.pth \ --image-path IJB_release/IJBC \ --result-dir ms1mv3_arcface_r50 \ --batch-size 128 \ --job ms1mv3_arcface_r50 \ --target IJBC \ --network iresnet50 ``` ## Inference ```shell python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 ``` ================================================ FILE: src/face3d/models/arcface_torch/docs/install.md ================================================ ## v1.8.0 ### Linux and Windows ```shell # CUDA 11.0 pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html # CUDA 10.2 pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 # CPU only pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html ``` ## v1.7.1 ### Linux and Windows ```shell # CUDA 11.0 pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html # CUDA 10.2 pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 # CUDA 10.1 pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html # CUDA 9.2 pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html # CPU only pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html ``` ## v1.6.0 ### Linux and Windows ```shell # CUDA 10.2 pip install torch==1.6.0 torchvision==0.7.0 # CUDA 10.1 pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html # CUDA 9.2 pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html # CPU only pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html ``` ================================================ FILE: src/face3d/models/arcface_torch/docs/modelzoo.md ================================================ ================================================ FILE: src/face3d/models/arcface_torch/docs/speed_benchmark.md ================================================ ## Test Training Speed - Test Commands You need to use the following two commands to test the Partial FC training performance. The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, batch size is 1024. ```shell # Model Parallel python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions # Partial FC 0.1 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc ``` - GPU Memory ``` # (Model Parallel) gpustat -i [0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB [1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB [2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB [3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB [4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB [5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB [6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB [7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB # (Partial FC 0.1) gpustat -i [0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· [1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· [2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· [3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· [4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· [5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· [6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· [7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· ``` - Training Speed ```python # (Model Parallel) trainging.log Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 # (Partial FC 0.1) trainging.log Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 ``` In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, and the training speed is 2.5 times faster than the model parallel. ## Speed Benchmark 1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | | :--- | :--- | :--- | :--- | |125000 | 4681 | 4824 | 5004 | |250000 | 4047 | 4521 | 4976 | |500000 | 3087 | 4013 | 4900 | |1000000 | 2090 | 3449 | 4803 | |1400000 | 1672 | 3043 | 4738 | |2000000 | - | 2593 | 4626 | |4000000 | - | 1748 | 4208 | |5500000 | - | 1389 | 3975 | |8000000 | - | - | 3565 | |16000000 | - | - | 2679 | |29000000 | - | - | 1855 | 2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | | :--- | :--- | :--- | :--- | |125000 | 7358 | 5306 | 4868 | |250000 | 9940 | 5826 | 5004 | |500000 | 14220 | 7114 | 5202 | |1000000 | 23708 | 9966 | 5620 | |1400000 | 32252 | 11178 | 6056 | |2000000 | - | 13978 | 6472 | |4000000 | - | 23238 | 8284 | |5500000 | - | 32188 | 9854 | |8000000 | - | - | 12310 | |16000000 | - | - | 19950 | |29000000 | - | - | 32324 | ================================================ FILE: src/face3d/models/arcface_torch/eval/__init__.py ================================================ ================================================ FILE: src/face3d/models/arcface_torch/eval/verification.py ================================================ """Helper for evaluation on the Labeled Faces in the Wild dataset """ # MIT License # # Copyright (c) 2016 David Sandberg # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import datetime import os import pickle import mxnet as mx import numpy as np import sklearn import torch from mxnet import ndarray as nd from scipy import interpolate from sklearn.decomposition import PCA from sklearn.model_selection import KFold class LFold: def __init__(self, n_splits=2, shuffle=False): self.n_splits = n_splits if self.n_splits > 1: self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) def split(self, indices): if self.n_splits > 1: return self.k_fold.split(indices) else: return [(indices, indices)] def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, pca=0): assert (embeddings1.shape[0] == embeddings2.shape[0]) assert (embeddings1.shape[1] == embeddings2.shape[1]) nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) nrof_thresholds = len(thresholds) k_fold = LFold(n_splits=nrof_folds, shuffle=False) tprs = np.zeros((nrof_folds, nrof_thresholds)) fprs = np.zeros((nrof_folds, nrof_thresholds)) accuracy = np.zeros((nrof_folds)) indices = np.arange(nrof_pairs) if pca == 0: diff = np.subtract(embeddings1, embeddings2) dist = np.sum(np.square(diff), 1) for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): if pca > 0: print('doing pca on', fold_idx) embed1_train = embeddings1[train_set] embed2_train = embeddings2[train_set] _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) pca_model = PCA(n_components=pca) pca_model.fit(_embed_train) embed1 = pca_model.transform(embeddings1) embed2 = pca_model.transform(embeddings2) embed1 = sklearn.preprocessing.normalize(embed1) embed2 = sklearn.preprocessing.normalize(embed2) diff = np.subtract(embed1, embed2) dist = np.sum(np.square(diff), 1) # Find the best threshold for the fold acc_train = np.zeros((nrof_thresholds)) for threshold_idx, threshold in enumerate(thresholds): _, _, acc_train[threshold_idx] = calculate_accuracy( threshold, dist[train_set], actual_issame[train_set]) best_threshold_index = np.argmax(acc_train) for threshold_idx, threshold in enumerate(thresholds): tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( threshold, dist[test_set], actual_issame[test_set]) _, _, accuracy[fold_idx] = calculate_accuracy( thresholds[best_threshold_index], dist[test_set], actual_issame[test_set]) tpr = np.mean(tprs, 0) fpr = np.mean(fprs, 0) return tpr, fpr, accuracy def calculate_accuracy(threshold, dist, actual_issame): predict_issame = np.less(dist, threshold) tp = np.sum(np.logical_and(predict_issame, actual_issame)) fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) tn = np.sum( np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame))) fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) acc = float(tp + tn) / dist.size return tpr, fpr, acc def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10): assert (embeddings1.shape[0] == embeddings2.shape[0]) assert (embeddings1.shape[1] == embeddings2.shape[1]) nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) nrof_thresholds = len(thresholds) k_fold = LFold(n_splits=nrof_folds, shuffle=False) val = np.zeros(nrof_folds) far = np.zeros(nrof_folds) diff = np.subtract(embeddings1, embeddings2) dist = np.sum(np.square(diff), 1) indices = np.arange(nrof_pairs) for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): # Find the threshold that gives FAR = far_target far_train = np.zeros(nrof_thresholds) for threshold_idx, threshold in enumerate(thresholds): _, far_train[threshold_idx] = calculate_val_far( threshold, dist[train_set], actual_issame[train_set]) if np.max(far_train) >= far_target: f = interpolate.interp1d(far_train, thresholds, kind='slinear') threshold = f(far_target) else: threshold = 0.0 val[fold_idx], far[fold_idx] = calculate_val_far( threshold, dist[test_set], actual_issame[test_set]) val_mean = np.mean(val) far_mean = np.mean(far) val_std = np.std(val) return val_mean, val_std, far_mean def calculate_val_far(threshold, dist, actual_issame): predict_issame = np.less(dist, threshold) true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) false_accept = np.sum( np.logical_and(predict_issame, np.logical_not(actual_issame))) n_same = np.sum(actual_issame) n_diff = np.sum(np.logical_not(actual_issame)) # print(true_accept, false_accept) # print(n_same, n_diff) val = float(true_accept) / float(n_same) far = float(false_accept) / float(n_diff) return val, far def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): # Calculate evaluation metrics thresholds = np.arange(0, 4, 0.01) embeddings1 = embeddings[0::2] embeddings2 = embeddings[1::2] tpr, fpr, accuracy = calculate_roc(thresholds, embeddings1, embeddings2, np.asarray(actual_issame), nrof_folds=nrof_folds, pca=pca) thresholds = np.arange(0, 4, 0.001) val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2, np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds) return tpr, fpr, accuracy, val, val_std, far @torch.no_grad() def load_bin(path, image_size): try: with open(path, 'rb') as f: bins, issame_list = pickle.load(f) # py2 except UnicodeDecodeError as e: with open(path, 'rb') as f: bins, issame_list = pickle.load(f, encoding='bytes') # py3 data_list = [] for flip in [0, 1]: data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) data_list.append(data) for idx in range(len(issame_list) * 2): _bin = bins[idx] img = mx.image.imdecode(_bin) if img.shape[1] != image_size[0]: img = mx.image.resize_short(img, image_size[0]) img = nd.transpose(img, axes=(2, 0, 1)) for flip in [0, 1]: if flip == 1: img = mx.ndarray.flip(data=img, axis=2) data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) if idx % 1000 == 0: print('loading bin', idx) print(data_list[0].shape) return data_list, issame_list @torch.no_grad() def test(data_set, backbone, batch_size, nfolds=10): print('testing verification..') data_list = data_set[0] issame_list = data_set[1] embeddings_list = [] time_consumed = 0.0 for i in range(len(data_list)): data = data_list[i] embeddings = None ba = 0 while ba < data.shape[0]: bb = min(ba + batch_size, data.shape[0]) count = bb - ba _data = data[bb - batch_size: bb] time0 = datetime.datetime.now() img = ((_data / 255) - 0.5) / 0.5 net_out: torch.Tensor = backbone(img) _embeddings = net_out.detach().cpu().numpy() time_now = datetime.datetime.now() diff = time_now - time0 time_consumed += diff.total_seconds() if embeddings is None: embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] ba = bb embeddings_list.append(embeddings) _xnorm = 0.0 _xnorm_cnt = 0 for embed in embeddings_list: for i in range(embed.shape[0]): _em = embed[i] _norm = np.linalg.norm(_em) _xnorm += _norm _xnorm_cnt += 1 _xnorm /= _xnorm_cnt acc1 = 0.0 std1 = 0.0 embeddings = embeddings_list[0] + embeddings_list[1] embeddings = sklearn.preprocessing.normalize(embeddings) print(embeddings.shape) print('infer time', time_consumed) _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) acc2, std2 = np.mean(accuracy), np.std(accuracy) return acc1, std1, acc2, std2, _xnorm, embeddings_list def dumpR(data_set, backbone, batch_size, name='', data_extra=None, label_shape=None): print('dump verification embedding..') data_list = data_set[0] issame_list = data_set[1] embeddings_list = [] time_consumed = 0.0 for i in range(len(data_list)): data = data_list[i] embeddings = None ba = 0 while ba < data.shape[0]: bb = min(ba + batch_size, data.shape[0]) count = bb - ba _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) time0 = datetime.datetime.now() if data_extra is None: db = mx.io.DataBatch(data=(_data,), label=(_label,)) else: db = mx.io.DataBatch(data=(_data, _data_extra), label=(_label,)) model.forward(db, is_train=False) net_out = model.get_outputs() _embeddings = net_out[0].asnumpy() time_now = datetime.datetime.now() diff = time_now - time0 time_consumed += diff.total_seconds() if embeddings is None: embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] ba = bb embeddings_list.append(embeddings) embeddings = embeddings_list[0] + embeddings_list[1] embeddings = sklearn.preprocessing.normalize(embeddings) actual_issame = np.asarray(issame_list) outname = os.path.join('temp.bin') with open(outname, 'wb') as f: pickle.dump((embeddings, issame_list), f, protocol=pickle.HIGHEST_PROTOCOL) # if __name__ == '__main__': # # parser = argparse.ArgumentParser(description='do verification') # # general # parser.add_argument('--data-dir', default='', help='') # parser.add_argument('--model', # default='../model/softmax,50', # help='path to load model.') # parser.add_argument('--target', # default='lfw,cfp_ff,cfp_fp,agedb_30', # help='test targets.') # parser.add_argument('--gpu', default=0, type=int, help='gpu id') # parser.add_argument('--batch-size', default=32, type=int, help='') # parser.add_argument('--max', default='', type=str, help='') # parser.add_argument('--mode', default=0, type=int, help='') # parser.add_argument('--nfolds', default=10, type=int, help='') # args = parser.parse_args() # image_size = [112, 112] # print('image_size', image_size) # ctx = mx.gpu(args.gpu) # nets = [] # vec = args.model.split(',') # prefix = args.model.split(',')[0] # epochs = [] # if len(vec) == 1: # pdir = os.path.dirname(prefix) # for fname in os.listdir(pdir): # if not fname.endswith('.params'): # continue # _file = os.path.join(pdir, fname) # if _file.startswith(prefix): # epoch = int(fname.split('.')[0].split('-')[1]) # epochs.append(epoch) # epochs = sorted(epochs, reverse=True) # if len(args.max) > 0: # _max = [int(x) for x in args.max.split(',')] # assert len(_max) == 2 # if len(epochs) > _max[1]: # epochs = epochs[_max[0]:_max[1]] # # else: # epochs = [int(x) for x in vec[1].split('|')] # print('model number', len(epochs)) # time0 = datetime.datetime.now() # for epoch in epochs: # print('loading', prefix, epoch) # sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) # # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) # all_layers = sym.get_internals() # sym = all_layers['fc1_output'] # model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) # # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], # image_size[1]))]) # model.set_params(arg_params, aux_params) # nets.append(model) # time_now = datetime.datetime.now() # diff = time_now - time0 # print('model loading time', diff.total_seconds()) # # ver_list = [] # ver_name_list = [] # for name in args.target.split(','): # path = os.path.join(args.data_dir, name + ".bin") # if os.path.exists(path): # print('loading.. ', name) # data_set = load_bin(path, image_size) # ver_list.append(data_set) # ver_name_list.append(name) # # if args.mode == 0: # for i in range(len(ver_list)): # results = [] # for model in nets: # acc1, std1, acc2, std2, xnorm, embeddings_list = test( # ver_list[i], model, args.batch_size, args.nfolds) # print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) # print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) # print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) # results.append(acc2) # print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) # elif args.mode == 1: # raise ValueError # else: # model = nets[0] # dumpR(ver_list[0], model, args.batch_size, args.target) ================================================ FILE: src/face3d/models/arcface_torch/eval_ijbc.py ================================================ # coding: utf-8 import os import pickle import matplotlib import pandas as pd matplotlib.use('Agg') import matplotlib.pyplot as plt import timeit import sklearn import argparse import cv2 import numpy as np import torch from skimage import transform as trans from backbones import get_model from sklearn.metrics import roc_curve, auc from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap from prettytable import PrettyTable from pathlib import Path import sys import warnings sys.path.insert(0, "../") warnings.filterwarnings("ignore") parser = argparse.ArgumentParser(description='do ijb test') # general parser.add_argument('--model-prefix', default='', help='path to load model.') parser.add_argument('--image-path', default='', type=str, help='') parser.add_argument('--result-dir', default='.', type=str, help='') parser.add_argument('--batch-size', default=128, type=int, help='') parser.add_argument('--network', default='iresnet50', type=str, help='') parser.add_argument('--job', default='insightface', type=str, help='job name') parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') args = parser.parse_args() target = args.target model_path = args.model_prefix image_path = args.image_path result_dir = args.result_dir gpu_id = None use_norm_score = True # if Ture, TestMode(N1) use_detector_score = True # if Ture, TestMode(D1) use_flip_test = True # if Ture, TestMode(F1) job = args.job batch_size = args.batch_size class Embedding(object): def __init__(self, prefix, data_shape, batch_size=1): image_size = (112, 112) self.image_size = image_size weight = torch.load(prefix) resnet = get_model(args.network, dropout=0, fp16=False).cuda() resnet.load_state_dict(weight) model = torch.nn.DataParallel(resnet) self.model = model self.model.eval() src = np.array([ [30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041]], dtype=np.float32) src[:, 0] += 8.0 self.src = src self.batch_size = batch_size self.data_shape = data_shape def get(self, rimg, landmark): assert landmark.shape[0] == 68 or landmark.shape[0] == 5 assert landmark.shape[1] == 2 if landmark.shape[0] == 68: landmark5 = np.zeros((5, 2), dtype=np.float32) landmark5[0] = (landmark[36] + landmark[39]) / 2 landmark5[1] = (landmark[42] + landmark[45]) / 2 landmark5[2] = landmark[30] landmark5[3] = landmark[48] landmark5[4] = landmark[54] else: landmark5 = landmark tform = trans.SimilarityTransform() tform.estimate(landmark5, self.src) M = tform.params[0:2, :] img = cv2.warpAffine(rimg, M, (self.image_size[1], self.image_size[0]), borderValue=0.0) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_flip = np.fliplr(img) img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB img_flip = np.transpose(img_flip, (2, 0, 1)) input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8) input_blob[0] = img input_blob[1] = img_flip return input_blob @torch.no_grad() def forward_db(self, batch_data): imgs = torch.Tensor(batch_data).cuda() imgs.div_(255).sub_(0.5).div_(0.5) feat = self.model(imgs) feat = feat.reshape([self.batch_size, 2 * feat.shape[1]]) return feat.cpu().numpy() # 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[] def divideIntoNstrand(listTemp, n): twoList = [[] for i in range(n)] for i, e in enumerate(listTemp): twoList[i % n].append(e) return twoList def read_template_media_list(path): # ijb_meta = np.loadtxt(path, dtype=str) ijb_meta = pd.read_csv(path, sep=' ', header=None).values templates = ijb_meta[:, 1].astype(np.int) medias = ijb_meta[:, 2].astype(np.int) return templates, medias # In[ ]: def read_template_pair_list(path): # pairs = np.loadtxt(path, dtype=str) pairs = pd.read_csv(path, sep=' ', header=None).values # print(pairs.shape) # print(pairs[:, 0].astype(np.int)) t1 = pairs[:, 0].astype(np.int) t2 = pairs[:, 1].astype(np.int) label = pairs[:, 2].astype(np.int) return t1, t2, label # In[ ]: def read_image_feature(path): with open(path, 'rb') as fid: img_feats = pickle.load(fid) return img_feats # In[ ]: def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): batch_size = args.batch_size data_shape = (3, 112, 112) files = files_list print('files:', len(files)) rare_size = len(files) % batch_size faceness_scores = [] batch = 0 img_feats = np.empty((len(files), 1024), dtype=np.float32) batch_data = np.empty((2 * batch_size, 3, 112, 112)) embedding = Embedding(model_path, data_shape, batch_size) for img_index, each_line in enumerate(files[:len(files) - rare_size]): name_lmk_score = each_line.strip().split(' ') img_name = os.path.join(img_path, name_lmk_score[0]) img = cv2.imread(img_name) lmk = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32) lmk = lmk.reshape((5, 2)) input_blob = embedding.get(img, lmk) batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] if (img_index + 1) % batch_size == 0: print('batch', batch) img_feats[batch * batch_size:batch * batch_size + batch_size][:] = embedding.forward_db(batch_data) batch += 1 faceness_scores.append(name_lmk_score[-1]) batch_data = np.empty((2 * rare_size, 3, 112, 112)) embedding = Embedding(model_path, data_shape, rare_size) for img_index, each_line in enumerate(files[len(files) - rare_size:]): name_lmk_score = each_line.strip().split(' ') img_name = os.path.join(img_path, name_lmk_score[0]) img = cv2.imread(img_name) lmk = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32) lmk = lmk.reshape((5, 2)) input_blob = embedding.get(img, lmk) batch_data[2 * img_index][:] = input_blob[0] batch_data[2 * img_index + 1][:] = input_blob[1] if (img_index + 1) % rare_size == 0: print('batch', batch) img_feats[len(files) - rare_size:][:] = embedding.forward_db(batch_data) batch += 1 faceness_scores.append(name_lmk_score[-1]) faceness_scores = np.array(faceness_scores).astype(np.float32) # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01 # faceness_scores = np.ones( (len(files), ), dtype=np.float32 ) return img_feats, faceness_scores # In[ ]: def image2template_feature(img_feats=None, templates=None, medias=None): # ========================================================== # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] # 2. compute media feature. # 3. compute template feature. # ========================================================== unique_templates = np.unique(templates) template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) for count_template, uqt in enumerate(unique_templates): (ind_t,) = np.where(templates == uqt) face_norm_feats = img_feats[ind_t] face_medias = medias[ind_t] unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) media_norm_feats = [] for u, ct in zip(unique_medias, unique_media_counts): (ind_m,) = np.where(face_medias == u) if ct == 1: media_norm_feats += [face_norm_feats[ind_m]] else: # image features from the same video will be aggregated into one feature media_norm_feats += [ np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) ] media_norm_feats = np.array(media_norm_feats) # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) template_feats[count_template] = np.sum(media_norm_feats, axis=0) if count_template % 2000 == 0: print('Finish Calculating {} template features.'.format( count_template)) # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) template_norm_feats = sklearn.preprocessing.normalize(template_feats) # print(template_norm_feats.shape) return template_norm_feats, unique_templates # In[ ]: def verification(template_norm_feats=None, unique_templates=None, p1=None, p2=None): # ========================================================== # Compute set-to-set Similarity Score. # ========================================================== template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) for count_template, uqt in enumerate(unique_templates): template2id[uqt] = count_template score = np.zeros((len(p1),)) # save cosine distance between pairs total_pairs = np.array(range(len(p1))) batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation sublists = [ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) ] total_sublists = len(sublists) for c, s in enumerate(sublists): feat1 = template_norm_feats[template2id[p1[s]]] feat2 = template_norm_feats[template2id[p2[s]]] similarity_score = np.sum(feat1 * feat2, -1) score[s] = similarity_score.flatten() if c % 10 == 0: print('Finish {}/{} pairs.'.format(c, total_sublists)) return score # In[ ]: def verification2(template_norm_feats=None, unique_templates=None, p1=None, p2=None): template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) for count_template, uqt in enumerate(unique_templates): template2id[uqt] = count_template score = np.zeros((len(p1),)) # save cosine distance between pairs total_pairs = np.array(range(len(p1))) batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation sublists = [ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) ] total_sublists = len(sublists) for c, s in enumerate(sublists): feat1 = template_norm_feats[template2id[p1[s]]] feat2 = template_norm_feats[template2id[p2[s]]] similarity_score = np.sum(feat1 * feat2, -1) score[s] = similarity_score.flatten() if c % 10 == 0: print('Finish {}/{} pairs.'.format(c, total_sublists)) return score def read_score(path): with open(path, 'rb') as fid: img_feats = pickle.load(fid) return img_feats # # Step1: Load Meta Data # In[ ]: assert target == 'IJBC' or target == 'IJBB' # ============================================================= # load image and template relationships for template feature embedding # tid --> template id, mid --> media id # format: # image_name tid mid # ============================================================= start = timeit.default_timer() templates, medias = read_template_media_list( os.path.join('%s/meta' % image_path, '%s_face_tid_mid.txt' % target.lower())) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) # In[ ]: # ============================================================= # load template pairs for template-to-template verification # tid : template id, label : 1/0 # format: # tid_1 tid_2 label # ============================================================= start = timeit.default_timer() p1, p2, label = read_template_pair_list( os.path.join('%s/meta' % image_path, '%s_template_pair_label.txt' % target.lower())) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) # # Step 2: Get Image Features # In[ ]: # ============================================================= # load image features # format: # img_feats: [image_num x feats_dim] (227630, 512) # ============================================================= start = timeit.default_timer() img_path = '%s/loose_crop' % image_path img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower()) img_list = open(img_list_path) files = img_list.readlines() # files_list = divideIntoNstrand(files, rank_size) files_list = files # img_feats # for i in range(rank_size): img_feats, faceness_scores = get_image_feature(img_path, files_list, model_path, 0, gpu_id) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) # # Step3: Get Template Features # In[ ]: # ============================================================= # compute template features from image features. # ============================================================= start = timeit.default_timer() # ========================================================== # Norm feature before aggregation into template feature? # Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face). # ========================================================== # 1. FaceScore (Feature Norm) # 2. FaceScore (Detector) if use_flip_test: # concat --- F1 # img_input_feats = img_feats # add --- F2 img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] else: img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] if use_norm_score: img_input_feats = img_input_feats else: # normalise features to remove norm information img_input_feats = img_input_feats / np.sqrt( np.sum(img_input_feats ** 2, -1, keepdims=True)) if use_detector_score: print(img_input_feats.shape, faceness_scores.shape) img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] else: img_input_feats = img_input_feats template_norm_feats, unique_templates = image2template_feature( img_input_feats, templates, medias) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) # # Step 4: Get Template Similarity Scores # In[ ]: # ============================================================= # compute verification scores between template pairs. # ============================================================= start = timeit.default_timer() score = verification(template_norm_feats, unique_templates, p1, p2) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) # In[ ]: save_path = os.path.join(result_dir, args.job) # save_path = result_dir + '/%s_result' % target if not os.path.exists(save_path): os.makedirs(save_path) score_save_file = os.path.join(save_path, "%s.npy" % target.lower()) np.save(score_save_file, score) # # Step 5: Get ROC Curves and TPR@FPR Table # In[ ]: files = [score_save_file] methods = [] scores = [] for file in files: methods.append(Path(file).stem) scores.append(np.load(file)) methods = np.array(methods) scores = dict(zip(methods, scores)) colours = dict( zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) fig = plt.figure() for method in methods: fpr, tpr, _ = roc_curve(label, scores[method]) roc_auc = auc(fpr, tpr) fpr = np.flipud(fpr) tpr = np.flipud(tpr) # select largest tpr at same fpr plt.plot(fpr, tpr, color=colours[method], lw=1, label=('[%s (AUC = %0.4f %%)]' % (method.split('-')[-1], roc_auc * 100))) tpr_fpr_row = [] tpr_fpr_row.append("%s-%s" % (method, target)) for fpr_iter in np.arange(len(x_labels)): _, min_index = min( list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) tpr_fpr_table.add_row(tpr_fpr_row) plt.xlim([10 ** -6, 0.1]) plt.ylim([0.3, 1.0]) plt.grid(linestyle='--', linewidth=1) plt.xticks(x_labels) plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) plt.xscale('log') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC on IJB') plt.legend(loc="lower right") fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower())) print(tpr_fpr_table) ================================================ FILE: src/face3d/models/arcface_torch/inference.py ================================================ import argparse import cv2 import numpy as np import torch from backbones import get_model @torch.no_grad() def inference(weight, name, img): if img is None: img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) else: img = cv2.imread(img) img = cv2.resize(img, (112, 112)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = np.transpose(img, (2, 0, 1)) img = torch.from_numpy(img).unsqueeze(0).float() img.div_(255).sub_(0.5).div_(0.5) net = get_model(name, fp16=False) net.load_state_dict(torch.load(weight)) net.eval() feat = net(img).numpy() print(feat) if __name__ == "__main__": parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') parser.add_argument('--network', type=str, default='r50', help='backbone network') parser.add_argument('--weight', type=str, default='') parser.add_argument('--img', type=str, default=None) args = parser.parse_args() inference(args.weight, args.network, args.img) ================================================ FILE: src/face3d/models/arcface_torch/losses.py ================================================ import torch from torch import nn def get_loss(name): if name == "cosface": return CosFace() elif name == "arcface": return ArcFace() else: raise ValueError() class CosFace(nn.Module): def __init__(self, s=64.0, m=0.40): super(CosFace, self).__init__() self.s = s self.m = m def forward(self, cosine, label): index = torch.where(label != -1)[0] m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) m_hot.scatter_(1, label[index, None], self.m) cosine[index] -= m_hot ret = cosine * self.s return ret class ArcFace(nn.Module): def __init__(self, s=64.0, m=0.5): super(ArcFace, self).__init__() self.s = s self.m = m def forward(self, cosine: torch.Tensor, label): index = torch.where(label != -1)[0] m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) m_hot.scatter_(1, label[index, None], self.m) cosine.acos_() cosine[index] += m_hot cosine.cos_().mul_(self.s) return cosine ================================================ FILE: src/face3d/models/arcface_torch/onnx_helper.py ================================================ from __future__ import division import datetime import os import os.path as osp import glob import numpy as np import cv2 import sys import onnxruntime import onnx import argparse from onnx import numpy_helper from insightface.data import get_image class ArcFaceORT: def __init__(self, model_path, cpu=False): self.model_path = model_path # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider" self.providers = ['CPUExecutionProvider'] if cpu else None #input_size is (w,h), return error message, return None if success def check(self, track='cfat', test_img = None): #default is cfat max_model_size_mb=1024 max_feat_dim=512 max_time_cost=15 if track.startswith('ms1m'): max_model_size_mb=1024 max_feat_dim=512 max_time_cost=10 elif track.startswith('glint'): max_model_size_mb=1024 max_feat_dim=1024 max_time_cost=20 elif track.startswith('cfat'): max_model_size_mb = 1024 max_feat_dim = 512 max_time_cost = 15 elif track.startswith('unconstrained'): max_model_size_mb=1024 max_feat_dim=1024 max_time_cost=30 else: return "track not found" if not os.path.exists(self.model_path): return "model_path not exists" if not os.path.isdir(self.model_path): return "model_path should be directory" onnx_files = [] for _file in os.listdir(self.model_path): if _file.endswith('.onnx'): onnx_files.append(osp.join(self.model_path, _file)) if len(onnx_files)==0: return "do not have onnx files" self.model_file = sorted(onnx_files)[-1] print('use onnx-model:', self.model_file) try: session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) except: return "load onnx failed" input_cfg = session.get_inputs()[0] input_shape = input_cfg.shape print('input-shape:', input_shape) if len(input_shape)!=4: return "length of input_shape should be 4" if not isinstance(input_shape[0], str): #return "input_shape[0] should be str to support batch-inference" print('reset input-shape[0] to None') model = onnx.load(self.model_file) model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx') onnx.save(model, new_model_file) self.model_file = new_model_file print('use new onnx-model:', self.model_file) try: session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) except: return "load onnx failed" input_cfg = session.get_inputs()[0] input_shape = input_cfg.shape print('new-input-shape:', input_shape) self.image_size = tuple(input_shape[2:4][::-1]) #print('image_size:', self.image_size) input_name = input_cfg.name outputs = session.get_outputs() output_names = [] for o in outputs: output_names.append(o.name) #print(o.name, o.shape) if len(output_names)!=1: return "number of output nodes should be 1" self.session = session self.input_name = input_name self.output_names = output_names #print(self.output_names) model = onnx.load(self.model_file) graph = model.graph if len(graph.node)<8: return "too small onnx graph" input_size = (112,112) self.crop = None if track=='cfat': crop_file = osp.join(self.model_path, 'crop.txt') if osp.exists(crop_file): lines = open(crop_file,'r').readlines() if len(lines)!=6: return "crop.txt should contain 6 lines" lines = [int(x) for x in lines] self.crop = lines[:4] input_size = tuple(lines[4:6]) if input_size!=self.image_size: return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size) self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024) if self.model_size_mb > max_model_size_mb: return "max model size exceed, given %.3f-MB"%self.model_size_mb input_mean = None input_std = None if track=='cfat': pn_file = osp.join(self.model_path, 'pixel_norm.txt') if osp.exists(pn_file): lines = open(pn_file,'r').readlines() if len(lines)!=2: return "pixel_norm.txt should contain 2 lines" input_mean = float(lines[0]) input_std = float(lines[1]) if input_mean is not None or input_std is not None: if input_mean is None or input_std is None: return "please set input_mean and input_std simultaneously" else: find_sub = False find_mul = False for nid, node in enumerate(graph.node[:8]): print(nid, node.name) if node.name.startswith('Sub') or node.name.startswith('_minus'): find_sub = True if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'): find_mul = True if find_sub and find_mul: print("find sub and mul") #mxnet arcface model input_mean = 0.0 input_std = 1.0 else: input_mean = 127.5 input_std = 127.5 self.input_mean = input_mean self.input_std = input_std for initn in graph.initializer: weight_array = numpy_helper.to_array(initn) dt = weight_array.dtype if dt.itemsize<4: return 'invalid weight type - (%s:%s)' % (initn.name, dt.name) if test_img is None: test_img = get_image('Tom_Hanks_54745') test_img = cv2.resize(test_img, self.image_size) else: test_img = cv2.resize(test_img, self.image_size) feat, cost = self.benchmark(test_img) batch_result = self.check_batch(test_img) batch_result_sum = float(np.sum(batch_result)) if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum: print(batch_result) print(batch_result_sum) return "batch result output contains NaN!" if len(feat.shape) < 2: return "the shape of the feature must be two, but get {}".format(str(feat.shape)) if feat.shape[1] > max_feat_dim: return "max feat dim exceed, given %d"%feat.shape[1] self.feat_dim = feat.shape[1] cost_ms = cost*1000 if cost_ms>max_time_cost: return "max time cost exceed, given %.4f"%cost_ms self.cost_ms = cost_ms print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)) return None def check_batch(self, img): if not isinstance(img, list): imgs = [img, ] * 32 if self.crop is not None: nimgs = [] for img in imgs: nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :] if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]: nimg = cv2.resize(nimg, self.image_size) nimgs.append(nimg) imgs = nimgs blob = cv2.dnn.blobFromImages( images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size, mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True) net_out = self.session.run(self.output_names, {self.input_name: blob})[0] return net_out def meta_info(self): return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms} def forward(self, imgs): if not isinstance(imgs, list): imgs = [imgs] input_size = self.image_size if self.crop is not None: nimgs = [] for img in imgs: nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: nimg = cv2.resize(nimg, input_size) nimgs.append(nimg) imgs = nimgs blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) net_out = self.session.run(self.output_names, {self.input_name : blob})[0] return net_out def benchmark(self, img): input_size = self.image_size if self.crop is not None: nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: nimg = cv2.resize(nimg, input_size) img = nimg blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) costs = [] for _ in range(50): ta = datetime.datetime.now() net_out = self.session.run(self.output_names, {self.input_name : blob})[0] tb = datetime.datetime.now() cost = (tb-ta).total_seconds() costs.append(cost) costs = sorted(costs) cost = costs[5] return net_out, cost if __name__ == '__main__': parser = argparse.ArgumentParser(description='') # general parser.add_argument('workdir', help='submitted work dir', type=str) parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat') args = parser.parse_args() handler = ArcFaceORT(args.workdir) err = handler.check(args.track) print('err:', err) ================================================ FILE: src/face3d/models/arcface_torch/onnx_ijbc.py ================================================ import argparse import os import pickle import timeit import cv2 import mxnet as mx import numpy as np import pandas as pd import prettytable import skimage.transform from sklearn.metrics import roc_curve from sklearn.preprocessing import normalize from onnx_helper import ArcFaceORT SRC = np.array( [ [30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041]] , dtype=np.float32) SRC[:, 0] += 8.0 class AlignedDataSet(mx.gluon.data.Dataset): def __init__(self, root, lines, align=True): self.lines = lines self.root = root self.align = align def __len__(self): return len(self.lines) def __getitem__(self, idx): each_line = self.lines[idx] name_lmk_score = each_line.strip().split(' ') name = os.path.join(self.root, name_lmk_score[0]) img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB) landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2)) st = skimage.transform.SimilarityTransform() st.estimate(landmark5, SRC) img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0) img_1 = np.expand_dims(img, 0) img_2 = np.expand_dims(np.fliplr(img), 0) output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) output = np.transpose(output, (0, 3, 1, 2)) output = mx.nd.array(output) return output def extract(model_root, dataset): model = ArcFaceORT(model_path=model_root) model.check() feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) def batchify_fn(data): return mx.nd.concat(*data, dim=0) data_loader = mx.gluon.data.DataLoader( dataset, 128, last_batch='keep', num_workers=4, thread_pool=True, prefetch=16, batchify_fn=batchify_fn) num_iter = 0 for batch in data_loader: batch = batch.asnumpy() batch = (batch - model.input_mean) / model.input_std feat = model.session.run(model.output_names, {model.input_name: batch})[0] feat = np.reshape(feat, (-1, model.feat_dim * 2)) feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat num_iter += 1 if num_iter % 50 == 0: print(num_iter) return feat_mat def read_template_media_list(path): ijb_meta = pd.read_csv(path, sep=' ', header=None).values templates = ijb_meta[:, 1].astype(np.int) medias = ijb_meta[:, 2].astype(np.int) return templates, medias def read_template_pair_list(path): pairs = pd.read_csv(path, sep=' ', header=None).values t1 = pairs[:, 0].astype(np.int) t2 = pairs[:, 1].astype(np.int) label = pairs[:, 2].astype(np.int) return t1, t2, label def read_image_feature(path): with open(path, 'rb') as fid: img_feats = pickle.load(fid) return img_feats def image2template_feature(img_feats=None, templates=None, medias=None): unique_templates = np.unique(templates) template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) for count_template, uqt in enumerate(unique_templates): (ind_t,) = np.where(templates == uqt) face_norm_feats = img_feats[ind_t] face_medias = medias[ind_t] unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) media_norm_feats = [] for u, ct in zip(unique_medias, unique_media_counts): (ind_m,) = np.where(face_medias == u) if ct == 1: media_norm_feats += [face_norm_feats[ind_m]] else: # image features from the same video will be aggregated into one feature media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ] media_norm_feats = np.array(media_norm_feats) template_feats[count_template] = np.sum(media_norm_feats, axis=0) if count_template % 2000 == 0: print('Finish Calculating {} template features.'.format( count_template)) template_norm_feats = normalize(template_feats) return template_norm_feats, unique_templates def verification(template_norm_feats=None, unique_templates=None, p1=None, p2=None): template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) for count_template, uqt in enumerate(unique_templates): template2id[uqt] = count_template score = np.zeros((len(p1),)) total_pairs = np.array(range(len(p1))) batchsize = 100000 sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)] total_sublists = len(sublists) for c, s in enumerate(sublists): feat1 = template_norm_feats[template2id[p1[s]]] feat2 = template_norm_feats[template2id[p2[s]]] similarity_score = np.sum(feat1 * feat2, -1) score[s] = similarity_score.flatten() if c % 10 == 0: print('Finish {}/{} pairs.'.format(c, total_sublists)) return score def verification2(template_norm_feats=None, unique_templates=None, p1=None, p2=None): template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) for count_template, uqt in enumerate(unique_templates): template2id[uqt] = count_template score = np.zeros((len(p1),)) # save cosine distance between pairs total_pairs = np.array(range(len(p1))) batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)] total_sublists = len(sublists) for c, s in enumerate(sublists): feat1 = template_norm_feats[template2id[p1[s]]] feat2 = template_norm_feats[template2id[p2[s]]] similarity_score = np.sum(feat1 * feat2, -1) score[s] = similarity_score.flatten() if c % 10 == 0: print('Finish {}/{} pairs.'.format(c, total_sublists)) return score def main(args): use_norm_score = True # if Ture, TestMode(N1) use_detector_score = True # if Ture, TestMode(D1) use_flip_test = True # if Ture, TestMode(F1) assert args.target == 'IJBC' or args.target == 'IJBB' start = timeit.default_timer() templates, medias = read_template_media_list( os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower())) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) start = timeit.default_timer() p1, p2, label = read_template_pair_list( os.path.join('%s/meta' % args.image_path, '%s_template_pair_label.txt' % args.target.lower())) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) start = timeit.default_timer() img_path = '%s/loose_crop' % args.image_path img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower()) img_list = open(img_list_path) files = img_list.readlines() dataset = AlignedDataSet(root=img_path, lines=files, align=True) img_feats = extract(args.model_root, dataset) faceness_scores = [] for each_line in files: name_lmk_score = each_line.split() faceness_scores.append(name_lmk_score[-1]) faceness_scores = np.array(faceness_scores).astype(np.float32) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) start = timeit.default_timer() if use_flip_test: img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] else: img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] if use_norm_score: img_input_feats = img_input_feats else: img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) if use_detector_score: print(img_input_feats.shape, faceness_scores.shape) img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] else: img_input_feats = img_input_feats template_norm_feats, unique_templates = image2template_feature( img_input_feats, templates, medias) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) start = timeit.default_timer() score = verification(template_norm_feats, unique_templates, p1, p2) stop = timeit.default_timer() print('Time: %.2f s. ' % (stop - start)) save_path = os.path.join(args.result_dir, "{}_result".format(args.target)) if not os.path.exists(save_path): os.makedirs(save_path) score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root)) np.save(score_save_file, score) files = [score_save_file] methods = [] scores = [] for file in files: methods.append(os.path.basename(file)) scores.append(np.load(file)) methods = np.array(methods) scores = dict(zip(methods, scores)) x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels]) for method in methods: fpr, tpr, _ = roc_curve(label, scores[method]) fpr = np.flipud(fpr) tpr = np.flipud(tpr) tpr_fpr_row = [] tpr_fpr_row.append("%s-%s" % (method, args.target)) for fpr_iter in np.arange(len(x_labels)): _, min_index = min( list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) tpr_fpr_table.add_row(tpr_fpr_row) print(tpr_fpr_table) if __name__ == '__main__': parser = argparse.ArgumentParser(description='do ijb test') # general parser.add_argument('--model-root', default='', help='path to load model.') parser.add_argument('--image-path', default='', type=str, help='') parser.add_argument('--result-dir', default='.', type=str, help='') parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') main(parser.parse_args()) ================================================ FILE: src/face3d/models/arcface_torch/partial_fc.py ================================================ import logging import os import torch import torch.distributed as dist from torch.nn import Module from torch.nn.functional import normalize, linear from torch.nn.parameter import Parameter class PartialFC(Module): """ Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, Partial FC: Training 10 Million Identities on a Single Machine See the original paper: https://arxiv.org/abs/2010.05222 """ @torch.no_grad() def __init__(self, rank, local_rank, world_size, batch_size, resume, margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"): """ rank: int Unique process(GPU) ID from 0 to world_size - 1. local_rank: int Unique process(GPU) ID within the server from 0 to 7. world_size: int Number of GPU. batch_size: int Batch size on current rank(GPU). resume: bool Select whether to restore the weight of softmax. margin_softmax: callable A function of margin softmax, eg: cosface, arcface. num_classes: int The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, required. sample_rate: float The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. embedding_size: int The feature dimension, default is 512. prefix: str Path for save checkpoint, default is './'. """ super(PartialFC, self).__init__() # self.num_classes: int = num_classes self.rank: int = rank self.local_rank: int = local_rank self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) self.world_size: int = world_size self.batch_size: int = batch_size self.margin_softmax: callable = margin_softmax self.sample_rate: float = sample_rate self.embedding_size: int = embedding_size self.prefix: str = prefix self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) self.num_sample: int = int(self.sample_rate * self.num_local) self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) if resume: try: self.weight: torch.Tensor = torch.load(self.weight_name) self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: raise IndexError logging.info("softmax weight resume successfully!") logging.info("softmax weight mom resume successfully!") except (FileNotFoundError, KeyError, IndexError): self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) logging.info("softmax weight init!") logging.info("softmax weight mom init!") else: self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) logging.info("softmax weight init successfully!") logging.info("softmax weight mom init successfully!") self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) self.index = None if int(self.sample_rate) == 1: self.update = lambda: 0 self.sub_weight = Parameter(self.weight) self.sub_weight_mom = self.weight_mom else: self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) def save_params(self): """ Save softmax weight for each rank on prefix """ torch.save(self.weight.data, self.weight_name) torch.save(self.weight_mom, self.weight_mom_name) @torch.no_grad() def sample(self, total_label): """ Sample all positive class centers in each rank, and random select neg class centers to filling a fixed `num_sample`. total_label: tensor Label after all gather, which cross all GPUs. """ index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) total_label[~index_positive] = -1 total_label[index_positive] -= self.class_start if int(self.sample_rate) != 1: positive = torch.unique(total_label[index_positive], sorted=True) if self.num_sample - positive.size(0) >= 0: perm = torch.rand(size=[self.num_local], device=self.device) perm[positive] = 2.0 index = torch.topk(perm, k=self.num_sample)[1] index = index.sort()[0] else: index = positive self.index = index total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) self.sub_weight = Parameter(self.weight[index]) self.sub_weight_mom = self.weight_mom[index] def forward(self, total_features, norm_weight): """ Partial fc forward, `logits = X * sample(W)` """ torch.cuda.current_stream().wait_stream(self.stream) logits = linear(total_features, norm_weight) return logits @torch.no_grad() def update(self): """ Set updated weight and weight_mom to memory bank. """ self.weight_mom[self.index] = self.sub_weight_mom self.weight[self.index] = self.sub_weight def prepare(self, label, optimizer): """ get sampled class centers for cal softmax. label: tensor Label tensor on each rank. optimizer: opt Optimizer for partial fc, which need to get weight mom. """ with torch.cuda.stream(self.stream): total_label = torch.zeros( size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) self.sample(total_label) optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) optimizer.param_groups[-1]['params'][0] = self.sub_weight optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom norm_weight = normalize(self.sub_weight) return total_label, norm_weight def forward_backward(self, label, features, optimizer): """ Partial fc forward and backward with model parallel label: tensor Label tensor on each rank(GPU) features: tensor Features tensor on each rank(GPU) optimizer: optimizer Optimizer for partial fc Returns: -------- x_grad: tensor The gradient of features. loss_v: tensor Loss value for cross entropy. """ total_label, norm_weight = self.prepare(label, optimizer) total_features = torch.zeros( size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) total_features.requires_grad = True logits = self.forward(total_features, norm_weight) logits = self.margin_softmax(logits, total_label) with torch.no_grad(): max_fc = torch.max(logits, dim=1, keepdim=True)[0] dist.all_reduce(max_fc, dist.ReduceOp.MAX) # calculate exp(logits) and all-reduce logits_exp = torch.exp(logits - max_fc) logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) # calculate prob logits_exp.div_(logits_sum_exp) # get one-hot grad = logits_exp index = torch.where(total_label != -1)[0] one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) one_hot.scatter_(1, total_label[index, None], 1) # calculate loss loss = torch.zeros(grad.size()[0], 1, device=grad.device) loss[index] = grad[index].gather(1, total_label[index, None]) dist.all_reduce(loss, dist.ReduceOp.SUM) loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) # calculate grad grad[index] -= one_hot grad.div_(self.batch_size * self.world_size) logits.backward(grad) if total_features.grad is not None: total_features.grad.detach_() x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) # feature gradient all-reduce dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) x_grad = x_grad * self.world_size # backward backbone return x_grad, loss_v ================================================ FILE: src/face3d/models/arcface_torch/requirement.txt ================================================ tensorboard easydict mxnet onnx sklearn ================================================ FILE: src/face3d/models/arcface_torch/run.sh ================================================ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh ================================================ FILE: src/face3d/models/arcface_torch/torch2onnx.py ================================================ import numpy as np import onnx import torch def convert_onnx(net, path_module, output, opset=11, simplify=False): assert isinstance(net, torch.nn.Module) img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) img = img.astype(np.float) img = (img / 255. - 0.5) / 0.5 # torch style norm img = img.transpose((2, 0, 1)) img = torch.from_numpy(img).unsqueeze(0).float() weight = torch.load(path_module) net.load_state_dict(weight) net.eval() torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) model = onnx.load(output) graph = model.graph graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' if simplify: from onnxsim import simplify model, check = simplify(model) assert check, "Simplified ONNX model could not be validated" onnx.save(model, output) if __name__ == '__main__': import os import argparse from backbones import get_model parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') parser.add_argument('input', type=str, help='input backbone.pth file or path') parser.add_argument('--output', type=str, default=None, help='output onnx path') parser.add_argument('--network', type=str, default=None, help='backbone network') parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') args = parser.parse_args() input_file = args.input if os.path.isdir(input_file): input_file = os.path.join(input_file, "backbone.pth") assert os.path.exists(input_file) model_name = os.path.basename(os.path.dirname(input_file)).lower() params = model_name.split("_") if len(params) >= 3 and params[1] in ('arcface', 'cosface'): if args.network is None: args.network = params[2] assert args.network is not None print(args) backbone_onnx = get_model(args.network, dropout=0) output_path = args.output if output_path is None: output_path = os.path.join(os.path.dirname(__file__), 'onnx') if not os.path.exists(output_path): os.makedirs(output_path) assert os.path.isdir(output_path) output_file = os.path.join(output_path, "%s.onnx" % model_name) convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) ================================================ FILE: src/face3d/models/arcface_torch/train.py ================================================ import argparse import logging import os import torch import torch.distributed as dist import torch.nn.functional as F import torch.utils.data.distributed from torch.nn.utils import clip_grad_norm_ import losses from backbones import get_model from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX from partial_fc import PartialFC from utils.utils_amp import MaxClipGradScaler from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint from utils.utils_config import get_config from utils.utils_logging import AverageMeter, init_logging def main(args): cfg = get_config(args.config) try: world_size = int(os.environ['WORLD_SIZE']) rank = int(os.environ['RANK']) dist.init_process_group('nccl') except KeyError: world_size = 1 rank = 0 dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) local_rank = args.local_rank torch.cuda.set_device(local_rank) os.makedirs(cfg.output, exist_ok=True) init_logging(rank, cfg.output) if cfg.rec == "synthetic": train_set = SyntheticDataset(local_rank=local_rank) else: train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) train_loader = DataLoaderX( local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) if cfg.resume: try: backbone_pth = os.path.join(cfg.output, "backbone.pth") backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) if rank == 0: logging.info("backbone resume successfully!") except (FileNotFoundError, KeyError, IndexError, RuntimeError): if rank == 0: logging.info("resume fail, backbone init successfully!") backbone = torch.nn.parallel.DistributedDataParallel( module=backbone, broadcast_buffers=False, device_ids=[local_rank]) backbone.train() margin_softmax = losses.get_loss(cfg.loss) module_partial_fc = PartialFC( rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output) opt_backbone = torch.optim.SGD( params=[{'params': backbone.parameters()}], lr=cfg.lr / 512 * cfg.batch_size * world_size, momentum=0.9, weight_decay=cfg.weight_decay) opt_pfc = torch.optim.SGD( params=[{'params': module_partial_fc.parameters()}], lr=cfg.lr / 512 * cfg.batch_size * world_size, momentum=0.9, weight_decay=cfg.weight_decay) num_image = len(train_set) total_batch_size = cfg.batch_size * world_size cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch cfg.total_step = num_image // total_batch_size * cfg.num_epoch def lr_step_func(current_step): cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] if current_step < cfg.warmup_step: return current_step / cfg.warmup_step else: return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( optimizer=opt_backbone, lr_lambda=lr_step_func) scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( optimizer=opt_pfc, lr_lambda=lr_step_func) for key, value in cfg.items(): num_space = 25 - len(key) logging.info(": " + key + " " * num_space + str(value)) val_target = cfg.val_targets callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec) callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) loss = AverageMeter() start_epoch = 0 global_step = 0 grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None for epoch in range(start_epoch, cfg.num_epoch): train_sampler.set_epoch(epoch) for step, (img, label) in enumerate(train_loader): global_step += 1 features = F.normalize(backbone(img)) x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) if cfg.fp16: features.backward(grad_amp.scale(x_grad)) grad_amp.unscale_(opt_backbone) clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) grad_amp.step(opt_backbone) grad_amp.update() else: features.backward(x_grad) clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) opt_backbone.step() opt_pfc.step() module_partial_fc.update() opt_backbone.zero_grad() opt_pfc.zero_grad() loss.update(loss_v, 1) callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) callback_verification(global_step, backbone) scheduler_backbone.step() scheduler_pfc.step() callback_checkpoint(global_step, backbone, module_partial_fc) dist.destroy_process_group() if __name__ == "__main__": torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') parser.add_argument('config', type=str, help='py config file') parser.add_argument('--local_rank', type=int, default=0, help='local_rank') main(parser.parse_args()) ================================================ FILE: src/face3d/models/arcface_torch/utils/__init__.py ================================================ ================================================ FILE: src/face3d/models/arcface_torch/utils/plot.py ================================================ # coding: utf-8 import os from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap from prettytable import PrettyTable from sklearn.metrics import roc_curve, auc image_path = "/data/anxiang/IJB_release/IJBC" files = [ "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" ] def read_template_pair_list(path): pairs = pd.read_csv(path, sep=' ', header=None).values t1 = pairs[:, 0].astype(np.int) t2 = pairs[:, 1].astype(np.int) label = pairs[:, 2].astype(np.int) return t1, t2, label p1, p2, label = read_template_pair_list( os.path.join('%s/meta' % image_path, '%s_template_pair_label.txt' % 'ijbc')) methods = [] scores = [] for file in files: methods.append(file.split('/')[-2]) scores.append(np.load(file)) methods = np.array(methods) scores = dict(zip(methods, scores)) colours = dict( zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) fig = plt.figure() for method in methods: fpr, tpr, _ = roc_curve(label, scores[method]) roc_auc = auc(fpr, tpr) fpr = np.flipud(fpr) tpr = np.flipud(tpr) # select largest tpr at same fpr plt.plot(fpr, tpr, color=colours[method], lw=1, label=('[%s (AUC = %0.4f %%)]' % (method.split('-')[-1], roc_auc * 100))) tpr_fpr_row = [] tpr_fpr_row.append("%s-%s" % (method, "IJBC")) for fpr_iter in np.arange(len(x_labels)): _, min_index = min( list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) tpr_fpr_table.add_row(tpr_fpr_row) plt.xlim([10 ** -6, 0.1]) plt.ylim([0.3, 1.0]) plt.grid(linestyle='--', linewidth=1) plt.xticks(x_labels) plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) plt.xscale('log') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC on IJB') plt.legend(loc="lower right") print(tpr_fpr_table) ================================================ FILE: src/face3d/models/arcface_torch/utils/utils_amp.py ================================================ from typing import Dict, List import torch if torch.__version__ < '1.9': Iterable = torch._six.container_abcs.Iterable else: import collections Iterable = collections.abc.Iterable from torch.cuda.amp import GradScaler class _MultiDeviceReplicator(object): """ Lazily serves copies of a tensor to requested devices. Copies are cached per-device. """ def __init__(self, master_tensor: torch.Tensor) -> None: assert master_tensor.is_cuda self.master = master_tensor self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} def get(self, device) -> torch.Tensor: retval = self._per_device_tensors.get(device, None) if retval is None: retval = self.master.to(device=device, non_blocking=True, copy=True) self._per_device_tensors[device] = retval return retval class MaxClipGradScaler(GradScaler): def __init__(self, init_scale, max_scale: float, growth_interval=100): GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) self.max_scale = max_scale def scale_clip(self): if self.get_scale() == self.max_scale: self.set_growth_factor(1) elif self.get_scale() < self.max_scale: self.set_growth_factor(2) elif self.get_scale() > self.max_scale: self._scale.fill_(self.max_scale) self.set_growth_factor(1) def scale(self, outputs): """ Multiplies ('scales') a tensor or list of tensors by the scale factor. Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned unmodified. Arguments: outputs (Tensor or iterable of Tensors): Outputs to scale. """ if not self._enabled: return outputs self.scale_clip() # Short-circuit for the common case. if isinstance(outputs, torch.Tensor): assert outputs.is_cuda if self._scale is None: self._lazy_init_scale_growth_tracker(outputs.device) assert self._scale is not None return outputs * self._scale.to(device=outputs.device, non_blocking=True) # Invoke the more complex machinery only if we're treating multiple outputs. stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale def apply_scale(val): if isinstance(val, torch.Tensor): assert val.is_cuda if len(stash) == 0: if self._scale is None: self._lazy_init_scale_growth_tracker(val.device) assert self._scale is not None stash.append(_MultiDeviceReplicator(self._scale)) return val * stash[0].get(val.device) elif isinstance(val, Iterable): iterable = map(apply_scale, val) if isinstance(val, list) or isinstance(val, tuple): return type(val)(iterable) else: return iterable else: raise ValueError("outputs must be a Tensor or an iterable of Tensors") return apply_scale(outputs) ================================================ FILE: src/face3d/models/arcface_torch/utils/utils_callbacks.py ================================================ import logging import os import time from typing import List import torch from eval import verification from utils.utils_logging import AverageMeter class CallBackVerification(object): def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): self.frequent: int = frequent self.rank: int = rank self.highest_acc: float = 0.0 self.highest_acc_list: List[float] = [0.0] * len(val_targets) self.ver_list: List[object] = [] self.ver_name_list: List[str] = [] if self.rank is 0: self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) def ver_test(self, backbone: torch.nn.Module, global_step: int): results = [] for i in range(len(self.ver_list)): acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( self.ver_list[i], backbone, 10, 10) logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) if acc2 > self.highest_acc_list[i]: self.highest_acc_list[i] = acc2 logging.info( '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) results.append(acc2) def init_dataset(self, val_targets, data_dir, image_size): for name in val_targets: path = os.path.join(data_dir, name + ".bin") if os.path.exists(path): data_set = verification.load_bin(path, image_size) self.ver_list.append(data_set) self.ver_name_list.append(name) def __call__(self, num_update, backbone: torch.nn.Module): if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: backbone.eval() self.ver_test(backbone, num_update) backbone.train() class CallBackLogging(object): def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): self.frequent: int = frequent self.rank: int = rank self.time_start = time.time() self.total_step: int = total_step self.batch_size: int = batch_size self.world_size: int = world_size self.writer = writer self.init = False self.tic = 0 def __call__(self, global_step: int, loss: AverageMeter, epoch: int, fp16: bool, learning_rate: float, grad_scaler: torch.cuda.amp.GradScaler): if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: if self.init: try: speed: float = self.frequent * self.batch_size / (time.time() - self.tic) speed_total = speed * self.world_size except ZeroDivisionError: speed_total = float('inf') time_now = (time.time() - self.time_start) / 3600 time_total = time_now / ((global_step + 1) / self.total_step) time_for_end = time_total - time_now if self.writer is not None: self.writer.add_scalar('time_for_end', time_for_end, global_step) self.writer.add_scalar('learning_rate', learning_rate, global_step) self.writer.add_scalar('loss', loss.avg, global_step) if fp16: msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( speed_total, loss.avg, learning_rate, epoch, global_step, grad_scaler.get_scale(), time_for_end ) else: msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ "Required: %1.f hours" % ( speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end ) logging.info(msg) loss.reset() self.tic = time.time() else: self.init = True self.tic = time.time() class CallBackModelCheckpoint(object): def __init__(self, rank, output="./"): self.rank: int = rank self.output: str = output def __call__(self, global_step, backbone, partial_fc, ): if global_step > 100 and self.rank == 0: path_module = os.path.join(self.output, "backbone.pth") torch.save(backbone.module.state_dict(), path_module) logging.info("Pytorch Model Saved in '{}'".format(path_module)) if global_step > 100 and partial_fc is not None: partial_fc.save_params() ================================================ FILE: src/face3d/models/arcface_torch/utils/utils_config.py ================================================ import importlib import os.path as osp def get_config(config_file): assert config_file.startswith('configs/'), 'config file setting must start with configs/' temp_config_name = osp.basename(config_file) temp_module_name = osp.splitext(temp_config_name)[0] config = importlib.import_module("configs.base") cfg = config.config config = importlib.import_module("configs.%s" % temp_module_name) job_cfg = config.config cfg.update(job_cfg) if cfg.output is None: cfg.output = osp.join('work_dirs', temp_module_name) return cfg ================================================ FILE: src/face3d/models/arcface_torch/utils/utils_logging.py ================================================ import logging import os import sys class AverageMeter(object): """Computes and stores the average and current value """ def __init__(self): self.val = None self.avg = None self.sum = None self.count = None self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def init_logging(rank, models_root): if rank == 0: log_root = logging.getLogger() log_root.setLevel(logging.INFO) formatter = logging.Formatter("Training: %(asctime)s-%(message)s") handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) handler_stream = logging.StreamHandler(sys.stdout) handler_file.setFormatter(formatter) handler_stream.setFormatter(formatter) log_root.addHandler(handler_file) log_root.addHandler(handler_stream) log_root.info('rank_id: %d' % rank) ================================================ FILE: src/face3d/models/arcface_torch/utils/utils_os.py ================================================ ================================================ FILE: src/face3d/models/base_model.py ================================================ """This script defines the base network model for Deep3DFaceRecon_pytorch """ import os import numpy as np import torch from collections import OrderedDict from abc import ABC, abstractmethod from . import networks class BaseModel(ABC): """This class is an abstract base class (ABC) for models. To create a subclass, you need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate losses, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. """ def __init__(self, opt): """Initialize the BaseModel class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions When creating your custom class, you need to implement your own initialization. In this fucntion, you should first call Then, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.model_names (str list): specify the images that you want to display and save. -- self.visual_names (str list): define networks used in our training. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. """ self.opt = opt self.isTrain = False self.device = torch.device('cpu') self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir self.loss_names = [] self.model_names = [] self.visual_names = [] self.parallel_names = [] self.optimizers = [] self.image_paths = [] self.metric = 0 # used for learning rate policy 'plateau' @staticmethod def dict_grad_hook_factory(add_func=lambda x: x): saved_dict = dict() def hook_gen(name): def grad_hook(grad): saved_vals = add_func(grad) saved_dict[name] = saved_vals return grad_hook return hook_gen, saved_dict @staticmethod def modify_commandline_options(parser, is_train): """Add new model-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ return parser @abstractmethod def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): includes the data itself and its metadata information. """ pass @abstractmethod def forward(self): """Run forward pass; called by both functions and .""" pass @abstractmethod def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" pass def setup(self, opt): """Load and print networks; create schedulers Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ if self.isTrain: self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] if not self.isTrain or opt.continue_train: load_suffix = opt.epoch self.load_networks(load_suffix) # self.print_networks(opt.verbose) def parallelize(self, convert_sync_batchnorm=True): if not self.opt.use_ddp: for name in self.parallel_names: if isinstance(name, str): module = getattr(self, name) setattr(self, name, module.to(self.device)) else: for name in self.model_names: if isinstance(name, str): module = getattr(self, name) if convert_sync_batchnorm: module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), device_ids=[self.device.index], find_unused_parameters=True, broadcast_buffers=True)) # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. for name in self.parallel_names: if isinstance(name, str) and name not in self.model_names: module = getattr(self, name) setattr(self, name, module.to(self.device)) # put state_dict of optimizer to gpu device if self.opt.phase != 'test': if self.opt.continue_train: for optim in self.optimizers: for state in optim.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(self.device) def data_dependent_initialize(self, data): pass def train(self): """Make models train mode""" for name in self.model_names: if isinstance(name, str): net = getattr(self, name) net.train() def eval(self): """Make models eval mode""" for name in self.model_names: if isinstance(name, str): net = getattr(self, name) net.eval() def test(self): """Forward function used in test time. This function wraps function in no_grad() so we don't save intermediate steps for backprop It also calls to produce additional visualization results """ with torch.no_grad(): self.forward() self.compute_visuals() def compute_visuals(self): """Calculate additional output images for visdom and HTML visualization""" pass def get_image_paths(self, name='A'): """ Return image paths that are used to load current data""" return self.image_paths if name =='A' else self.image_paths_B def update_learning_rate(self): """Update learning rates for all the networks; called at the end of every epoch""" for scheduler in self.schedulers: if self.opt.lr_policy == 'plateau': scheduler.step(self.metric) else: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % lr) def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): visual_ret[name] = getattr(self, name)[:, :3, ...] return visual_ret def get_current_losses(self): """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number return errors_ret def save_networks(self, epoch): """Save all the networks to the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ if not os.path.isdir(self.save_dir): os.makedirs(self.save_dir) save_filename = 'epoch_%s.pth' % (epoch) save_path = os.path.join(self.save_dir, save_filename) save_dict = {} for name in self.model_names: if isinstance(name, str): net = getattr(self, name) if isinstance(net, torch.nn.DataParallel) or isinstance(net, torch.nn.parallel.DistributedDataParallel): net = net.module save_dict[name] = net.state_dict() for i, optim in enumerate(self.optimizers): save_dict['opt_%02d'%i] = optim.state_dict() for i, sched in enumerate(self.schedulers): save_dict['sched_%02d'%i] = sched.state_dict() torch.save(save_dict, save_path) def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'running_mean' or key == 'running_var'): if getattr(module, key) is None: state_dict.pop('.'.join(keys)) if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'num_batches_tracked'): state_dict.pop('.'.join(keys)) else: self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) def load_networks(self, epoch): """Load all the networks from the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ if self.opt.isTrain and self.opt.pretrained_name is not None: load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) else: load_dir = self.save_dir load_filename = 'epoch_%s.pth' % (epoch) load_path = os.path.join(load_dir, load_filename) state_dict = torch.load(load_path, map_location=self.device) print('loading the model from %s' % load_path) for name in self.model_names: if isinstance(name, str): net = getattr(self, name) if isinstance(net, torch.nn.DataParallel): net = net.module net.load_state_dict(state_dict[name]) if self.opt.phase != 'test': if self.opt.continue_train: print('loading the optim from %s' % load_path) for i, optim in enumerate(self.optimizers): optim.load_state_dict(state_dict['opt_%02d'%i]) try: print('loading the sched from %s' % load_path) for i, sched in enumerate(self.schedulers): sched.load_state_dict(state_dict['sched_%02d'%i]) except: print('Failed to load schedulers, set schedulers according to epoch count manually') for i, sched in enumerate(self.schedulers): sched.last_epoch = self.opt.epoch_count - 1 def print_networks(self, verbose): """Print the total number of parameters in the network and (if verbose) network architecture Parameters: verbose (bool) -- if verbose: print the network architecture """ print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self, name) num_params = 0 for param in net.parameters(): num_params += param.numel() if verbose: print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) print('-----------------------------------------------') def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad def generate_visuals_for_evaluation(self, data, mode): return {} ================================================ FILE: src/face3d/models/bfm.py ================================================ """This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch """ import numpy as np import torch import torch.nn.functional as F from scipy.io import loadmat from src.face3d.util.load_mats import transferBFM09 import os def perspective_projection(focal, center): # return p.T (N, 3) @ (3, 3) return np.array([ focal, 0, center, 0, focal, center, 0, 0, 1 ]).reshape([3, 3]).astype(np.float32).transpose() class SH: def __init__(self): self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] class ParametricFaceModel: def __init__(self, bfm_folder='./BFM', recenter=True, camera_distance=10., init_lit=np.array([ 0.8, 0, 0, 0, 0, 0, 0, 0, 0 ]), focal=1015., center=112., is_train=True, default_name='BFM_model_front.mat'): if not os.path.isfile(os.path.join(bfm_folder, default_name)): transferBFM09(bfm_folder) model = loadmat(os.path.join(bfm_folder, default_name)) # mean face shape. [3*N,1] self.mean_shape = model['meanshape'].astype(np.float32) # identity basis. [3*N,80] self.id_base = model['idBase'].astype(np.float32) # expression basis. [3*N,64] self.exp_base = model['exBase'].astype(np.float32) # mean face texture. [3*N,1] (0-255) self.mean_tex = model['meantex'].astype(np.float32) # texture basis. [3*N,80] self.tex_base = model['texBase'].astype(np.float32) # face indices for each vertex that lies in. starts from 0. [N,8] self.point_buf = model['point_buf'].astype(np.int64) - 1 # vertex indices for each face. starts from 0. [F,3] self.face_buf = model['tri'].astype(np.int64) - 1 # vertex indices for 68 landmarks. starts from 0. [68,1] self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1 if is_train: # vertex indices for small face region to compute photometric error. starts from 0. self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1 # vertex indices for each face from small face region. starts from 0. [f,3] self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1 # vertex indices for pre-defined skin region to compute reflectance loss self.skin_mask = np.squeeze(model['skinmask']) if recenter: mean_shape = self.mean_shape.reshape([-1, 3]) mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True) self.mean_shape = mean_shape.reshape([-1, 1]) self.persc_proj = perspective_projection(focal, center) self.device = 'cpu' self.camera_distance = camera_distance self.SH = SH() self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) def to(self, device): self.device = device for key, value in self.__dict__.items(): if type(value).__module__ == np.__name__: setattr(self, key, torch.tensor(value).to(device)) def compute_shape(self, id_coeff, exp_coeff): """ Return: face_shape -- torch.tensor, size (B, N, 3) Parameters: id_coeff -- torch.tensor, size (B, 80), identity coeffs exp_coeff -- torch.tensor, size (B, 64), expression coeffs """ batch_size = id_coeff.shape[0] id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff) exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) return face_shape.reshape([batch_size, -1, 3]) def compute_texture(self, tex_coeff, normalize=True): """ Return: face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.) Parameters: tex_coeff -- torch.tensor, size (B, 80) """ batch_size = tex_coeff.shape[0] face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex if normalize: face_texture = face_texture / 255. return face_texture.reshape([batch_size, -1, 3]) def compute_norm(self, face_shape): """ Return: vertex_norm -- torch.tensor, size (B, N, 3) Parameters: face_shape -- torch.tensor, size (B, N, 3) """ v1 = face_shape[:, self.face_buf[:, 0]] v2 = face_shape[:, self.face_buf[:, 1]] v3 = face_shape[:, self.face_buf[:, 2]] e1 = v1 - v2 e2 = v2 - v3 face_norm = torch.cross(e1, e2, dim=-1) face_norm = F.normalize(face_norm, dim=-1, p=2) face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1) vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2) vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) return vertex_norm def compute_color(self, face_texture, face_norm, gamma): """ Return: face_color -- torch.tensor, size (B, N, 3), range (0, 1.) Parameters: face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.) face_norm -- torch.tensor, size (B, N, 3), rotated face normal gamma -- torch.tensor, size (B, 27), SH coeffs """ batch_size = gamma.shape[0] v_num = face_texture.shape[1] a, c = self.SH.a, self.SH.c gamma = gamma.reshape([batch_size, 3, 9]) gamma = gamma + self.init_lit gamma = gamma.permute(0, 2, 1) Y = torch.cat([ a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device), -a[1] * c[1] * face_norm[..., 1:2], a[1] * c[1] * face_norm[..., 2:], -a[1] * c[1] * face_norm[..., :1], a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2], -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:], 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1), -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:], 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2) ], dim=-1) r = Y @ gamma[..., :1] g = Y @ gamma[..., 1:2] b = Y @ gamma[..., 2:] face_color = torch.cat([r, g, b], dim=-1) * face_texture return face_color def compute_rotation(self, angles): """ Return: rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat Parameters: angles -- torch.tensor, size (B, 3), radian """ batch_size = angles.shape[0] ones = torch.ones([batch_size, 1]).to(self.device) zeros = torch.zeros([batch_size, 1]).to(self.device) x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], rot_x = torch.cat([ ones, zeros, zeros, zeros, torch.cos(x), -torch.sin(x), zeros, torch.sin(x), torch.cos(x) ], dim=1).reshape([batch_size, 3, 3]) rot_y = torch.cat([ torch.cos(y), zeros, torch.sin(y), zeros, ones, zeros, -torch.sin(y), zeros, torch.cos(y) ], dim=1).reshape([batch_size, 3, 3]) rot_z = torch.cat([ torch.cos(z), -torch.sin(z), zeros, torch.sin(z), torch.cos(z), zeros, zeros, zeros, ones ], dim=1).reshape([batch_size, 3, 3]) rot = rot_z @ rot_y @ rot_x return rot.permute(0, 2, 1) def to_camera(self, face_shape): face_shape[..., -1] = self.camera_distance - face_shape[..., -1] return face_shape def to_image(self, face_shape): """ Return: face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction Parameters: face_shape -- torch.tensor, size (B, N, 3) """ # to image_plane face_proj = face_shape @ self.persc_proj face_proj = face_proj[..., :2] / face_proj[..., 2:] return face_proj def transform(self, face_shape, rot, trans): """ Return: face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans Parameters: face_shape -- torch.tensor, size (B, N, 3) rot -- torch.tensor, size (B, 3, 3) trans -- torch.tensor, size (B, 3) """ return face_shape @ rot + trans.unsqueeze(1) def get_landmarks(self, face_proj): """ Return: face_lms -- torch.tensor, size (B, 68, 2) Parameters: face_proj -- torch.tensor, size (B, N, 2) """ return face_proj[:, self.keypoints] def split_coeff(self, coeffs): """ Return: coeffs_dict -- a dict of torch.tensors Parameters: coeffs -- torch.tensor, size (B, 256) """ id_coeffs = coeffs[:, :80] exp_coeffs = coeffs[:, 80: 144] tex_coeffs = coeffs[:, 144: 224] angles = coeffs[:, 224: 227] gammas = coeffs[:, 227: 254] translations = coeffs[:, 254:] return { 'id': id_coeffs, 'exp': exp_coeffs, 'tex': tex_coeffs, 'angle': angles, 'gamma': gammas, 'trans': translations } def compute_for_render(self, coeffs): """ Return: face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate face_color -- torch.tensor, size (B, N, 3), in RGB order landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction Parameters: coeffs -- torch.tensor, size (B, 257) """ coef_dict = self.split_coeff(coeffs) face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) rotation = self.compute_rotation(coef_dict['angle']) face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) face_vertex = self.to_camera(face_shape_transformed) face_proj = self.to_image(face_vertex) landmark = self.get_landmarks(face_proj) face_texture = self.compute_texture(coef_dict['tex']) face_norm = self.compute_norm(face_shape) face_norm_roted = face_norm @ rotation face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) return face_vertex, face_texture, face_color, landmark def compute_for_render_woRotation(self, coeffs): """ Return: face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate face_color -- torch.tensor, size (B, N, 3), in RGB order landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction Parameters: coeffs -- torch.tensor, size (B, 257) """ coef_dict = self.split_coeff(coeffs) face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) #rotation = self.compute_rotation(coef_dict['angle']) #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) face_vertex = self.to_camera(face_shape) face_proj = self.to_image(face_vertex) landmark = self.get_landmarks(face_proj) face_texture = self.compute_texture(coef_dict['tex']) face_norm = self.compute_norm(face_shape) face_norm_roted = face_norm # @ rotation face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) return face_vertex, face_texture, face_color, landmark if __name__ == '__main__': transferBFM09() ================================================ FILE: src/face3d/models/facerecon_model.py ================================================ """This script defines the face reconstruction model for Deep3DFaceRecon_pytorch """ import numpy as np import torch from src.face3d.models.base_model import BaseModel from src.face3d.models import networks from src.face3d.models.bfm import ParametricFaceModel from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss from src.face3d.util import util from src.face3d.util.nvdiffrast import MeshRenderer # from src.face3d.util.preprocess import estimate_norm_torch import trimesh from scipy.io import savemat class FaceReconModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=False): """ Configures options specific for CUT model """ # net structure and parameters parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') parser.add_argument('--init_path', type=str, default='./checkpoints/init_model/resnet50-0676ba61.pth') parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc') parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/') parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') # renderer parameters parser.add_argument('--focal', type=float, default=1015.) parser.add_argument('--center', type=float, default=112.) parser.add_argument('--camera_d', type=float, default=10.) parser.add_argument('--z_near', type=float, default=5.) parser.add_argument('--z_far', type=float, default=15.) if is_train: # training parameters parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure') parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth') parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') # augmentation parameters parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree') # loss weights parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss') parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss') parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss') parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss') parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss') parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss') parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss') parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') opt, _ = parser.parse_known_args() parser.set_defaults( focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. ) if is_train: parser.set_defaults( use_crop_face=True, use_predef_M=False ) return parser def __init__(self, opt): """Initialize this model class. Parameters: opt -- training/test options A few things can be done here. - (required) call the initialization function of BaseModel - define loss function, visualization images, model names, and optimizers """ BaseModel.__init__(self, opt) # call the initialization method of BaseModel self.visual_names = ['output_vis'] self.model_names = ['net_recon'] self.parallel_names = self.model_names + ['renderer'] self.facemodel = ParametricFaceModel( bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center, is_train=self.isTrain, default_name=opt.bfm_model ) fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi self.renderer = MeshRenderer( rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center) ) if self.isTrain: self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'] self.net_recog = networks.define_net_recog( net_recog=opt.net_recog, pretrained_path=opt.net_recog_path ) # loss func name: (compute_%s_loss) % loss_name self.compute_feat_loss = perceptual_loss self.comupte_color_loss = photo_loss self.compute_lm_loss = landmark_loss self.compute_reg_loss = reg_loss self.compute_reflc_loss = reflectance_loss self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr) self.optimizers = [self.optimizer] self.parallel_names += ['net_recog'] # Our program will automatically call to define schedulers, load networks, and print networks def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. """ self.input_img = input['imgs'].to(self.device) self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None self.trans_m = input['M'].to(self.device) if 'M' in input else None self.image_paths = input['im_paths'] if 'im_paths' in input else None def forward(self, output_coeff, device): self.facemodel.to(device) self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ self.facemodel.compute_for_render(output_coeff) self.pred_mask, _, self.pred_face = self.renderer( self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) def compute_losses(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" assert self.net_recog.training == False trans_m = self.trans_m if not self.opt.use_predef_M: trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) pred_feat = self.net_recog(self.pred_face, trans_m) gt_feat = self.net_recog(self.input_img, self.trans_m) self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat) face_mask = self.pred_mask if self.opt.use_crop_face: face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf) face_mask = face_mask.detach() self.loss_color = self.opt.w_color * self.comupte_color_loss( self.pred_face, self.input_img, self.atten_mask * face_mask) loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt) self.loss_reg = self.opt.w_reg * loss_reg self.loss_gamma = self.opt.w_gamma * loss_gamma self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm) self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask) self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ + self.loss_lm + self.loss_reflc def optimize_parameters(self, isTrain=True): self.forward() self.compute_losses() """Update network weights; it will be called in every training iteration.""" if isTrain: self.optimizer.zero_grad() self.loss_all.backward() self.optimizer.step() def compute_visuals(self): with torch.no_grad(): input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() if self.gt_lm is not None: gt_lm_numpy = self.gt_lm.cpu().numpy() pred_lm_numpy = self.pred_lm.detach().cpu().numpy() output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') output_vis_numpy = np.concatenate((input_img_numpy, output_vis_numpy_raw, output_vis_numpy), axis=-2) else: output_vis_numpy = np.concatenate((input_img_numpy, output_vis_numpy_raw), axis=-2) self.output_vis = torch.tensor( output_vis_numpy / 255., dtype=torch.float32 ).permute(0, 3, 1, 2).to(self.device) def save_mesh(self, name): recon_shape = self.pred_vertex # get reconstructed shape recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space recon_shape = recon_shape.cpu().numpy()[0] recon_color = self.pred_color recon_color = recon_color.cpu().numpy()[0] tri = self.facemodel.face_buf.cpu().numpy() mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8)) mesh.export(name) def save_coeff(self,name): pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} pred_lm = self.pred_lm.cpu().numpy() pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate pred_coeffs['lm68'] = pred_lm savemat(name,pred_coeffs) ================================================ FILE: src/face3d/models/losses.py ================================================ import numpy as np import torch import torch.nn as nn from kornia.geometry import warp_affine import torch.nn.functional as F def resize_n_crop(image, M, dsize=112): # image: (b, c, h, w) # M : (b, 2, 3) return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) ### perceptual level loss class PerceptualLoss(nn.Module): def __init__(self, recog_net, input_size=112): super(PerceptualLoss, self).__init__() self.recog_net = recog_net self.preprocess = lambda x: 2 * x - 1 self.input_size=input_size def forward(imageA, imageB, M): """ 1 - cosine distance Parameters: imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order imageB --same as imageA """ imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) # freeze bn self.recog_net.eval() id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) # assert torch.sum((cosine_d > 1).float()) == 0 return torch.sum(1 - cosine_d) / cosine_d.shape[0] def perceptual_loss(id_featureA, id_featureB): cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) # assert torch.sum((cosine_d > 1).float()) == 0 return torch.sum(1 - cosine_d) / cosine_d.shape[0] ### image level loss def photo_loss(imageA, imageB, mask, eps=1e-6): """ l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) Parameters: imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order imageB --same as imageA """ loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) return loss def landmark_loss(predict_lm, gt_lm, weight=None): """ weighted mse loss Parameters: predict_lm --torch.tensor (B, 68, 2) gt_lm --torch.tensor (B, 68, 2) weight --numpy.array (1, 68) """ if not weight: weight = np.ones([68]) weight[28:31] = 20 weight[-8:] = 20 weight = np.expand_dims(weight, 0) weight = torch.tensor(weight).to(predict_lm.device) loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) return loss ### regulization def reg_loss(coeffs_dict, opt=None): """ l2 norm without the sqrt, from yu's implementation (mse) tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss Parameters: coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans """ # coefficient regularization to ensure plausible 3d faces if opt: w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex else: w_id, w_exp, w_tex = 1, 1, 1, 1 creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ w_tex * torch.sum(coeffs_dict['tex'] ** 2) creg_loss = creg_loss / coeffs_dict['id'].shape[0] # gamma regularization to ensure a nearly-monochromatic light gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) gamma_mean = torch.mean(gamma, dim=1, keepdims=True) gamma_loss = torch.mean((gamma - gamma_mean) ** 2) return creg_loss, gamma_loss def reflectance_loss(texture, mask): """ minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo Parameters: texture --torch.tensor, (B, N, 3) mask --torch.tensor, (N), 1 or 0 """ mask = mask.reshape([1, mask.shape[0], 1]) texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) return loss ================================================ FILE: src/face3d/models/networks.py ================================================ """This script defines deep neural networks for Deep3DFaceRecon_pytorch """ import os import numpy as np import torch.nn.functional as F from torch.nn import init import functools from torch.optim import lr_scheduler import torch from torch import Tensor import torch.nn as nn try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url from typing import Type, Any, Callable, Union, List, Optional from .arcface_torch.backbones import get_model from kornia.geometry import warp_affine def resize_n_crop(image, M, dsize=112): # image: (b, c, h, w) # M : (b, 2, 3) return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) def filter_state_dict(state_dict, remove_name='fc'): new_state_dict = {} for key in state_dict: if remove_name in key: continue new_state_dict[key] = state_dict[key] return new_state_dict def get_scheduler(optimizer, opt): """Return a learning rate scheduler Parameters: optimizer -- the optimizer of the network opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. See https://pytorch.org/docs/stable/optim.html for more details. """ if opt.lr_policy == 'linear': def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2) elif opt.lr_policy == 'plateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) elif opt.lr_policy == 'cosine': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) else: return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) return scheduler def define_net_recon(net_recon, use_last_fc=False, init_path=None): return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) def define_net_recog(net_recog, pretrained_path=None): net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) net.eval() return net class ReconNetWrapper(nn.Module): fc_dim=257 def __init__(self, net_recon, use_last_fc=False, init_path=None): super(ReconNetWrapper, self).__init__() self.use_last_fc = use_last_fc if net_recon not in func_dict: return NotImplementedError('network [%s] is not implemented', net_recon) func, last_dim = func_dict[net_recon] backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) if init_path and os.path.isfile(init_path): state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) backbone.load_state_dict(state_dict) print("loading init net_recon %s from %s" %(net_recon, init_path)) self.backbone = backbone if not use_last_fc: self.final_layers = nn.ModuleList([ conv1x1(last_dim, 80, bias=True), # id layer conv1x1(last_dim, 64, bias=True), # exp layer conv1x1(last_dim, 80, bias=True), # tex layer conv1x1(last_dim, 3, bias=True), # angle layer conv1x1(last_dim, 27, bias=True), # gamma layer conv1x1(last_dim, 2, bias=True), # tx, ty conv1x1(last_dim, 1, bias=True) # tz ]) for m in self.final_layers: nn.init.constant_(m.weight, 0.) nn.init.constant_(m.bias, 0.) def forward(self, x): x = self.backbone(x) if not self.use_last_fc: output = [] for layer in self.final_layers: output.append(layer(x)) x = torch.flatten(torch.cat(output, dim=1), 1) return x class RecogNetWrapper(nn.Module): def __init__(self, net_recog, pretrained_path=None, input_size=112): super(RecogNetWrapper, self).__init__() net = get_model(name=net_recog, fp16=False) if pretrained_path: state_dict = torch.load(pretrained_path, map_location='cpu') net.load_state_dict(state_dict) print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) for param in net.parameters(): param.requires_grad = False self.net = net self.preprocess = lambda x: 2 * x - 1 self.input_size=input_size def forward(self, image, M): image = self.preprocess(resize_n_crop(image, M, self.input_size)) id_feature = F.normalize(self.net(image), dim=-1, p=2) return id_feature # adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', } def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) class BasicBlock(nn.Module): expansion: int = 1 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion: int = 4 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, use_last_fc: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None ) -> None: super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.use_last_fc = use_last_fc self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if self.use_last_fc: self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def _forward_impl(self, x: Tensor) -> Tensor: # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) if self.use_last_fc: x = torch.flatten(x, 1) x = self.fc(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _resnet( arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool, **kwargs: Any ) -> ResNet: model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['groups'] = 32 kwargs['width_per_group'] = 4 return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['groups'] = 32 kwargs['width_per_group'] = 8 return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['width_per_group'] = 64 * 2 return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['width_per_group'] = 64 * 2 return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) func_dict = { 'resnet18': (resnet18, 512), 'resnet50': (resnet50, 2048) } ================================================ FILE: src/face3d/models/template_model.py ================================================ """Model class template This module provides a template for users to implement custom models. You can specify '--model template' to use this model. The class name should be consistent with both the filename and its model option. The filename should be _dataset.py The class name should be Dataset.py It implements a simple image-to-image translation baseline based on regression loss. Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: min_ ||netG(data_A) - data_B||_1 You need to implement the following functions: : Add model-specific options and rewrite default values for existing options. <__init__>: Initialize this model class. : Unpack input data and perform data pre-processing. : Run forward pass. This will be called by both and . : Update network weights; it will be called in every training iteration. """ import numpy as np import torch from .base_model import BaseModel from . import networks class TemplateModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): """Add new model-specific options and rewrite default values for existing options. Parameters: parser -- the option parser is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. if is_train: parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. return parser def __init__(self, opt): """Initialize this model class. Parameters: opt -- training/test options A few things can be done here. - (required) call the initialization function of BaseModel - define loss function, visualization images, model names, and optimizers """ BaseModel.__init__(self, opt) # call the initialization method of BaseModel # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. self.loss_names = ['loss_G'] # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. self.visual_names = ['data_A', 'data_B', 'output'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. self.model_names = ['G'] # define networks; you can use opt.isTrain to specify different behaviors for training and test. self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) if self.isTrain: # only defined during training time # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) self.criterionLoss = torch.nn.L1Loss() # define and initialize optimizers. You can define one optimizer for each network. # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [self.optimizer] # Our program will automatically call to define schedulers, load networks, and print networks def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. """ AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths def forward(self): """Run forward pass. This will be called by both functions and .""" self.output = self.netG(self.data_A) # generate output image given the input data_A def backward(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # caculate the intermediate results if necessary; here self.output has been computed during function # calculate loss given the input and intermediate results self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G def optimize_parameters(self): """Update network weights; it will be called in every training iteration.""" self.forward() # first call forward to calculate intermediate results self.optimizer.zero_grad() # clear network G's existing gradients self.backward() # calculate gradients for network G self.optimizer.step() # update gradients for network G ================================================ FILE: src/face3d/options/__init__.py ================================================ """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" ================================================ FILE: src/face3d/options/base_options.py ================================================ """This script contains base options for Deep3DFaceRecon_pytorch """ import argparse import os from util import util import numpy as np import torch import face3d.models as models import face3d.data as data class BaseOptions(): """This class defines options used during both training and test time. It also implements several helper functions such as parsing, printing, and saving the options. It also gathers additional options defined in functions in both dataset class and model class. """ def __init__(self, cmd_line=None): """Reset the class; indicates the class hasn't been initailized""" self.initialized = False self.cmd_line = None if cmd_line is not None: self.cmd_line = cmd_line.split() def initialize(self, parser): """Define the common options that are used in both training and test.""" # basic parameters parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models') parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') # model parameters parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') # additional parameters parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') self.initialized = True return parser def gather_options(self): """Initialize our parser with basic options(only once). Add additional model-specific and dataset-specific options. These options are defined in the function in model and dataset classes. """ if not self.initialized: # check if it has been initialized parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = self.initialize(parser) # get the basic options if self.cmd_line is None: opt, _ = parser.parse_known_args() else: opt, _ = parser.parse_known_args(self.cmd_line) # set cuda visible devices os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids # modify model-related parser options model_name = opt.model model_option_setter = models.get_option_setter(model_name) parser = model_option_setter(parser, self.isTrain) if self.cmd_line is None: opt, _ = parser.parse_known_args() # parse again with new defaults else: opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults # modify dataset-related parser options if opt.dataset_mode: dataset_name = opt.dataset_mode dataset_option_setter = data.get_option_setter(dataset_name) parser = dataset_option_setter(parser, self.isTrain) # save and return the parser self.parser = parser if self.cmd_line is None: return parser.parse_args() else: return parser.parse_args(self.cmd_line) def print_options(self, opt): """Print and save options It will print both current options and default values(if different). It will save options into a text file / [checkpoints_dir] / opt.txt """ message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) # save to the disk expr_dir = os.path.join(opt.checkpoints_dir, opt.name) util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) try: with open(file_name, 'wt') as opt_file: opt_file.write(message) opt_file.write('\n') except PermissionError as error: print("permission error {}".format(error)) pass def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' opt.name = opt.name + suffix # set gpu ids str_ids = opt.gpu_ids.split(',') gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: gpu_ids.append(id) opt.world_size = len(gpu_ids) # if len(opt.gpu_ids) > 0: # torch.cuda.set_device(gpu_ids[0]) if opt.world_size == 1: opt.use_ddp = False if opt.phase != 'test': # set continue_train automatically if opt.pretrained_name is None: model_dir = os.path.join(opt.checkpoints_dir, opt.name) else: model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) if os.path.isdir(model_dir): model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] if os.path.isdir(model_dir) and len(model_pths) != 0: opt.continue_train= True # update the latest epoch count if opt.continue_train: if opt.epoch == 'latest': epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] if len(epoch_counts) != 0: opt.epoch_count = max(epoch_counts) + 1 else: opt.epoch_count = int(opt.epoch) + 1 self.print_options(opt) self.opt = opt return self.opt ================================================ FILE: src/face3d/options/inference_options.py ================================================ from face3d.options.base_options import BaseOptions class InferenceOptions(BaseOptions): """This class includes test options. It also includes shared options defined in BaseOptions. """ def initialize(self, parser): parser = BaseOptions.initialize(self, parser) # define shared options parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') parser.add_argument('--input_dir', type=str, help='the folder of the input files') parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') parser.add_argument('--save_split_files', action='store_true', help='save split files or not') parser.add_argument('--inference_batch_size', type=int, default=8) # Dropout and Batchnorm has different behavior during training and test. self.isTrain = False return parser ================================================ FILE: src/face3d/options/test_options.py ================================================ """This script contains the test options for Deep3DFaceRecon_pytorch """ from .base_options import BaseOptions class TestOptions(BaseOptions): """This class includes test options. It also includes shared options defined in BaseOptions. """ def initialize(self, parser): parser = BaseOptions.initialize(self, parser) # define shared options parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') # Dropout and Batchnorm has different behavior during training and test. self.isTrain = False return parser ================================================ FILE: src/face3d/options/train_options.py ================================================ """This script contains the training options for Deep3DFaceRecon_pytorch """ from .base_options import BaseOptions from util import util class TrainOptions(BaseOptions): """This class includes training options. It also includes shared options defined in BaseOptions. """ def initialize(self, parser): parser = BaseOptions.initialize(self, parser) # dataset parameters # for train parser.add_argument('--data_root', type=str, default='./', help='dataset root') parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') # for val parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') parser.add_argument('--batch_size_val', type=int, default=32) # visualization parameters parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') # network saving and loading parameters parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') # training parameters parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') self.isTrain = True return parser ================================================ FILE: src/face3d/util/__init__.py ================================================ """This package includes a miscellaneous collection of useful helper functions.""" from src.face3d.util import * ================================================ FILE: src/face3d/util/detect_lm68.py ================================================ import os import cv2 import numpy as np from scipy.io import loadmat import tensorflow as tf from util.preprocess import align_for_lm from shutil import move mean_face = np.loadtxt('util/test_mean_face.txt') mean_face = mean_face.reshape([68, 2]) def save_label(labels, save_path): np.savetxt(save_path, labels) def draw_landmarks(img, landmark, save_name): landmark = landmark lm_img = np.zeros([img.shape[0], img.shape[1], 3]) lm_img[:] = img.astype(np.float32) landmark = np.round(landmark).astype(np.int32) for i in range(len(landmark)): for j in range(-1, 1): for k in range(-1, 1): if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ landmark[i, 0]+k > 0 and \ landmark[i, 0]+k < img.shape[1]: lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, :] = np.array([0, 0, 255]) lm_img = lm_img.astype(np.uint8) cv2.imwrite(save_name, lm_img) def load_data(img_name, txt_name): return cv2.imread(img_name), np.loadtxt(txt_name) # create tensorflow graph for landmark detector def load_lm_graph(graph_filename): with tf.gfile.GFile(graph_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name='net') img_224 = graph.get_tensor_by_name('net/input_imgs:0') output_lm = graph.get_tensor_by_name('net/lm:0') lm_sess = tf.Session(graph=graph) return lm_sess,img_224,output_lm # landmark detection def detect_68p(img_path,sess,input_op,output_op): print('detecting landmarks......') names = [i for i in sorted(os.listdir( img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] vis_path = os.path.join(img_path, 'vis') remove_path = os.path.join(img_path, 'remove') save_path = os.path.join(img_path, 'landmarks') if not os.path.isdir(vis_path): os.makedirs(vis_path) if not os.path.isdir(remove_path): os.makedirs(remove_path) if not os.path.isdir(save_path): os.makedirs(save_path) for i in range(0, len(names)): name = names[i] print('%05d' % (i), ' ', name) full_image_name = os.path.join(img_path, name) txt_name = '.'.join(name.split('.')[:-1]) + '.txt' full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image # if an image does not have detected 5 facial landmarks, remove it from the training list if not os.path.isfile(full_txt_name): move(full_image_name, os.path.join(remove_path, name)) continue # load data img, five_points = load_data(full_image_name, full_txt_name) input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection # if the alignment fails, remove corresponding image from the training list if scale == 0: move(full_txt_name, os.path.join( remove_path, txt_name)) move(full_image_name, os.path.join(remove_path, name)) continue # detect landmarks input_img = np.reshape( input_img, [1, 224, 224, 3]).astype(np.float32) landmark = sess.run( output_op, feed_dict={input_op: input_img}) # transform back to original image coordinate landmark = landmark.reshape([68, 2]) + mean_face landmark[:, 1] = 223 - landmark[:, 1] landmark = landmark / scale landmark[:, 0] = landmark[:, 0] + bbox[0] landmark[:, 1] = landmark[:, 1] + bbox[1] landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] if i % 100 == 0: draw_landmarks(img, landmark, os.path.join(vis_path, name)) save_label(landmark, os.path.join(save_path, txt_name)) ================================================ FILE: src/face3d/util/generate_list.py ================================================ """This script is to generate training list files for Deep3DFaceRecon_pytorch """ import os # save path to training data def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): save_path = os.path.join(save_folder, mode) if not os.path.isdir(save_path): os.makedirs(save_path) with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: fd.writelines([i + '\n' for i in lms_list]) with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: fd.writelines([i + '\n' for i in imgs_list]) with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: fd.writelines([i + '\n' for i in msks_list]) # check if the path is valid def check_list(rlms_list, rimgs_list, rmsks_list): lms_list, imgs_list, msks_list = [], [], [] for i in range(len(rlms_list)): flag = 'false' lm_path = rlms_list[i] im_path = rimgs_list[i] msk_path = rmsks_list[i] if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): flag = 'true' lms_list.append(rlms_list[i]) imgs_list.append(rimgs_list[i]) msks_list.append(rmsks_list[i]) print(i, rlms_list[i], flag) return lms_list, imgs_list, msks_list ================================================ FILE: src/face3d/util/html.py ================================================ import dominate from dominate.tags import meta, h3, table, tr, td, p, a, img, br import os class HTML: """This HTML class allows us to save images and write texts into a single HTML file. It consists of functions such as (add a text header to the HTML file), (add a row of images to the HTML file), and (save the HTML to the disk). It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. """ def __init__(self, web_dir, title, refresh=0): """Initialize the HTML classes Parameters: web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: with self.doc.head: meta(http_equiv="refresh", content=str(refresh)) def get_image_dir(self): """Return the directory that stores images""" return self.img_dir def add_header(self, text): """Insert a header to the HTML file Parameters: text (str) -- the header text """ with self.doc: h3(text) def add_images(self, ims, txts, links, width=400): """add images to the HTML file Parameters: ims (str list) -- a list of image paths txts (str list) -- a list of image names shown on the website links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page """ self.t = table(border=1, style="table-layout: fixed;") # Insert a table self.doc.add(self.t) with self.t: with tr(): for im, txt, link in zip(ims, txts, links): with td(style="word-wrap: break-word;", halign="center", valign="top"): with p(): with a(href=os.path.join('images', link)): img(style="width:%dpx" % width, src=os.path.join('images', im)) br() p(txt) def save(self): """save the current content to the HMTL file""" html_file = '%s/index.html' % self.web_dir f = open(html_file, 'wt') f.write(self.doc.render()) f.close() if __name__ == '__main__': # we show an example usage here. html = HTML('web/', 'test_html') html.add_header('hello world') ims, txts, links = [], [], [] for n in range(4): ims.append('image_%d.png' % n) txts.append('text_%d' % n) links.append('image_%d.png' % n) html.add_images(ims, txts, links) html.save() ================================================ FILE: src/face3d/util/load_mats.py ================================================ """This script is to load 3D face model for Deep3DFaceRecon_pytorch """ import numpy as np from PIL import Image from scipy.io import loadmat, savemat from array import array import os.path as osp # load expression basis def LoadExpBasis(bfm_folder='BFM'): n_vertex = 53215 Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') exp_dim = array('i') exp_dim.fromfile(Expbin, 1) expMU = array('f') expPC = array('f') expMU.fromfile(Expbin, 3*n_vertex) expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) Expbin.close() expPC = np.array(expPC) expPC = np.reshape(expPC, [exp_dim[0], -1]) expPC = np.transpose(expPC) expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) return expPC, expEV # transfer original BFM09 to our face model def transferBFM09(bfm_folder='BFM'): print('Transfer BFM09 to BFM_model_front......') original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) shapePC = original_BFM['shapePC'] # shape basis shapeEV = original_BFM['shapeEV'] # corresponding eigen value shapeMU = original_BFM['shapeMU'] # mean face texPC = original_BFM['texPC'] # texture basis texEV = original_BFM['texEV'] # eigen value texMU = original_BFM['texMU'] # mean texture expPC, expEV = LoadExpBasis(bfm_folder) # transfer BFM09 to our face model idBase = shapePC*np.reshape(shapeEV, [-1, 199]) idBase = idBase/1e5 # unify the scale to decimeter idBase = idBase[:, :80] # use only first 80 basis exBase = expPC*np.reshape(expEV, [-1, 79]) exBase = exBase/1e5 # unify the scale to decimeter exBase = exBase[:, :64] # use only first 64 basis texBase = texPC*np.reshape(texEV, [-1, 199]) texBase = texBase[:, :80] # use only first 80 basis # our face model is cropped along face landmarks and contains only 35709 vertex. # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. # thus we select corresponding vertex to get our face model. index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) index_shape = index_shape['trimIndex'].astype( np.int32) - 1 # starts from 0 (to 53490) index_shape = index_shape[index_exp] idBase = np.reshape(idBase, [-1, 3, 80]) idBase = idBase[index_shape, :, :] idBase = np.reshape(idBase, [-1, 80]) texBase = np.reshape(texBase, [-1, 3, 80]) texBase = texBase[index_shape, :, :] texBase = np.reshape(texBase, [-1, 80]) exBase = np.reshape(exBase, [-1, 3, 64]) exBase = exBase[index_exp, :, :] exBase = np.reshape(exBase, [-1, 64]) meanshape = np.reshape(shapeMU, [-1, 3])/1e5 meanshape = meanshape[index_shape, :] meanshape = np.reshape(meanshape, [1, -1]) meantex = np.reshape(texMU, [-1, 3]) meantex = meantex[index_shape, :] meantex = np.reshape(meantex, [1, -1]) # other info contains triangles, region used for computing photometric loss, # region used for skin texture regularization, and 68 landmarks index etc. other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) frontmask2_idx = other_info['frontmask2_idx'] skinmask = other_info['skinmask'] keypoints = other_info['keypoints'] point_buf = other_info['point_buf'] tri = other_info['tri'] tri_mask2 = other_info['tri_mask2'] # save our face model savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) # load landmarks for standard face, which is used for image preprocessing def load_lm3d(bfm_folder): Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) Lm3D = Lm3D['lm'] # calculate 5 facial landmarks using 68 landmarks lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) Lm3D = Lm3D[[1, 2, 0, 3, 4], :] return Lm3D if __name__ == '__main__': transferBFM09() ================================================ FILE: src/face3d/util/my_awing_arch.py ================================================ import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def calculate_points(heatmaps): # change heatmaps to landmarks B, N, H, W = heatmaps.shape HW = H * W BN_range = np.arange(B * N) heatline = heatmaps.reshape(B, N, HW) indexes = np.argmax(heatline, axis=2) preds = np.stack((indexes % W, indexes // W), axis=2) preds = preds.astype(np.float, copy=False) inr = indexes.ravel() heatline = heatline.reshape(B * N, HW) x_up = heatline[BN_range, inr + 1] x_down = heatline[BN_range, inr - 1] # y_up = heatline[BN_range, inr + W] if any((inr + W) >= 4096): y_up = heatline[BN_range, 4095] else: y_up = heatline[BN_range, inr + W] if any((inr - W) <= 0): y_down = heatline[BN_range, 0] else: y_down = heatline[BN_range, inr - W] think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1)) think_diff *= .25 preds += think_diff.reshape(B, N, 2) preds += .5 return preds class AddCoordsTh(nn.Module): def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False): super(AddCoordsTh, self).__init__() self.x_dim = x_dim self.y_dim = y_dim self.with_r = with_r self.with_boundary = with_boundary def forward(self, input_tensor, heatmap=None): """ input_tensor: (batch, c, x_dim, y_dim) """ batch_size_tensor = input_tensor.shape[0] xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32, device=input_tensor.device) xx_ones = xx_ones.unsqueeze(-1) xx_range = torch.arange(self.x_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0) xx_range = xx_range.unsqueeze(1) xx_channel = torch.matmul(xx_ones.float(), xx_range.float()) xx_channel = xx_channel.unsqueeze(-1) yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32, device=input_tensor.device) yy_ones = yy_ones.unsqueeze(1) yy_range = torch.arange(self.y_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0) yy_range = yy_range.unsqueeze(-1) yy_channel = torch.matmul(yy_range.float(), yy_ones.float()) yy_channel = yy_channel.unsqueeze(-1) xx_channel = xx_channel.permute(0, 3, 2, 1) yy_channel = yy_channel.permute(0, 3, 2, 1) xx_channel = xx_channel / (self.x_dim - 1) yy_channel = yy_channel / (self.y_dim - 1) xx_channel = xx_channel * 2 - 1 yy_channel = yy_channel * 2 - 1 xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) if self.with_boundary and heatmap is not None: boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0) zero_tensor = torch.zeros_like(xx_channel) xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor) yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor) if self.with_boundary and heatmap is not None: xx_boundary_channel = xx_boundary_channel.to(input_tensor.device) yy_boundary_channel = yy_boundary_channel.to(input_tensor.device) ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) if self.with_r: rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2)) rr = rr / torch.max(rr) ret = torch.cat([ret, rr], dim=1) if self.with_boundary and heatmap is not None: ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1) return ret class CoordConvTh(nn.Module): """CoordConv layer as in the paper.""" def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs): super(CoordConvTh, self).__init__() self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary) in_channels += 2 if with_r: in_channels += 1 if with_boundary and not first_one: in_channels += 2 self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs) def forward(self, input_tensor, heatmap=None): ret = self.addcoords(input_tensor, heatmap) last_channel = ret[:, -2:, :, :] ret = self.conv(ret) return ret, last_channel def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilation=1): '3x3 convolution with padding' return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias, dilation=dilation) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) # self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) # self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.relu(out) out = self.conv2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ConvBlock(nn.Module): def __init__(self, in_planes, out_planes): super(ConvBlock, self).__init__() self.bn1 = nn.BatchNorm2d(in_planes) self.conv1 = conv3x3(in_planes, int(out_planes / 2)) self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), padding=1, dilation=1) self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), padding=1, dilation=1) if in_planes != out_planes: self.downsample = nn.Sequential( nn.BatchNorm2d(in_planes), nn.ReLU(True), nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False), ) else: self.downsample = None def forward(self, x): residual = x out1 = self.bn1(x) out1 = F.relu(out1, True) out1 = self.conv1(out1) out2 = self.bn2(out1) out2 = F.relu(out2, True) out2 = self.conv2(out2) out3 = self.bn3(out2) out3 = F.relu(out3, True) out3 = self.conv3(out3) out3 = torch.cat((out1, out2, out3), 1) if self.downsample is not None: residual = self.downsample(residual) out3 += residual return out3 class HourGlass(nn.Module): def __init__(self, num_modules, depth, num_features, first_one=False): super(HourGlass, self).__init__() self.num_modules = num_modules self.depth = depth self.features = num_features self.coordconv = CoordConvTh( x_dim=64, y_dim=64, with_r=True, with_boundary=True, in_channels=256, first_one=first_one, out_channels=256, kernel_size=1, stride=1, padding=0) self._generate_network(self.depth) def _generate_network(self, level): self.add_module('b1_' + str(level), ConvBlock(256, 256)) self.add_module('b2_' + str(level), ConvBlock(256, 256)) if level > 1: self._generate_network(level - 1) else: self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) self.add_module('b3_' + str(level), ConvBlock(256, 256)) def _forward(self, level, inp): # Upper branch up1 = inp up1 = self._modules['b1_' + str(level)](up1) # Lower branch low1 = F.avg_pool2d(inp, 2, stride=2) low1 = self._modules['b2_' + str(level)](low1) if level > 1: low2 = self._forward(level - 1, low1) else: low2 = low1 low2 = self._modules['b2_plus_' + str(level)](low2) low3 = low2 low3 = self._modules['b3_' + str(level)](low3) up2 = F.interpolate(low3, scale_factor=2, mode='nearest') return up1 + up2 def forward(self, x, heatmap): x, last_channel = self.coordconv(x, heatmap) return self._forward(self.depth, x), last_channel class FAN(nn.Module): def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68, device='cuda'): super(FAN, self).__init__() self.device = device self.num_modules = num_modules self.gray_scale = gray_scale self.end_relu = end_relu self.num_landmarks = num_landmarks # Base part if self.gray_scale: self.conv1 = CoordConvTh( x_dim=256, y_dim=256, with_r=True, with_boundary=False, in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3) else: self.conv1 = CoordConvTh( x_dim=256, y_dim=256, with_r=True, with_boundary=False, in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3) self.bn1 = nn.BatchNorm2d(64) self.conv2 = ConvBlock(64, 128) self.conv3 = ConvBlock(128, 128) self.conv4 = ConvBlock(128, 256) # Stacking part for hg_module in range(self.num_modules): if hg_module == 0: first_one = True else: first_one = False self.add_module('m' + str(hg_module), HourGlass(1, 4, 256, first_one)) self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) self.add_module('conv_last' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) self.add_module('l' + str(hg_module), nn.Conv2d(256, num_landmarks + 1, kernel_size=1, stride=1, padding=0)) if hg_module < self.num_modules - 1: self.add_module('bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) self.add_module('al' + str(hg_module), nn.Conv2d(num_landmarks + 1, 256, kernel_size=1, stride=1, padding=0)) def forward(self, x): x, _ = self.conv1(x) x = F.relu(self.bn1(x), True) # x = F.relu(self.bn1(self.conv1(x)), True) x = F.avg_pool2d(self.conv2(x), 2, stride=2) x = self.conv3(x) x = self.conv4(x) previous = x outputs = [] boundary_channels = [] tmp_out = None for i in range(self.num_modules): hg, boundary_channel = self._modules['m' + str(i)](previous, tmp_out) ll = hg ll = self._modules['top_m_' + str(i)](ll) ll = F.relu(self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)), True) # Predict heatmaps tmp_out = self._modules['l' + str(i)](ll) if self.end_relu: tmp_out = F.relu(tmp_out) # HACK: Added relu outputs.append(tmp_out) boundary_channels.append(boundary_channel) if i < self.num_modules - 1: ll = self._modules['bl' + str(i)](ll) tmp_out_ = self._modules['al' + str(i)](tmp_out) previous = previous + ll + tmp_out_ return outputs, boundary_channels def get_landmarks(self, img): H, W, _ = img.shape offset = W / 64, H / 64, 0, 0 img = cv2.resize(img, (256, 256)) inp = img[..., ::-1] inp = torch.from_numpy(np.ascontiguousarray(inp.transpose((2, 0, 1)))).float() inp = inp.to(self.device) inp.div_(255.0).unsqueeze_(0) outputs, _ = self.forward(inp) out = outputs[-1][:, :-1, :, :] heatmaps = out.detach().cpu().numpy() pred = calculate_points(heatmaps).reshape(-1, 2) pred *= offset[:2] pred += offset[-2:] return pred ================================================ FILE: src/face3d/util/nvdiffrast.py ================================================ """This script is the differentiable renderer for Deep3DFaceRecon_pytorch Attention, antialiasing step is missing in current version. """ import pytorch3d.ops import torch import torch.nn.functional as F import kornia from kornia.geometry.camera import pixel2cam import numpy as np from typing import List from scipy.io import loadmat from torch import nn from pytorch3d.structures import Meshes from pytorch3d.renderer import ( look_at_view_transform, FoVPerspectiveCameras, DirectionalLights, RasterizationSettings, MeshRenderer, MeshRasterizer, SoftPhongShader, TexturesUV, ) # def ndc_projection(x=0.1, n=1.0, f=50.0): # return np.array([[n/x, 0, 0, 0], # [ 0, n/-x, 0, 0], # [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], # [ 0, 0, -1, 0]]).astype(np.float32) class MeshRenderer(nn.Module): def __init__(self, rasterize_fov, znear=0.1, zfar=10, rasterize_size=224): super(MeshRenderer, self).__init__() # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( # torch.diag(torch.tensor([1., -1, -1, 1]))) self.rasterize_size = rasterize_size self.fov = rasterize_fov self.znear = znear self.zfar = zfar self.rasterizer = None def forward(self, vertex, tri, feat=None): """ Return: mask -- torch.tensor, size (B, 1, H, W) depth -- torch.tensor, size (B, 1, H, W) features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None Parameters: vertex -- torch.tensor, size (B, N, 3) tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles feat(optional) -- torch.tensor, size (B, N ,C), features """ device = vertex.device rsize = int(self.rasterize_size) # ndc_proj = self.ndc_proj.to(device) # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v if vertex.shape[-1] == 3: vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) vertex[..., 0] = -vertex[..., 0] # vertex_ndc = vertex @ ndc_proj.t() if self.rasterizer is None: self.rasterizer = MeshRasterizer() print("create rasterizer on device cuda:%d"%device.index) # ranges = None # if isinstance(tri, List) or len(tri.shape) == 3: # vum = vertex_ndc.shape[1] # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) # fstartidx = torch.cumsum(fnum, dim=0) - fnum # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() # for i in range(tri.shape[0]): # tri[i] = tri[i] + i*vum # vertex_ndc = torch.cat(vertex_ndc, dim=0) # tri = torch.cat(tri, dim=0) # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] tri = tri.type(torch.int32).contiguous() # rasterize cameras = FoVPerspectiveCameras( device=device, fov=self.fov, znear=self.znear, zfar=self.zfar, ) raster_settings = RasterizationSettings( image_size=rsize ) # print(vertex.shape, tri.shape) mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1))) fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) rast_out = fragments.pix_to_face.squeeze(-1) depth = fragments.zbuf # render depth depth = depth.permute(0, 3, 1, 2) mask = (rast_out > 0).float().unsqueeze(1) depth = mask * depth image = None if feat is not None: attributes = feat.reshape(-1,3)[mesh.faces_packed()] image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, attributes) # print(image.shape) image = image.squeeze(-2).permute(0, 3, 1, 2) image = mask * image return mask, depth, image ================================================ FILE: src/face3d/util/preprocess.py ================================================ """This script contains the image preprocessing code for Deep3DFaceRecon_pytorch """ import numpy as np from scipy.io import loadmat from PIL import Image import cv2 import os from skimage import transform as trans import torch import warnings warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) # calculating least square problem for image alignment def POS(xp, x): npts = xp.shape[1] A = np.zeros([2*npts, 8]) A[0:2*npts-1:2, 0:3] = x.transpose() A[0:2*npts-1:2, 3] = 1 A[1:2*npts:2, 4:7] = x.transpose() A[1:2*npts:2, 7] = 1 b = np.reshape(xp.transpose(), [2*npts, 1]) k, _, _, _ = np.linalg.lstsq(A, b) R1 = k[0:3] R2 = k[4:7] sTx = k[3] sTy = k[7] s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 t = np.stack([sTx, sTy], axis=0) return t, s # resize and crop images for face reconstruction def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): w0, h0 = img.size w = (w0*s).astype(np.int32) h = (h0*s).astype(np.int32) left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) right = left + target_size up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) below = up + target_size img = img.resize((w, h), resample=Image.BICUBIC) img = img.crop((left, up, right, below)) if mask is not None: mask = mask.resize((w, h), resample=Image.BICUBIC) mask = mask.crop((left, up, right, below)) lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - t[1] + h0/2], axis=1)*s lm = lm - np.reshape( np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) return img, lm, mask # utils for face reconstruction def extract_5p(lm): lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) lm5p = lm5p[[1, 2, 0, 3, 4], :] return lm5p # utils for face reconstruction def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): """ Return: transparams --numpy.array (raw_W, raw_H, scale, tx, ty) img_new --PIL.Image (target_size, target_size, 3) lm_new --numpy.array (68, 2), y direction is opposite to v direction mask_new --PIL.Image (target_size, target_size) Parameters: img --PIL.Image (raw_H, raw_W, 3) lm --numpy.array (68, 2), y direction is opposite to v direction lm3D --numpy.array (5, 3) mask --PIL.Image (raw_H, raw_W, 3) """ w0, h0 = img.size if lm.shape[0] != 5: lm5p = extract_5p(lm) else: lm5p = lm # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face t, s = POS(lm5p.transpose(), lm3D.transpose()) s = rescale_factor/s # processing the image img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) trans_params = np.array([w0, h0, s, t[0], t[1]]) return trans_params, img_new, lm_new, mask_new ================================================ FILE: src/face3d/util/skin_mask.py ================================================ """This script is to generate skin attention mask for Deep3DFaceRecon_pytorch """ import math import numpy as np import os import cv2 class GMM: def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): self.dim = dim # feature dimension self.num = num # number of Gaussian components self.w = w # weights of Gaussian components (a list of scalars) self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) self.factor = [0]*num for i in range(self.num): self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5 def likelihood(self, data): assert(data.shape[1] == self.dim) N = data.shape[0] lh = np.zeros(N) for i in range(self.num): data_ = data - self.mu[i] tmp = np.matmul(data_,self.cov_inv[i]) * data_ tmp = np.sum(tmp,axis=1) power = -0.5 * tmp p = np.array([math.exp(power[j]) for j in range(N)]) p = p/self.factor[i] lh += p*self.w[i] return lh def _rgb2ycbcr(rgb): m = np.array([[65.481, 128.553, 24.966], [-37.797, -74.203, 112], [112, -93.786, -18.214]]) shape = rgb.shape rgb = rgb.reshape((shape[0] * shape[1], 3)) ycbcr = np.dot(rgb, m.transpose() / 255.) ycbcr[:, 0] += 16. ycbcr[:, 1:] += 128. return ycbcr.reshape(shape) def _bgr2ycbcr(bgr): rgb = bgr[..., ::-1] return _rgb2ycbcr(rgb) gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), np.array([150.19858, 105.18467, 155.51428]), np.array([183.92976, 107.62468, 152.71820]), np.array([114.90524, 113.59782, 151.38217])] gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), np.array([110.91392, 125.52969, 130.19237]), np.array([129.75864, 129.96107, 126.96808]), np.array([112.29587, 128.85121, 129.05431])] gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) prior_skin = 0.8 prior_nonskin = 1 - prior_skin # calculate skin attention mask def skinmask(imbgr): im = _bgr2ycbcr(imbgr) data = im.reshape((-1,3)) lh_skin = gmm_skin.likelihood(data) lh_nonskin = gmm_nonskin.likelihood(data) tmp1 = prior_skin * lh_skin tmp2 = prior_nonskin * lh_nonskin post_skin = tmp1 / (tmp1+tmp2) # posterior probability post_skin = post_skin.reshape((im.shape[0],im.shape[1])) post_skin = np.round(post_skin*255) post_skin = post_skin.astype(np.uint8) post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 return post_skin def get_skin_mask(img_path): print('generating skin masks......') names = [i for i in sorted(os.listdir( img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] save_path = os.path.join(img_path, 'mask') if not os.path.isdir(save_path): os.makedirs(save_path) for i in range(0, len(names)): name = names[i] print('%05d' % (i), ' ', name) full_image_name = os.path.join(img_path, name) img = cv2.imread(full_image_name).astype(np.float32) skin_img = skinmask(img) cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8)) ================================================ FILE: src/face3d/util/test_mean_face.txt ================================================ -5.228591537475585938e+01 2.078247070312500000e-01 -5.064269638061523438e+01 -1.315765380859375000e+01 -4.952939224243164062e+01 -2.592591094970703125e+01 -4.793047332763671875e+01 -3.832135772705078125e+01 -4.512159729003906250e+01 -5.059623336791992188e+01 -3.917720794677734375e+01 -6.043736648559570312e+01 -2.929953765869140625e+01 -6.861183166503906250e+01 -1.719801330566406250e+01 -7.572736358642578125e+01 -1.961936950683593750e+00 -7.862001037597656250e+01 1.467941284179687500e+01 -7.607844543457031250e+01 2.744073486328125000e+01 -6.915261840820312500e+01 3.855677795410156250e+01 -5.950350570678710938e+01 4.478240966796875000e+01 -4.867547225952148438e+01 4.714337158203125000e+01 -3.800830078125000000e+01 4.940315246582031250e+01 -2.496297454833984375e+01 5.117234802246093750e+01 -1.241538238525390625e+01 5.190507507324218750e+01 8.244247436523437500e-01 -4.150688934326171875e+01 2.386329650878906250e+01 -3.570307159423828125e+01 3.017010498046875000e+01 -2.790358734130859375e+01 3.212951660156250000e+01 -1.941773223876953125e+01 3.156523132324218750e+01 -1.138106536865234375e+01 2.841992187500000000e+01 5.993263244628906250e+00 2.895182800292968750e+01 1.343590545654296875e+01 3.189880371093750000e+01 2.203153991699218750e+01 3.302221679687500000e+01 2.992478942871093750e+01 3.099150085449218750e+01 3.628388977050781250e+01 2.765748596191406250e+01 -1.933914184570312500e+00 1.405374145507812500e+01 -2.153038024902343750e+00 5.772636413574218750e+00 -2.270050048828125000e+00 -2.121643066406250000e+00 -2.218330383300781250e+00 -1.068978118896484375e+01 -1.187252044677734375e+01 -1.997912597656250000e+01 -6.879402160644531250e+00 -2.143579864501953125e+01 -1.227821350097656250e+00 -2.193494415283203125e+01 4.623237609863281250e+00 -2.152721405029296875e+01 9.721397399902343750e+00 -1.953671264648437500e+01 -3.648714447021484375e+01 9.811126708984375000e+00 -3.130242919921875000e+01 1.422447967529296875e+01 -2.212834930419921875e+01 1.493019866943359375e+01 -1.500880432128906250e+01 1.073588562011718750e+01 -2.095037078857421875e+01 9.054298400878906250e+00 -3.050099182128906250e+01 8.704177856445312500e+00 1.173237609863281250e+01 1.054329681396484375e+01 1.856353759765625000e+01 1.535009765625000000e+01 2.893331909179687500e+01 1.451992797851562500e+01 3.452944946289062500e+01 1.065280151367187500e+01 2.875990295410156250e+01 8.654792785644531250e+00 1.942100524902343750e+01 9.422447204589843750e+00 -2.204488372802734375e+01 -3.983994293212890625e+01 -1.324458312988281250e+01 -3.467377471923828125e+01 -6.749649047851562500e+00 -3.092894744873046875e+01 -9.183349609375000000e-01 -3.196458435058593750e+01 4.220649719238281250e+00 -3.090406036376953125e+01 1.089889526367187500e+01 -3.497008514404296875e+01 1.874589538574218750e+01 -4.065438079833984375e+01 1.124106597900390625e+01 -4.438417816162109375e+01 5.181709289550781250e+00 -4.649170684814453125e+01 -1.158607482910156250e+00 -4.680406951904296875e+01 -7.918922424316406250e+00 -4.671575164794921875e+01 -1.452505493164062500e+01 -4.416526031494140625e+01 -2.005007171630859375e+01 -3.997841644287109375e+01 -1.054919433593750000e+01 -3.849683380126953125e+01 -1.051826477050781250e+00 -3.794863128662109375e+01 6.412681579589843750e+00 -3.804645538330078125e+01 1.627674865722656250e+01 -4.039697265625000000e+01 6.373878479003906250e+00 -4.087213897705078125e+01 -8.551712036132812500e-01 -4.157129669189453125e+01 -1.014953613281250000e+01 -4.128469085693359375e+01 ================================================ FILE: src/face3d/util/util.py ================================================ """This script contains basic utilities for Deep3DFaceRecon_pytorch """ from __future__ import print_function import numpy as np import torch from PIL import Image import os import importlib import argparse from argparse import Namespace import torchvision def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def copyconf(default_opt, **kwargs): conf = Namespace(**vars(default_opt)) for key in kwargs: setattr(conf, key, kwargs[key]) return conf def genvalconf(train_opt, **kwargs): conf = Namespace(**vars(train_opt)) attr_dict = train_opt.__dict__ for key, value in attr_dict.items(): if 'val' in key and key.split('_')[0] in attr_dict: setattr(conf, key.split('_')[0], value) for key in kwargs: setattr(conf, key, kwargs[key]) return conf def find_class_in_module(target_cls_name, module): target_cls_name = target_cls_name.replace('_', '').lower() clslib = importlib.import_module(module) cls = None for name, clsobj in clslib.__dict__.items(): if name.lower() == target_cls_name: cls = clsobj assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) return cls def tensor2im(input_image, imtype=np.uint8): """"Converts a Tensor array into a numpy image array. Parameters: input_image (tensor) -- the input image tensor array, range(0, 1) imtype (type) -- the desired type of the converted numpy array """ if not isinstance(input_image, np.ndarray): if isinstance(input_image, torch.Tensor): # get the data from a variable image_tensor = input_image.data else: return input_image image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array if image_numpy.shape[0] == 1: # grayscale to RGB image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling else: # if it is a numpy array, do nothing image_numpy = input_image return image_numpy.astype(imtype) def diagnose_network(net, name='network'): """Calculate and print the mean of average absolute(gradients) Parameters: net (torch network) -- Torch network name (str) -- the name of the network """ mean = 0.0 count = 0 for param in net.parameters(): if param.grad is not None: mean += torch.mean(torch.abs(param.grad.data)) count += 1 if count > 0: mean = mean / count print(name) print(mean) def save_image(image_numpy, image_path, aspect_ratio=1.0): """Save a numpy image to the disk Parameters: image_numpy (numpy array) -- input numpy array image_path (str) -- the path of the image """ image_pil = Image.fromarray(image_numpy) h, w, _ = image_numpy.shape if aspect_ratio is None: pass elif aspect_ratio > 1.0: image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) elif aspect_ratio < 1.0: image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) image_pil.save(image_path) def print_numpy(x, val=True, shp=False): """Print the mean, min, max, median, std, and size of a numpy array Parameters: val (bool) -- if print the values of the numpy array shp (bool) -- if print the shape of the numpy array """ x = x.astype(np.float64) if shp: print('shape,', x.shape) if val: x = x.flatten() print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) def mkdirs(paths): """create empty directories if they don't exist Parameters: paths (str list) -- a list of directory paths """ if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): """create a single empty directory if it didn't exist Parameters: path (str) -- a single directory path """ if not os.path.exists(path): os.makedirs(path) def correct_resize_label(t, size): device = t.device t = t.detach().cpu() resized = [] for i in range(t.size(0)): one_t = t[i, :1] one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) one_np = one_np[:, :, 0] one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) resized_t = torch.from_numpy(np.array(one_image)).long() resized.append(resized_t) return torch.stack(resized, dim=0).to(device) def correct_resize(t, size, mode=Image.BICUBIC): device = t.device t = t.detach().cpu() resized = [] for i in range(t.size(0)): one_t = t[i:i + 1] one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 resized.append(resized_t) return torch.stack(resized, dim=0).to(device) def draw_landmarks(img, landmark, color='r', step=2): """ Return: img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255) Parameters: img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255) landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction color -- str, 'r' or 'b' (red or blue) """ if color =='r': c = np.array([255., 0, 0]) else: c = np.array([0, 0, 255.]) _, H, W, _ = img.shape img, landmark = img.copy(), landmark.copy() landmark[..., 1] = H - 1 - landmark[..., 1] landmark = np.round(landmark).astype(np.int32) for i in range(landmark.shape[1]): x, y = landmark[:, i, 0], landmark[:, i, 1] for j in range(-step, step): for k in range(-step, step): u = np.clip(x + j, 0, W - 1) v = np.clip(y + k, 0, H - 1) for m in range(landmark.shape[0]): img[m, v[m], u[m]] = c return img ================================================ FILE: src/face3d/util/visualizer.py ================================================ """This script defines the visualizer for Deep3DFaceRecon_pytorch """ import numpy as np import os import sys import ntpath import time from . import util, html from subprocess import Popen, PIPE from torch.utils.tensorboard import SummaryWriter def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): """Save images to the disk. Parameters: webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs image_path (str) -- the string is used to create image paths aspect_ratio (float) -- the aspect ratio of saved images width (int) -- the images will be resized to width x width This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. """ image_dir = webpage.get_image_dir() short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] webpage.add_header(name) ims, txts, links = [], [], [] for label, im_data in visuals.items(): im = util.tensor2im(im_data) image_name = '%s/%s.png' % (label, name) os.makedirs(os.path.join(image_dir, label), exist_ok=True) save_path = os.path.join(image_dir, image_name) util.save_image(im, save_path, aspect_ratio=aspect_ratio) ims.append(image_name) txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=width) class Visualizer(): """This class includes several functions that can display/save images and print/save logging information. It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. """ def __init__(self, opt): """Initialize the Visualizer class Parameters: opt -- stores all the experiment flags; needs to be a subclass of BaseOptions Step 1: Cache the training/test options Step 2: create a tensorboard writer Step 3: create an HTML object for saveing HTML filters Step 4: create a logging file to store training losses """ self.opt = opt # cache the option self.use_html = opt.isTrain and not opt.no_html self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name)) self.win_size = opt.display_winsize self.name = opt.name self.saved = False if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) # create a logging file to store training losses self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) def reset(self): """Reset the self.saved status""" self.saved = False def display_current_results(self, visuals, total_iters, epoch, save_result): """Display current results on tensorboad; save current results to an HTML file. Parameters: visuals (OrderedDict) - - dictionary of images to display or save total_iters (int) -- total iterations epoch (int) - - the current epoch save_result (bool) - - if save the current results to an HTML file """ for label, image in visuals.items(): self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC') if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. self.saved = True # save images to the disk for label, image in visuals.items(): image_numpy = util.tensor2im(image) img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) # update website webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) for n in range(epoch, 0, -1): webpage.add_header('epoch [%d]' % n) ims, txts, links = [], [], [] for label, image_numpy in visuals.items(): image_numpy = util.tensor2im(image) img_path = 'epoch%.3d_%s.png' % (n, label) ims.append(img_path) txts.append(label) links.append(img_path) webpage.add_images(ims, txts, links, width=self.win_size) webpage.save() def plot_current_losses(self, total_iters, losses): # G_loss_collection = {} # D_loss_collection = {} # for name, value in losses.items(): # if 'G' in name or 'NCE' in name or 'idt' in name: # G_loss_collection[name] = value # else: # D_loss_collection[name] = value # self.writer.add_scalars('G_collec', G_loss_collection, total_iters) # self.writer.add_scalars('D_collec', D_loss_collection, total_iters) for name, value in losses.items(): self.writer.add_scalar(name, value, total_iters) # losses: same format as |losses| of plot_current_losses def print_current_losses(self, epoch, iters, losses, t_comp, t_data): """print current losses on console; also save the losses to the disk Parameters: epoch (int) -- current epoch iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) losses (OrderedDict) -- training losses stored in the format of (name, float) pairs t_comp (float) -- computational time per data point (normalized by batch_size) t_data (float) -- data loading time per data point (normalized by batch_size) """ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) for k, v in losses.items(): message += '%s: %.3f ' % (k, v) print(message) # print the message with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) # save the message class MyVisualizer: def __init__(self, opt): """Initialize the Visualizer class Parameters: opt -- stores all the experiment flags; needs to be a subclass of BaseOptions Step 1: Cache the training/test options Step 2: create a tensorboard writer Step 3: create an HTML object for saveing HTML filters Step 4: create a logging file to store training losses """ self.opt = opt # cache the optio self.name = opt.name self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results') if opt.phase != 'test': self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs')) # create a logging file to store training losses self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, add_image=True): """Display current results on tensorboad; save current results to an HTML file. Parameters: visuals (OrderedDict) - - dictionary of images to display or save total_iters (int) -- total iterations epoch (int) - - the current epoch dataset (str) - - 'train' or 'val' or 'test' """ # if (not add_image) and (not save_results): return for label, image in visuals.items(): for i in range(image.shape[0]): image_numpy = util.tensor2im(image[i]) if add_image: self.writer.add_image(label + '%s_%02d'%(dataset, i + count), image_numpy, total_iters, dataformats='HWC') if save_results: save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) if not os.path.isdir(save_path): os.makedirs(save_path) if name is not None: img_path = os.path.join(save_path, '%s.png' % name) else: img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) util.save_image(image_numpy, img_path) def plot_current_losses(self, total_iters, losses, dataset='train'): for name, value in losses.items(): self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) # losses: same format as |losses| of plot_current_losses def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'): """print current losses on console; also save the losses to the disk Parameters: epoch (int) -- current epoch iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) losses (OrderedDict) -- training losses stored in the format of (name, float) pairs t_comp (float) -- computational time per data point (normalized by batch_size) t_data (float) -- data loading time per data point (normalized by batch_size) """ message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % ( dataset, epoch, iters, t_comp, t_data) for k, v in losses.items(): message += '%s: %.3f ' % (k, v) print(message) # print the message with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) # save the message ================================================ FILE: src/face3d/visualize.py ================================================ # check the sync of 3dmm feature and the audio import cv2 import numpy as np from src.face3d.models.bfm import ParametricFaceModel from src.face3d.models.facerecon_model import FaceReconModel import torch import subprocess, platform import scipy.io as scio from tqdm import tqdm # draft def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64): coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 coeff_full[:, 80:144] = coeff_pred[:, 0:64] coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation tmp_video_path = '/tmp/face3dtmp.mp4' facemodel = FaceReconModel(args) video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) facemodel.forward(cur_coeff_full, device) predicted_landmark = facemodel.pred_lm # TODO. predicted_landmark = predicted_landmark.cpu().numpy().squeeze() rendered_img = facemodel.pred_face rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) out_img = rendered_img[:, :, :3].astype(np.uint8) video.write(np.uint8(out_img[:,:,::-1])) video.release() command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) subprocess.call(command, shell=platform.system() != 'Windows') ================================================ FILE: src/facerender/animate.py ================================================ import os import cv2 import yaml import numpy as np import warnings from skimage import img_as_ubyte import safetensors import safetensors.torch warnings.filterwarnings('ignore') import imageio import torch import torchvision from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector from src.facerender.modules.mapping import MappingNet from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator from src.facerender.modules.make_animation import make_animation from pydub import AudioSegment from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list from src.utils.paste_pic import paste_pic from src.utils.videoio import save_video_with_watermark try: import webui # in webui in_webui = True except: in_webui = False class AnimateFromCoeff(): def __init__(self, sadtalker_path, device): with open(sadtalker_path['facerender_yaml']) as f: config = yaml.safe_load(f) generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']) kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], **config['model_params']['common_params']) mapping = MappingNet(**config['model_params']['mapping_params']) generator.to(device) kp_extractor.to(device) he_estimator.to(device) mapping.to(device) for param in generator.parameters(): param.requires_grad = False for param in kp_extractor.parameters(): param.requires_grad = False for param in he_estimator.parameters(): param.requires_grad = False for param in mapping.parameters(): param.requires_grad = False if sadtalker_path is not None: if 'checkpoint' in sadtalker_path: # use safe tensor self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None) else: self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) else: raise AttributeError("Checkpoint should be specified for video head pose estimator.") if sadtalker_path['mappingnet_checkpoint'] is not None: self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping) else: raise AttributeError("Checkpoint should be specified for video head pose estimator.") self.kp_extractor = kp_extractor self.generator = generator self.he_estimator = he_estimator self.mapping = mapping self.kp_extractor.eval() self.generator.eval() self.he_estimator.eval() self.mapping.eval() self.device = device def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None, kp_detector=None, he_estimator=None, device="cpu"): checkpoint = safetensors.torch.load_file(checkpoint_path) if generator is not None: x_generator = {} for k,v in checkpoint.items(): if 'generator' in k: x_generator[k.replace('generator.', '')] = v generator.load_state_dict(x_generator) if kp_detector is not None: x_generator = {} for k,v in checkpoint.items(): if 'kp_extractor' in k: x_generator[k.replace('kp_extractor.', '')] = v kp_detector.load_state_dict(x_generator) if he_estimator is not None: x_generator = {} for k,v in checkpoint.items(): if 'he_estimator' in k: x_generator[k.replace('he_estimator.', '')] = v he_estimator.load_state_dict(x_generator) return None def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, kp_detector=None, he_estimator=None, optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, optimizer_he_estimator=None, device="cpu"): checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) if generator is not None: generator.load_state_dict(checkpoint['generator']) if kp_detector is not None: kp_detector.load_state_dict(checkpoint['kp_detector']) if he_estimator is not None: he_estimator.load_state_dict(checkpoint['he_estimator']) if discriminator is not None: try: discriminator.load_state_dict(checkpoint['discriminator']) except: print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') if optimizer_generator is not None: optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) if optimizer_discriminator is not None: try: optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) except RuntimeError as e: print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') if optimizer_kp_detector is not None: optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) if optimizer_he_estimator is not None: optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) return checkpoint['epoch'] def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None, optimizer_mapping=None, optimizer_discriminator=None, device='cpu'): checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) if mapping is not None: mapping.load_state_dict(checkpoint['mapping']) if discriminator is not None: discriminator.load_state_dict(checkpoint['discriminator']) if optimizer_mapping is not None: optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping']) if optimizer_discriminator is not None: optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) return checkpoint['epoch'] def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): source_image=x['source_image'].type(torch.FloatTensor) source_semantics=x['source_semantics'].type(torch.FloatTensor) target_semantics=x['target_semantics_list'].type(torch.FloatTensor) source_image=source_image.to(self.device) source_semantics=source_semantics.to(self.device) target_semantics=target_semantics.to(self.device) if 'yaw_c_seq' in x: yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor) yaw_c_seq = x['yaw_c_seq'].to(self.device) else: yaw_c_seq = None if 'pitch_c_seq' in x: pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor) pitch_c_seq = x['pitch_c_seq'].to(self.device) else: pitch_c_seq = None if 'roll_c_seq' in x: roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor) roll_c_seq = x['roll_c_seq'].to(self.device) else: roll_c_seq = None frame_num = x['frame_num'] predictions_video = make_animation(source_image, source_semantics, target_semantics, self.generator, self.kp_extractor, self.he_estimator, self.mapping, yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True) predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) predictions_video = predictions_video[:frame_num] video = [] for idx in range(predictions_video.shape[0]): image = predictions_video[idx] image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) video.append(image) result = img_as_ubyte(video) ### the generated video is 256x256, so we keep the aspect ratio, original_size = crop_info[0] if original_size: result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] video_name = x['video_name'] + '.mp4' path = os.path.join(video_save_dir, 'temp_'+video_name) imageio.mimsave(path, result, fps=float(25)) av_path = os.path.join(video_save_dir, video_name) return_path = av_path audio_path = x['audio_path'] audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') start_time = 0 # cog will not keep the .mp3 filename sound = AudioSegment.from_file(audio_path) frames = frame_num end_time = start_time + frames*1/25*1000 word1=sound.set_frame_rate(16000) word = word1[start_time:end_time] word.export(new_audio_path, format="wav") save_video_with_watermark(path, new_audio_path, av_path, watermark= False) print(f'The generated video is named {video_save_dir}/{video_name}') if 'full' in preprocess.lower(): # only add watermark to the full image. video_name_full = x['video_name'] + '_full.mp4' full_video_path = os.path.join(video_save_dir, video_name_full) return_path = full_video_path paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) print(f'The generated video is named {video_save_dir}/{video_name_full}') else: full_video_path = av_path #### paste back then enhancers if enhancer: video_name_enhancer = x['video_name'] + '_enhanced.mp4' enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) return_path = av_path_enhancer try: enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer) imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) except: enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer) imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False) print(f'The generated video is named {video_save_dir}/{video_name_enhancer}') os.remove(enhanced_path) os.remove(path) os.remove(new_audio_path) return return_path ================================================ FILE: src/facerender/modules/dense_motion.py ================================================ from torch import nn import torch.nn.functional as F import torch from src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d class DenseMotionNetwork(nn.Module): """ Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving """ def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=False): super(DenseMotionNetwork, self).__init__() # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks) self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) self.norm = BatchNorm3d(compress, affine=True) if estimate_occlusion_map: # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3) self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) else: self.occlusion = None self.num_kp = num_kp def create_sparse_motions(self, feature, kp_driving, kp_source): bs, _, d, h, w = feature.shape identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type()) identity_grid = identity_grid.view(1, 1, d, h, w, 3) coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3) # if 'jacobian' in kp_driving: if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None: jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1) coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) coordinate_grid = coordinate_grid.squeeze(-1) driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) #adding background feature identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3 # sparse_motions = driving_to_source return sparse_motions def create_deformed_feature(self, feature, sparse_motions): bs, _, d, h, w = feature.shape feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!! sparse_deformed = F.grid_sample(feature_repeat, sparse_motions) sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) return sparse_deformed def create_heatmap_representations(self, feature, kp_driving, kp_source): spatial_size = feature.shape[3:] gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) heatmap = gaussian_driving - gaussian_source # adding background feature zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()) heatmap = torch.cat([zeros, heatmap], dim=1) heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) return heatmap def forward(self, feature, kp_driving, kp_source): bs, _, d, h, w = feature.shape feature = self.compress(feature) feature = self.norm(feature) feature = F.relu(feature) out_dict = dict() sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) deformed_feature = self.create_deformed_feature(feature, sparse_motion) heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) input_ = torch.cat([heatmap, deformed_feature], dim=2) input_ = input_.view(bs, -1, d, h, w) # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w) prediction = self.hourglass(input_) mask = self.mask(prediction) mask = F.softmax(mask, dim=1) out_dict['mask'] = mask mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) zeros_mask = torch.zeros_like(mask) mask = torch.where(mask < 1e-3, zeros_mask, mask) sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) out_dict['deformation'] = deformation if self.occlusion: bs, c, d, h, w = prediction.shape prediction = prediction.view(bs, -1, h, w) occlusion_map = torch.sigmoid(self.occlusion(prediction)) out_dict['occlusion_map'] = occlusion_map return out_dict ================================================ FILE: src/facerender/modules/discriminator.py ================================================ from torch import nn import torch.nn.functional as F from facerender.modules.util import kp2gaussian import torch class DownBlock2d(nn.Module): """ Simple block for processing video (encoder). """ def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): super(DownBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) if sn: self.conv = nn.utils.spectral_norm(self.conv) if norm: self.norm = nn.InstanceNorm2d(out_features, affine=True) else: self.norm = None self.pool = pool def forward(self, x): out = x out = self.conv(out) if self.norm: out = self.norm(out) out = F.leaky_relu(out, 0.2) if self.pool: out = F.avg_pool2d(out, (2, 2)) return out class Discriminator(nn.Module): """ Discriminator similar to Pix2Pix """ def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, sn=False, **kwargs): super(Discriminator, self).__init__() down_blocks = [] for i in range(num_blocks): down_blocks.append( DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) self.down_blocks = nn.ModuleList(down_blocks) self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) if sn: self.conv = nn.utils.spectral_norm(self.conv) def forward(self, x): feature_maps = [] out = x for down_block in self.down_blocks: feature_maps.append(down_block(out)) out = feature_maps[-1] prediction_map = self.conv(out) return feature_maps, prediction_map class MultiScaleDiscriminator(nn.Module): """ Multi-scale (scale) discriminator """ def __init__(self, scales=(), **kwargs): super(MultiScaleDiscriminator, self).__init__() self.scales = scales discs = {} for scale in scales: discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) self.discs = nn.ModuleDict(discs) def forward(self, x): out_dict = {} for scale, disc in self.discs.items(): scale = str(scale).replace('-', '.') key = 'prediction_' + scale feature_maps, prediction_map = disc(x[key]) out_dict['feature_maps_' + scale] = feature_maps out_dict['prediction_map_' + scale] = prediction_map return out_dict ================================================ FILE: src/facerender/modules/generator.py ================================================ import torch from torch import nn import torch.nn.functional as F from src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock from src.facerender.modules.dense_motion import DenseMotionNetwork class OcclusionAwareGenerator(nn.Module): """ Generator follows NVIDIA architecture. """ def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): super(OcclusionAwareGenerator, self).__init__() if dense_motion_params is not None: self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, estimate_occlusion_map=estimate_occlusion_map, **dense_motion_params) else: self.dense_motion_network = None self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3)) down_blocks = [] for i in range(num_down_blocks): in_features = min(max_features, block_expansion * (2 ** i)) out_features = min(max_features, block_expansion * (2 ** (i + 1))) down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) self.down_blocks = nn.ModuleList(down_blocks) self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) self.reshape_channel = reshape_channel self.reshape_depth = reshape_depth self.resblocks_3d = torch.nn.Sequential() for i in range(num_resblocks): self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) out_features = block_expansion * (2 ** (num_down_blocks)) self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) self.resblocks_2d = torch.nn.Sequential() for i in range(num_resblocks): self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1)) up_blocks = [] for i in range(num_down_blocks): in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i))) out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1))) up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) self.up_blocks = nn.ModuleList(up_blocks) self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3)) self.estimate_occlusion_map = estimate_occlusion_map self.image_channel = image_channel def deform_input(self, inp, deformation): _, d_old, h_old, w_old, _ = deformation.shape _, _, d, h, w = inp.shape if d_old != d or h_old != h or w_old != w: deformation = deformation.permute(0, 4, 1, 2, 3) deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') deformation = deformation.permute(0, 2, 3, 4, 1) return F.grid_sample(inp, deformation) def forward(self, source_image, kp_driving, kp_source): # Encoding (downsampling) part out = self.first(source_image) for i in range(len(self.down_blocks)): out = self.down_blocks[i](out) out = self.second(out) bs, c, h, w = out.shape # print(out.shape) feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) feature_3d = self.resblocks_3d(feature_3d) # Transforming feature representation according to deformation and occlusion output_dict = {} if self.dense_motion_network is not None: dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source) output_dict['mask'] = dense_motion['mask'] if 'occlusion_map' in dense_motion: occlusion_map = dense_motion['occlusion_map'] output_dict['occlusion_map'] = occlusion_map else: occlusion_map = None deformation = dense_motion['deformation'] out = self.deform_input(feature_3d, deformation) bs, c, d, h, w = out.shape out = out.view(bs, c*d, h, w) out = self.third(out) out = self.fourth(out) if occlusion_map is not None: if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') out = out * occlusion_map # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image # Decoding part out = self.resblocks_2d(out) for i in range(len(self.up_blocks)): out = self.up_blocks[i](out) out = self.final(out) out = F.sigmoid(out) output_dict["prediction"] = out return output_dict class SPADEDecoder(nn.Module): def __init__(self): super().__init__() ic = 256 oc = 64 norm_G = 'spadespectralinstance' label_nc = 256 self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1) self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc) self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc) self.conv_img = nn.Conv2d(oc, 3, 3, padding=1) self.up = nn.Upsample(scale_factor=2) def forward(self, feature): seg = feature x = self.fc(feature) x = self.G_middle_0(x, seg) x = self.G_middle_1(x, seg) x = self.G_middle_2(x, seg) x = self.G_middle_3(x, seg) x = self.G_middle_4(x, seg) x = self.G_middle_5(x, seg) x = self.up(x) x = self.up_0(x, seg) # 256, 128, 128 x = self.up(x) x = self.up_1(x, seg) # 64, 256, 256 x = self.conv_img(F.leaky_relu(x, 2e-1)) # x = torch.tanh(x) x = F.sigmoid(x) return x class OcclusionAwareSPADEGenerator(nn.Module): def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): super(OcclusionAwareSPADEGenerator, self).__init__() if dense_motion_params is not None: self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, estimate_occlusion_map=estimate_occlusion_map, **dense_motion_params) else: self.dense_motion_network = None self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) down_blocks = [] for i in range(num_down_blocks): in_features = min(max_features, block_expansion * (2 ** i)) out_features = min(max_features, block_expansion * (2 ** (i + 1))) down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) self.down_blocks = nn.ModuleList(down_blocks) self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) self.reshape_channel = reshape_channel self.reshape_depth = reshape_depth self.resblocks_3d = torch.nn.Sequential() for i in range(num_resblocks): self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) out_features = block_expansion * (2 ** (num_down_blocks)) self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) self.estimate_occlusion_map = estimate_occlusion_map self.image_channel = image_channel self.decoder = SPADEDecoder() def deform_input(self, inp, deformation): _, d_old, h_old, w_old, _ = deformation.shape _, _, d, h, w = inp.shape if d_old != d or h_old != h or w_old != w: deformation = deformation.permute(0, 4, 1, 2, 3) deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') deformation = deformation.permute(0, 2, 3, 4, 1) return F.grid_sample(inp, deformation) def forward(self, source_image, kp_driving, kp_source): # Encoding (downsampling) part out = self.first(source_image) for i in range(len(self.down_blocks)): out = self.down_blocks[i](out) out = self.second(out) bs, c, h, w = out.shape # print(out.shape) feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) feature_3d = self.resblocks_3d(feature_3d) # Transforming feature representation according to deformation and occlusion output_dict = {} if self.dense_motion_network is not None: dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source) output_dict['mask'] = dense_motion['mask'] # import pdb; pdb.set_trace() if 'occlusion_map' in dense_motion: occlusion_map = dense_motion['occlusion_map'] output_dict['occlusion_map'] = occlusion_map else: occlusion_map = None deformation = dense_motion['deformation'] out = self.deform_input(feature_3d, deformation) bs, c, d, h, w = out.shape out = out.view(bs, c*d, h, w) out = self.third(out) out = self.fourth(out) # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map) if occlusion_map is not None: if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') out = out * occlusion_map # Decoding part out = self.decoder(out) output_dict["prediction"] = out return output_dict ================================================ FILE: src/facerender/modules/keypoint_detector.py ================================================ from torch import nn import torch import torch.nn.functional as F from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d from src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck class KPDetector(nn.Module): """ Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint. """ def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth, num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False): super(KPDetector, self).__init__() self.predictor = KPHourglass(block_expansion, in_features=image_channel, max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks) # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3) self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1) if estimate_jacobian: self.num_jacobian_maps = 1 if single_jacobian_map else num_kp # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3) self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1) ''' initial as: [[1 0 0] [0 1 0] [0 0 1]] ''' self.jacobian.weight.data.zero_() self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) else: self.jacobian = None self.temperature = temperature self.scale_factor = scale_factor if self.scale_factor != 1: self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor) def gaussian2kp(self, heatmap): """ Extract the mean from a heatmap """ shape = heatmap.shape heatmap = heatmap.unsqueeze(-1) grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) value = (heatmap * grid).sum(dim=(2, 3, 4)) kp = {'value': value} return kp def forward(self, x): if self.scale_factor != 1: x = self.down(x) feature_map = self.predictor(x) prediction = self.kp(feature_map) final_shape = prediction.shape heatmap = prediction.view(final_shape[0], final_shape[1], -1) heatmap = F.softmax(heatmap / self.temperature, dim=2) heatmap = heatmap.view(*final_shape) out = self.gaussian2kp(heatmap) if self.jacobian is not None: jacobian_map = self.jacobian(feature_map) jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2], final_shape[3], final_shape[4]) heatmap = heatmap.unsqueeze(2) jacobian = heatmap * jacobian_map jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1) jacobian = jacobian.sum(dim=-1) jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3) out['jacobian'] = jacobian return out class HEEstimator(nn.Module): """ Estimating head pose and expression. """ def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True): super(HEEstimator, self).__init__() self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2) self.norm1 = BatchNorm2d(block_expansion, affine=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1) self.norm2 = BatchNorm2d(256, affine=True) self.block1 = nn.Sequential() for i in range(3): self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1)) self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1) self.norm3 = BatchNorm2d(512, affine=True) self.block2 = ResBottleneck(in_features=512, stride=2) self.block3 = nn.Sequential() for i in range(3): self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1)) self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1) self.norm4 = BatchNorm2d(1024, affine=True) self.block4 = ResBottleneck(in_features=1024, stride=2) self.block5 = nn.Sequential() for i in range(5): self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1)) self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1) self.norm5 = BatchNorm2d(2048, affine=True) self.block6 = ResBottleneck(in_features=2048, stride=2) self.block7 = nn.Sequential() for i in range(2): self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1)) self.fc_roll = nn.Linear(2048, num_bins) self.fc_pitch = nn.Linear(2048, num_bins) self.fc_yaw = nn.Linear(2048, num_bins) self.fc_t = nn.Linear(2048, 3) self.fc_exp = nn.Linear(2048, 3*num_kp) def forward(self, x): out = self.conv1(x) out = self.norm1(out) out = F.relu(out) out = self.maxpool(out) out = self.conv2(out) out = self.norm2(out) out = F.relu(out) out = self.block1(out) out = self.conv3(out) out = self.norm3(out) out = F.relu(out) out = self.block2(out) out = self.block3(out) out = self.conv4(out) out = self.norm4(out) out = F.relu(out) out = self.block4(out) out = self.block5(out) out = self.conv5(out) out = self.norm5(out) out = F.relu(out) out = self.block6(out) out = self.block7(out) out = F.adaptive_avg_pool2d(out, 1) out = out.view(out.shape[0], -1) yaw = self.fc_roll(out) pitch = self.fc_pitch(out) roll = self.fc_yaw(out) t = self.fc_t(out) exp = self.fc_exp(out) return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} ================================================ FILE: src/facerender/modules/make_animation.py ================================================ from scipy.spatial import ConvexHull import torch import torch.nn.functional as F import numpy as np from tqdm import tqdm def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, use_relative_movement=False, use_relative_jacobian=False): if adapt_movement_scale: source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) else: adapt_movement_scale = 1 kp_new = {k: v for k, v in kp_driving.items()} if use_relative_movement: kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) kp_value_diff *= adapt_movement_scale kp_new['value'] = kp_value_diff + kp_source['value'] if use_relative_jacobian: jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) return kp_new def headpose_pred_to_degree(pred): device = pred.device idx_tensor = [idx for idx in range(66)] idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device) pred = F.softmax(pred) degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 return degree def get_rotation_matrix(yaw, pitch, roll): yaw = yaw / 180 * 3.14 pitch = pitch / 180 * 3.14 roll = roll / 180 * 3.14 roll = roll.unsqueeze(1) pitch = pitch.unsqueeze(1) yaw = yaw.unsqueeze(1) pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch), torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1) pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1) yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), torch.sin(roll), torch.cos(roll), torch.zeros_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1) roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat) return rot_mat def keypoint_transformation(kp_canonical, he, wo_exp=False): kp = kp_canonical['value'] # (bs, k, 3) yaw, pitch, roll= he['yaw'], he['pitch'], he['roll'] yaw = headpose_pred_to_degree(yaw) pitch = headpose_pred_to_degree(pitch) roll = headpose_pred_to_degree(roll) if 'yaw_in' in he: yaw = he['yaw_in'] if 'pitch_in' in he: pitch = he['pitch_in'] if 'roll_in' in he: roll = he['roll_in'] rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) t, exp = he['t'], he['exp'] if wo_exp: exp = exp*0 # keypoint rotation kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) # keypoint translation t[:, 0] = t[:, 0]*0 t[:, 2] = t[:, 2]*0 t = t.unsqueeze(1).repeat(1, kp.shape[1], 1) kp_t = kp_rotated + t # add expression deviation exp = exp.view(exp.shape[0], -1, 3) kp_transformed = kp_t + exp return {'value': kp_transformed} def make_animation(source_image, source_semantics, target_semantics, generator, kp_detector, he_estimator, mapping, yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None, use_exp=True, use_half=False): with torch.no_grad(): predictions = [] kp_canonical = kp_detector(source_image) he_source = mapping(source_semantics) kp_source = keypoint_transformation(kp_canonical, he_source) for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'): # still check the dimension # print(target_semantics.shape, source_semantics.shape) target_semantics_frame = target_semantics[:, frame_idx] he_driving = mapping(target_semantics_frame) if yaw_c_seq is not None: he_driving['yaw_in'] = yaw_c_seq[:, frame_idx] if pitch_c_seq is not None: he_driving['pitch_in'] = pitch_c_seq[:, frame_idx] if roll_c_seq is not None: he_driving['roll_in'] = roll_c_seq[:, frame_idx] kp_driving = keypoint_transformation(kp_canonical, he_driving) kp_norm = kp_driving out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm) ''' source_image_new = out['prediction'].squeeze(1) kp_canonical_new = kp_detector(source_image_new) he_source_new = he_estimator(source_image_new) kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True) kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True) out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new) ''' predictions.append(out['prediction']) predictions_ts = torch.stack(predictions, dim=1) return predictions_ts class AnimateModel(torch.nn.Module): """ Merge all generator related updates into single model for better multi-gpu usage """ def __init__(self, generator, kp_extractor, mapping): super(AnimateModel, self).__init__() self.kp_extractor = kp_extractor self.generator = generator self.mapping = mapping self.kp_extractor.eval() self.generator.eval() self.mapping.eval() def forward(self, x): source_image = x['source_image'] source_semantics = x['source_semantics'] target_semantics = x['target_semantics'] yaw_c_seq = x['yaw_c_seq'] pitch_c_seq = x['pitch_c_seq'] roll_c_seq = x['roll_c_seq'] predictions_video = make_animation(source_image, source_semantics, target_semantics, self.generator, self.kp_extractor, self.mapping, use_exp = True, yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq) return predictions_video ================================================ FILE: src/facerender/modules/mapping.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class MappingNet(nn.Module): def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins): super( MappingNet, self).__init__() self.layer = layer nonlinearity = nn.LeakyReLU(0.1) self.first = nn.Sequential( torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) for i in range(layer): net = nn.Sequential(nonlinearity, torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) setattr(self, 'encoder' + str(i), net) self.pooling = nn.AdaptiveAvgPool1d(1) self.output_nc = descriptor_nc self.fc_roll = nn.Linear(descriptor_nc, num_bins) self.fc_pitch = nn.Linear(descriptor_nc, num_bins) self.fc_yaw = nn.Linear(descriptor_nc, num_bins) self.fc_t = nn.Linear(descriptor_nc, 3) self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp) def forward(self, input_3dmm): out = self.first(input_3dmm) for i in range(self.layer): model = getattr(self, 'encoder' + str(i)) out = model(out) + out[:,:,3:-3] out = self.pooling(out) out = out.view(out.shape[0], -1) #print('out:', out.shape) yaw = self.fc_yaw(out) pitch = self.fc_pitch(out) roll = self.fc_roll(out) t = self.fc_t(out) exp = self.fc_exp(out) return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} ================================================ FILE: src/facerender/modules/util.py ================================================ from torch import nn import torch.nn.functional as F import torch from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d import torch.nn.utils.spectral_norm as spectral_norm def kp2gaussian(kp, spatial_size, kp_variance): """ Transform a keypoint into gaussian like representation """ mean = kp['value'] coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) number_of_leading_dimensions = len(mean.shape) - 1 shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape coordinate_grid = coordinate_grid.view(*shape) repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) coordinate_grid = coordinate_grid.repeat(*repeats) # Preprocess kp shape shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) mean = mean.view(*shape) mean_sub = (coordinate_grid - mean) out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) return out def make_coordinate_grid_2d(spatial_size, type): """ Create a meshgrid [-1,1] x [-1,1] of given spatial_size. """ h, w = spatial_size x = torch.arange(w).type(type) y = torch.arange(h).type(type) x = (2 * (x / (w - 1)) - 1) y = (2 * (y / (h - 1)) - 1) yy = y.view(-1, 1).repeat(1, w) xx = x.view(1, -1).repeat(h, 1) meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) return meshed def make_coordinate_grid(spatial_size, type): d, h, w = spatial_size x = torch.arange(w).type(type) y = torch.arange(h).type(type) z = torch.arange(d).type(type) x = (2 * (x / (w - 1)) - 1) y = (2 * (y / (h - 1)) - 1) z = (2 * (z / (d - 1)) - 1) yy = y.view(1, -1, 1).repeat(d, 1, w) xx = x.view(1, 1, -1).repeat(d, h, 1) zz = z.view(-1, 1, 1).repeat(1, h, w) meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) return meshed class ResBottleneck(nn.Module): def __init__(self, in_features, stride): super(ResBottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1) self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride) self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1) self.norm1 = BatchNorm2d(in_features//4, affine=True) self.norm2 = BatchNorm2d(in_features//4, affine=True) self.norm3 = BatchNorm2d(in_features, affine=True) self.stride = stride if self.stride != 1: self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride) self.norm4 = BatchNorm2d(in_features, affine=True) def forward(self, x): out = self.conv1(x) out = self.norm1(out) out = F.relu(out) out = self.conv2(out) out = self.norm2(out) out = F.relu(out) out = self.conv3(out) out = self.norm3(out) if self.stride != 1: x = self.skip(x) x = self.norm4(x) out += x out = F.relu(out) return out class ResBlock2d(nn.Module): """ Res block, preserve spatial resolution. """ def __init__(self, in_features, kernel_size, padding): super(ResBlock2d, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.norm1 = BatchNorm2d(in_features, affine=True) self.norm2 = BatchNorm2d(in_features, affine=True) def forward(self, x): out = self.norm1(x) out = F.relu(out) out = self.conv1(out) out = self.norm2(out) out = F.relu(out) out = self.conv2(out) out += x return out class ResBlock3d(nn.Module): """ Res block, preserve spatial resolution. """ def __init__(self, in_features, kernel_size, padding): super(ResBlock3d, self).__init__() self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.norm1 = BatchNorm3d(in_features, affine=True) self.norm2 = BatchNorm3d(in_features, affine=True) def forward(self, x): out = self.norm1(x) out = F.relu(out) out = self.conv1(out) out = self.norm2(out) out = F.relu(out) out = self.conv2(out) out += x return out class UpBlock2d(nn.Module): """ Upsampling block for use in decoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(UpBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm2d(out_features, affine=True) def forward(self, x): out = F.interpolate(x, scale_factor=2) out = self.conv(out) out = self.norm(out) out = F.relu(out) return out class UpBlock3d(nn.Module): """ Upsampling block for use in decoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(UpBlock3d, self).__init__() self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm3d(out_features, affine=True) def forward(self, x): # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') out = F.interpolate(x, scale_factor=(1, 2, 2)) out = self.conv(out) out = self.norm(out) out = F.relu(out) return out class DownBlock2d(nn.Module): """ Downsampling block for use in encoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(DownBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm2d(out_features, affine=True) self.pool = nn.AvgPool2d(kernel_size=(2, 2)) def forward(self, x): out = self.conv(x) out = self.norm(out) out = F.relu(out) out = self.pool(out) return out class DownBlock3d(nn.Module): """ Downsampling block for use in encoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(DownBlock3d, self).__init__() ''' self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups, stride=(1, 2, 2)) ''' self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm3d(out_features, affine=True) self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) def forward(self, x): out = self.conv(x) out = self.norm(out) out = F.relu(out) out = self.pool(out) return out class SameBlock2d(nn.Module): """ Simple block, preserve spatial resolution. """ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): super(SameBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm2d(out_features, affine=True) if lrelu: self.ac = nn.LeakyReLU() else: self.ac = nn.ReLU() def forward(self, x): out = self.conv(x) out = self.norm(out) out = self.ac(out) return out class Encoder(nn.Module): """ Hourglass Encoder """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Encoder, self).__init__() down_blocks = [] for i in range(num_blocks): down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) self.down_blocks = nn.ModuleList(down_blocks) def forward(self, x): outs = [x] for down_block in self.down_blocks: outs.append(down_block(outs[-1])) return outs class Decoder(nn.Module): """ Hourglass Decoder """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Decoder, self).__init__() up_blocks = [] for i in range(num_blocks)[::-1]: in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) out_filters = min(max_features, block_expansion * (2 ** i)) up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) self.up_blocks = nn.ModuleList(up_blocks) # self.out_filters = block_expansion self.out_filters = block_expansion + in_features self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) self.norm = BatchNorm3d(self.out_filters, affine=True) def forward(self, x): out = x.pop() # for up_block in self.up_blocks[:-1]: for up_block in self.up_blocks: out = up_block(out) skip = x.pop() out = torch.cat([out, skip], dim=1) # out = self.up_blocks[-1](out) out = self.conv(out) out = self.norm(out) out = F.relu(out) return out class Hourglass(nn.Module): """ Hourglass architecture. """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Hourglass, self).__init__() self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) self.out_filters = self.decoder.out_filters def forward(self, x): return self.decoder(self.encoder(x)) class KPHourglass(nn.Module): """ Hourglass architecture. """ def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256): super(KPHourglass, self).__init__() self.down_blocks = nn.Sequential() for i in range(num_blocks): self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) in_filters = min(max_features, block_expansion * (2 ** num_blocks)) self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1) self.up_blocks = nn.Sequential() for i in range(num_blocks): in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i))) out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1))) self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) self.reshape_depth = reshape_depth self.out_filters = out_filters def forward(self, x): out = self.down_blocks(x) out = self.conv(out) bs, c, h, w = out.shape out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w) out = self.up_blocks(out) return out class AntiAliasInterpolation2d(nn.Module): """ Band-limited downsampling, for better preservation of the input signal. """ def __init__(self, channels, scale): super(AntiAliasInterpolation2d, self).__init__() sigma = (1 / scale - 1) / 2 kernel_size = 2 * round(sigma * 4) + 1 self.ka = kernel_size // 2 self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka kernel_size = [kernel_size, kernel_size] sigma = [sigma, sigma] # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( [ torch.arange(size, dtype=torch.float32) for size in kernel_size ] ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer('weight', kernel) self.groups = channels self.scale = scale inv_scale = 1 / scale self.int_inv_scale = int(inv_scale) def forward(self, input): if self.scale == 1.0: return input out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) out = F.conv2d(out, weight=self.weight, groups=self.groups) out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] return out class SPADE(nn.Module): def __init__(self, norm_nc, label_nc): super().__init__() self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) nhidden = 128 self.mlp_shared = nn.Sequential( nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), nn.ReLU()) self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) def forward(self, x, segmap): normalized = self.param_free_norm(x) segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') actv = self.mlp_shared(segmap) gamma = self.mlp_gamma(actv) beta = self.mlp_beta(actv) out = normalized * (1 + gamma) + beta return out class SPADEResnetBlock(nn.Module): def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): super().__init__() # Attributes self.learned_shortcut = (fin != fout) fmiddle = min(fin, fout) self.use_se = use_se # create conv layers self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) if self.learned_shortcut: self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) # apply spectral norm if specified if 'spectral' in norm_G: self.conv_0 = spectral_norm(self.conv_0) self.conv_1 = spectral_norm(self.conv_1) if self.learned_shortcut: self.conv_s = spectral_norm(self.conv_s) # define normalization layers self.norm_0 = SPADE(fin, label_nc) self.norm_1 = SPADE(fmiddle, label_nc) if self.learned_shortcut: self.norm_s = SPADE(fin, label_nc) def forward(self, x, seg1): x_s = self.shortcut(x, seg1) dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) out = x_s + dx return out def shortcut(self, x, seg1): if self.learned_shortcut: x_s = self.conv_s(self.norm_s(x, seg1)) else: x_s = x return x_s def actvn(self, x): return F.leaky_relu(x, 2e-1) class audio2image(nn.Module): def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params): super().__init__() # Attributes self.generator = generator self.kp_extractor = kp_extractor self.he_estimator_video = he_estimator_video self.he_estimator_audio = he_estimator_audio self.train_params = train_params def headpose_pred_to_degree(self, pred): device = pred.device idx_tensor = [idx for idx in range(66)] idx_tensor = torch.FloatTensor(idx_tensor).to(device) pred = F.softmax(pred) degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 return degree def get_rotation_matrix(self, yaw, pitch, roll): yaw = yaw / 180 * 3.14 pitch = pitch / 180 * 3.14 roll = roll / 180 * 3.14 roll = roll.unsqueeze(1) pitch = pitch.unsqueeze(1) yaw = yaw.unsqueeze(1) roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1) roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch), -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1) pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw), torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw), torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1) yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat) return rot_mat def keypoint_transformation(self, kp_canonical, he): kp = kp_canonical['value'] # (bs, k, 3) yaw, pitch, roll = he['yaw'], he['pitch'], he['roll'] t, exp = he['t'], he['exp'] yaw = self.headpose_pred_to_degree(yaw) pitch = self.headpose_pred_to_degree(pitch) roll = self.headpose_pred_to_degree(roll) rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) # keypoint rotation kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) # keypoint translation t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) kp_t = kp_rotated + t # add expression deviation exp = exp.view(exp.shape[0], -1, 3) kp_transformed = kp_t + exp return {'value': kp_transformed} def forward(self, source_image, target_audio): pose_source = self.he_estimator_video(source_image) pose_generated = self.he_estimator_audio(target_audio) kp_canonical = self.kp_extractor(source_image) kp_source = self.keypoint_transformation(kp_canonical, pose_source) kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated) generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated) return generated ================================================ FILE: src/facerender/sync_batchnorm/__init__.py ================================================ # -*- coding: utf-8 -*- # File : __init__.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d from .replicate import DataParallelWithCallback, patch_replication_callback ================================================ FILE: src/facerender/sync_batchnorm/batchnorm.py ================================================ # -*- coding: utf-8 -*- # File : batchnorm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import collections import torch import torch.nn.functional as F from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast from .comm import SyncMaster __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] def _sum_ft(tensor): """sum over the first and last dimention""" return tensor.sum(dim=0).sum(dim=-1) def _unsqueeze_ft(tensor): """add new dementions at the front and the tail""" return tensor.unsqueeze(0).unsqueeze(-1) _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) class _SynchronizedBatchNorm(_BatchNorm): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) self._sync_master = SyncMaster(self._data_parallel_master) self._is_parallel = False self._parallel_id = None self._slave_pipe = None def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape) def __data_parallel_replicate__(self, ctx, copy_id): self._is_parallel = True self._parallel_id = copy_id # parallel_id == 0 means master device. if self._parallel_id == 0: ctx.sync_master = self._sync_master else: self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) to_reduce = [i[1][:2] for i in intermediates] to_reduce = [j for i in to_reduce for j in i] # flatten target_gpus = [i[1].sum.get_device() for i in intermediates] sum_size = sum([i[1].sum_size for i in intermediates]) sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) broadcasted = Broadcast.apply(target_gpus, mean, inv_std) outputs = [] for i, rec in enumerate(intermediates): outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) return outputs def _compute_mean_std(self, sum_, ssum, size): """Compute the mean and standard-deviation with sum and square-sum. This method also maintains the moving average on the master device.""" assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' mean = sum_ / size sumvar = ssum - sum_ * mean unbias_var = sumvar / (size - 1) bias_var = sumvar / size self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data return mean, bias_var.clamp(self.eps) ** -0.5 class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a mini-batch. .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm1d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm Args: num_features: num_features from an expected input of size `batch_size x num_features [x width]` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)` - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm1d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm1d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm1d, self)._check_input_dim(input) class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch of 3d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm2d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm2d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm2d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm2d, self)._check_input_dim(input) class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch of 4d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm3d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm or Spatio-temporal BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x depth x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` - Output: :math:`(N, C, D, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm3d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm3d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm3d, self)._check_input_dim(input) ================================================ FILE: src/facerender/sync_batchnorm/comm.py ================================================ # -*- coding: utf-8 -*- # File : comm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import queue import collections import threading __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] class FutureResult(object): """A thread-safe future implementation. Used only as one-to-one pipe.""" def __init__(self): self._result = None self._lock = threading.Lock() self._cond = threading.Condition(self._lock) def put(self, result): with self._lock: assert self._result is None, 'Previous result has\'t been fetched.' self._result = result self._cond.notify() def get(self): with self._lock: if self._result is None: self._cond.wait() res = self._result self._result = None return res _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) class SlavePipe(_SlavePipeBase): """Pipe for master-slave communication.""" def run_slave(self, msg): self.queue.put((self.identifier, msg)) ret = self.result.get() self.queue.put(True) return ret class SyncMaster(object): """An abstract `SyncMaster` object. - During the replication, as the data parallel will trigger an callback of each module, all slave devices should call `register(id)` and obtain an `SlavePipe` to communicate with the master. - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, and passed to a registered callback. - After receiving the messages, the master device should gather the information and determine to message passed back to each slave devices. """ def __init__(self, master_callback): """ Args: master_callback: a callback to be invoked after having collected messages from slave devices. """ self._master_callback = master_callback self._queue = queue.Queue() self._registry = collections.OrderedDict() self._activated = False def __getstate__(self): return {'master_callback': self._master_callback} def __setstate__(self, state): self.__init__(state['master_callback']) def register_slave(self, identifier): """ Register an slave device. Args: identifier: an identifier, usually is the device id. Returns: a `SlavePipe` object which can be used to communicate with the master device. """ if self._activated: assert self._queue.empty(), 'Queue is not clean before next initialization.' self._activated = False self._registry.clear() future = FutureResult() self._registry[identifier] = _MasterRegistry(future) return SlavePipe(identifier, self._queue, future) def run_master(self, master_msg): """ Main entry for the master device in each forward pass. The messages were first collected from each devices (including the master device), and then an callback will be invoked to compute the message to be sent back to each devices (including the master device). Args: master_msg: the message that the master want to send to itself. This will be placed as the first message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. Returns: the message to be sent back to the master device. """ self._activated = True intermediates = [(0, master_msg)] for i in range(self.nr_slaves): intermediates.append(self._queue.get()) results = self._master_callback(intermediates) assert results[0][0] == 0, 'The first result should belongs to the master.' for i, res in results: if i == 0: continue self._registry[i].result.put(res) for i in range(self.nr_slaves): assert self._queue.get() is True return results[0][1] @property def nr_slaves(self): return len(self._registry) ================================================ FILE: src/facerender/sync_batchnorm/replicate.py ================================================ # -*- coding: utf-8 -*- # File : replicate.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import functools from torch.nn.parallel.data_parallel import DataParallel __all__ = [ 'CallbackContext', 'execute_replication_callbacks', 'DataParallelWithCallback', 'patch_replication_callback' ] class CallbackContext(object): pass def execute_replication_callbacks(modules): """ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies. """ master_copy = modules[0] nr_modules = len(list(master_copy.modules())) ctxs = [CallbackContext() for _ in range(nr_modules)] for i, module in enumerate(modules): for j, m in enumerate(module.modules()): if hasattr(m, '__data_parallel_replicate__'): m.__data_parallel_replicate__(ctxs[j], i) class DataParallelWithCallback(DataParallel): """ Data Parallel with a replication callback. An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by original `replicate` function. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) # sync_bn.__data_parallel_replicate__ will be invoked. """ def replicate(self, module, device_ids): modules = super(DataParallelWithCallback, self).replicate(module, device_ids) execute_replication_callbacks(modules) return modules def patch_replication_callback(data_parallel): """ Monkey-patch an existing `DataParallel` object. Add the replication callback. Useful when you have customized `DataParallel` implementation. Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) """ assert isinstance(data_parallel, DataParallel) old_replicate = data_parallel.replicate @functools.wraps(old_replicate) def new_replicate(module, device_ids): modules = old_replicate(module, device_ids) execute_replication_callbacks(modules) return modules data_parallel.replicate = new_replicate ================================================ FILE: src/facerender/sync_batchnorm/unittest.py ================================================ # -*- coding: utf-8 -*- # File : unittest.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import unittest import numpy as np from torch.autograd import Variable def as_numpy(v): if isinstance(v, Variable): v = v.data return v.cpu().numpy() class TorchTestCase(unittest.TestCase): def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): npa, npb = as_numpy(a), as_numpy(b) self.assertTrue( np.allclose(npa, npb, atol=atol), 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) ) ================================================ FILE: src/generate_batch.py ================================================ import os from tqdm import tqdm import torch import numpy as np import random import scipy.io as scio import src.utils.audio as audio def crop_pad_audio(wav, audio_length): if len(wav) > audio_length: wav = wav[:audio_length] elif len(wav) < audio_length: wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) return wav def parse_audio_length(audio_length, sr, fps): bit_per_frames = sr / fps num_frames = int(audio_length / bit_per_frames) audio_length = int(num_frames * bit_per_frames) return audio_length, num_frames def generate_blink_seq(num_frames): ratio = np.zeros((num_frames,1)) frame_id = 0 while frame_id in range(num_frames): start = 80 if frame_id+start+9<=num_frames - 1: ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5] frame_id = frame_id+start+9 else: break return ratio def generate_blink_seq_randomly(num_frames): ratio = np.zeros((num_frames,1)) if num_frames<=20: return ratio frame_id = 0 while frame_id in range(num_frames): start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) if frame_id+start+5<=num_frames - 1: ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5] frame_id = frame_id+start+5 else: break return ratio def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True): syncnet_mel_step_size = 16 fps = 25 pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0] audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] if idlemode: num_frames = int(length_of_audio * 25) indiv_mels = np.zeros((num_frames, 80, 16)) else: wav = audio.load_wav(audio_path, 16000) wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) wav = crop_pad_audio(wav, wav_length) orig_mel = audio.melspectrogram(wav).T spec = orig_mel.copy() # nframes 80 indiv_mels = [] for i in tqdm(range(num_frames), 'mel:'): start_frame_num = i-2 start_idx = int(80. * (start_frame_num / float(fps))) end_idx = start_idx + syncnet_mel_step_size seq = list(range(start_idx, end_idx)) seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] m = spec[seq, :] indiv_mels.append(m.T) indiv_mels = np.asarray(indiv_mels) # T 80 16 ratio = generate_blink_seq_randomly(num_frames) # T source_semantics_path = first_coeff_path source_semantics_dict = scio.loadmat(source_semantics_path) ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 ref_coeff = np.repeat(ref_coeff, num_frames, axis=0) if ref_eyeblink_coeff_path is not None: ratio[:num_frames] = 0 refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path) refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64] refeyeblink_num_frames = refeyeblink_coeff.shape[0] if refeyeblink_num_frames frame_num: new_degree_list = new_degree_list[:frame_num] elif len(new_degree_list) < frame_num: for _ in range(frame_num-len(new_degree_list)): new_degree_list.append(new_degree_list[-1]) print(len(new_degree_list)) print(frame_num) remainder = frame_num%batch_size if remainder!=0: for _ in range(batch_size-remainder): new_degree_list.append(new_degree_list[-1]) new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) return new_degree_np ================================================ FILE: src/gradio_demo.py ================================================ import torch, uuid import os, sys, shutil from src.utils.preprocess import CropAndExtract from src.test_audio2coeff import Audio2Coeff from src.facerender.animate import AnimateFromCoeff from src.generate_batch import get_data from src.generate_facerender_batch import get_facerender_data from src.utils.init_path import init_path from pydub import AudioSegment def mp3_to_wav(mp3_filename,wav_filename,frame_rate): mp3_file = AudioSegment.from_file(file=mp3_filename) mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav") class SadTalker(): def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False): if torch.cuda.is_available() : device = "cuda" else: device = "cpu" self.device = device os.environ['TORCH_HOME']= checkpoint_path self.checkpoint_path = checkpoint_path self.config_path = config_path def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style = 0, exp_scale=1.0, use_ref_video = False, ref_video = None, ref_info = None, use_idle_mode = False, length_of_audio = 0, use_blink=True, result_dir='./results/'): self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess) print(self.sadtalker_paths) self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device) self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device) self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device) time_tag = str(uuid.uuid4()) save_dir = os.path.join(result_dir, time_tag) os.makedirs(save_dir, exist_ok=True) input_dir = os.path.join(save_dir, 'input') os.makedirs(input_dir, exist_ok=True) print(source_image) pic_path = os.path.join(input_dir, os.path.basename(source_image)) shutil.move(source_image, input_dir) if driven_audio is not None and os.path.isfile(driven_audio): audio_path = os.path.join(input_dir, os.path.basename(driven_audio)) #### mp3 to wav if '.mp3' in audio_path: mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000) audio_path = audio_path.replace('.mp3', '.wav') else: shutil.move(driven_audio, input_dir) elif use_idle_mode: audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path from pydub import AudioSegment one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds one_sec_segment.export(audio_path, format="wav") else: print(use_ref_video, ref_info) assert use_ref_video == True and ref_info == 'all' if use_ref_video and ref_info == 'all': # full ref mode ref_video_videoname = os.path.basename(ref_video) audio_path = os.path.join(save_dir, ref_video_videoname+'.wav') print('new audiopath:',audio_path) # if ref_video contains audio, set the audio from ref_video. cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path) os.system(cmd) os.makedirs(save_dir, exist_ok=True) #crop image and extract 3dmm from image first_frame_dir = os.path.join(save_dir, 'first_frame_dir') os.makedirs(first_frame_dir, exist_ok=True) first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size) if first_coeff_path is None: raise AttributeError("No face is detected") if use_ref_video: print('using ref video for genreation') ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0] ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname) os.makedirs(ref_video_frame_dir, exist_ok=True) print('3DMM Extraction for the reference video providing pose') ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False) else: ref_video_coeff_path = None if use_ref_video: if ref_info == 'pose': ref_pose_coeff_path = ref_video_coeff_path ref_eyeblink_coeff_path = None elif ref_info == 'blink': ref_pose_coeff_path = None ref_eyeblink_coeff_path = ref_video_coeff_path elif ref_info == 'pose+blink': ref_pose_coeff_path = ref_video_coeff_path ref_eyeblink_coeff_path = ref_video_coeff_path elif ref_info == 'all': ref_pose_coeff_path = None ref_eyeblink_coeff_path = None else: raise('error in refinfo') else: ref_pose_coeff_path = None ref_eyeblink_coeff_path = None #audio2ceoff if use_ref_video and ref_info == 'all': coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) else: batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio? coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) #coeff2video data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size, expression_scale = exp_scale) return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size) video_name = data['video_name'] print(f'The generated video is named {video_name} in {save_dir}') del self.preprocess_model del self.audio_to_coeff del self.animate_from_coeff if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() import gc; gc.collect() return return_path ================================================ FILE: src/test_audio2coeff.py ================================================ import os import torch import numpy as np from scipy.io import savemat, loadmat from yacs.config import CfgNode as CN from scipy.signal import savgol_filter import safetensors import safetensors.torch from src.audio2pose_models.audio2pose import Audio2Pose from src.audio2exp_models.networks import SimpleWrapperV2 from src.audio2exp_models.audio2exp import Audio2Exp from src.utils.safetensor_helper import load_x_from_safetensor def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"): checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) if model is not None: model.load_state_dict(checkpoint['model']) if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer']) return checkpoint['epoch'] class Audio2Coeff(): def __init__(self, sadtalker_path, device): #load config fcfg_pose = open(sadtalker_path['audio2pose_yaml_path']) cfg_pose = CN.load_cfg(fcfg_pose) cfg_pose.freeze() fcfg_exp = open(sadtalker_path['audio2exp_yaml_path']) cfg_exp = CN.load_cfg(fcfg_exp) cfg_exp.freeze() # load audio2pose_model self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device) self.audio2pose_model = self.audio2pose_model.to(device) self.audio2pose_model.eval() for param in self.audio2pose_model.parameters(): param.requires_grad = False try: if sadtalker_path['use_safetensor']: checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose')) else: load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device) except: raise Exception("Failed in loading audio2pose_checkpoint") # load audio2exp_model netG = SimpleWrapperV2() netG = netG.to(device) for param in netG.parameters(): netG.requires_grad = False netG.eval() try: if sadtalker_path['use_safetensor']: checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp')) else: load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device) except: raise Exception("Failed in loading audio2exp_checkpoint") self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False) self.audio2exp_model = self.audio2exp_model.to(device) for param in self.audio2exp_model.parameters(): param.requires_grad = False self.audio2exp_model.eval() self.device = device def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None): with torch.no_grad(): #test results_dict_exp= self.audio2exp_model.test(batch) exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64 #for class_id in range(1): #class_id = 0#(i+10)%45 #class_id = random.randint(0,46) #46 styles can be selected batch['class'] = torch.LongTensor([pose_style]).to(self.device) results_dict_pose = self.audio2pose_model.test(batch) pose_pred = results_dict_pose['pose_pred'] #bs T 6 pose_len = pose_pred.shape[1] if pose_len<13: pose_len = int((pose_len-1)/2)*2+1 pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device) else: pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device) coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70 coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy() if ref_pose_coeff_path is not None: coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path) savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])), {'coeff_3dmm': coeffs_pred_numpy}) return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])) def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path): num_frames = coeffs_pred_numpy.shape[0] refpose_coeff_dict = loadmat(ref_pose_coeff_path) refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70] refpose_num_frames = refpose_coeff.shape[0] if refpose_num_frames= 0 if hp.symmetric_mels: return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value else: return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) def _denormalize(D): if hp.allow_clipping_in_normalization: if hp.symmetric_mels: return (((np.clip(D, -hp.max_abs_value, hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) else: return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) if hp.symmetric_mels: return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) else: return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) ================================================ FILE: src/utils/croper.py ================================================ import os import cv2 import time import glob import argparse import scipy import numpy as np from PIL import Image import torch from tqdm import tqdm from itertools import cycle from src.face3d.extract_kp_videos_safe import KeypointExtractor from facexlib.alignment import landmark_98_to_68 import numpy as np from PIL import Image class Preprocesser: def __init__(self, device='cuda'): self.predictor = KeypointExtractor(device) def get_landmark(self, img_np): """get landmark with dlib :return: np.array shape=(68, 2) """ with torch.no_grad(): dets = self.predictor.det_net.detect_faces(img_np, 0.97) if len(dets) == 0: return None det = dets[0] img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :] lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0] #### keypoints to the original location lm[:,0] += int(det[0]) lm[:,1] += int(det[1]) return lm def align_face(self, img, lm, output_size=1024): """ :param filepath: str :return: PIL Image """ lm_chin = lm[0: 17] # left-right lm_eyebrow_left = lm[17: 22] # left-right lm_eyebrow_right = lm[22: 27] # left-right lm_nose = lm[27: 31] # top-down lm_nostrils = lm[31: 36] # top-down lm_eye_left = lm[36: 42] # left-clockwise lm_eye_right = lm[42: 48] # left-clockwise lm_mouth_outer = lm[48: 60] # left-clockwise lm_mouth_inner = lm[60: 68] # left-clockwise # Calculate auxiliary vectors. eye_left = np.mean(lm_eye_left, axis=0) eye_right = np.mean(lm_eye_right, axis=0) eye_avg = (eye_left + eye_right) * 0.5 eye_to_eye = eye_right - eye_left mouth_left = lm_mouth_outer[0] mouth_right = lm_mouth_outer[6] mouth_avg = (mouth_left + mouth_right) * 0.5 eye_to_mouth = mouth_avg - eye_avg # Choose oriented crop rectangle. x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化 x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度 y = np.flipud(x) * [-1, 1] c = eye_avg + eye_to_mouth * 0.1 quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点 qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍 # Shrink. # 如果计算出的四边形太大了,就按比例缩小它 shrink = int(np.floor(qsize / output_size * 0.5)) if shrink > 1: rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) img = img.resize(rsize, Image.ANTIALIAS) quad /= shrink qsize /= shrink else: rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1])))) # Crop. border = max(int(np.rint(qsize * 0.1)), 3) crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: # img = img.crop(crop) quad -= crop[0:2] # Pad. pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) # if enable_padding and max(pad) > border - 4: # pad = np.maximum(pad, int(np.rint(qsize * 0.3))) # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') # h, w, _ = img.shape # y, x, _ = np.ogrid[:h, :w, :1] # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) # blur = qsize * 0.02 # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') # quad += pad[:2] # Transform. quad = (quad + 0.5).flatten() lx = max(min(quad[0], quad[2]), 0) ly = max(min(quad[1], quad[7]), 0) rx = min(max(quad[4], quad[6]), img.size[0]) ry = min(max(quad[3], quad[5]), img.size[0]) # Save aligned image. return rsize, crop, [lx, ly, rx, ry] def crop(self, img_np_list, still=False, xsize=512): # first frame for all video img_np = img_np_list[0] lm = self.get_landmark(img_np) if lm is None: raise 'can not detect the landmark from source image' rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize) clx, cly, crx, cry = crop lx, ly, rx, ry = quad lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) for _i in range(len(img_np_list)): _inp = img_np_list[_i] _inp = cv2.resize(_inp, (rsize[0], rsize[1])) _inp = _inp[cly:cry, clx:crx] if not still: _inp = _inp[ly:ry, lx:rx] img_np_list[_i] = _inp return img_np_list, crop, quad ================================================ FILE: src/utils/face_enhancer.py ================================================ import os import torch from gfpgan import GFPGANer from tqdm import tqdm from src.utils.videoio import load_video_to_cv2 import cv2 class GeneratorWithLen(object): """ From https://stackoverflow.com/a/7460929 """ def __init__(self, gen, length): self.gen = gen self.length = length def __len__(self): return self.length def __iter__(self): return self.gen def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) return list(gen) def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): """ Provide a generator with a __len__ method so that it can passed to functions that call len()""" if os.path.isfile(images): # handle video to images # TODO: Create a generator version of load_video_to_cv2 images = load_video_to_cv2(images) gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) gen_with_len = GeneratorWithLen(gen, len(images)) return gen_with_len def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): """ Provide a generator function so that all of the enhanced images don't need to be stored in memory at the same time. This can save tons of RAM compared to the enhancer function. """ print('face enhancer....') if not isinstance(images, list) and os.path.isfile(images): # handle video to images images = load_video_to_cv2(images) # ------------------------ set up GFPGAN restorer ------------------------ if method == 'gfpgan': arch = 'clean' channel_multiplier = 2 model_name = 'GFPGANv1.4' url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' elif method == 'RestoreFormer': arch = 'RestoreFormer' channel_multiplier = 2 model_name = 'RestoreFormer' url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' elif method == 'codeformer': # TODO: arch = 'CodeFormer' channel_multiplier = 2 model_name = 'CodeFormer' url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' else: raise ValueError(f'Wrong model version {method}.') # ------------------------ set up background upsampler ------------------------ if bg_upsampler == 'realesrgan': if not torch.cuda.is_available(): # CPU import warnings warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' 'If you really want to use it, please modify the corresponding codes.') bg_upsampler = None else: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) bg_upsampler = RealESRGANer( scale=2, model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', model=model, tile=400, tile_pad=10, pre_pad=0, half=True) # need to set False in CPU mode else: bg_upsampler = None # determine model paths model_path = os.path.join('gfpgan/weights', model_name + '.pth') if not os.path.isfile(model_path): model_path = os.path.join('checkpoints', model_name + '.pth') if not os.path.isfile(model_path): # download pre-trained models from url model_path = url restorer = GFPGANer( model_path=model_path, upscale=2, arch=arch, channel_multiplier=channel_multiplier, bg_upsampler=bg_upsampler) # ------------------------ restore ------------------------ for idx in tqdm(range(len(images)), 'Face Enhancer:'): img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) # restore faces and background if necessary cropped_faces, restored_faces, r_img = restorer.enhance( img, has_aligned=False, only_center_face=False, paste_back=True) r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) yield r_img ================================================ FILE: src/utils/hparams.py ================================================ from glob import glob import os class HParams: def __init__(self, **kwargs): self.data = {} for key, value in kwargs.items(): self.data[key] = value def __getattr__(self, key): if key not in self.data: raise AttributeError("'HParams' object has no attribute %s" % key) return self.data[key] def set_hparam(self, key, value): self.data[key] = value # Default hyperparameters hparams = HParams( num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality # network rescale=True, # Whether to rescale audio prior to preprocessing rescaling_max=0.9, # Rescaling value # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder # Does not work if n_ffit is not multiple of hop_size!! use_lws=False, n_fft=800, # Extra window size is filled with 0 paddings to match this parameter hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) # Mel and Linear spectrograms normalization/scaling and clipping signal_normalization=True, # Whether to normalize mel spectrograms to some predefined range (following below parameters) allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True symmetric_mels=True, # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, # faster and cleaner convergence) max_abs_value=4., # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not # be too big to avoid gradient explosion, # not too small for fast convergence) # Contribution by @begeekmyfriend # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude # levels. Also allows for better G&L phase reconstruction) preemphasize=True, # whether to apply filter preemphasis=0.97, # filter coefficient. # Limits min_level_db=-100, ref_level_db=20, fmin=55, # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) fmax=7600, # To be increased/reduced depending on data. ###################### Our training parameters ################################# img_size=96, fps=25, batch_size=16, initial_learning_rate=1e-4, nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs num_workers=20, checkpoint_interval=3000, eval_interval=3000, writer_interval=300, save_optimizer_state=True, syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. syncnet_batch_size=64, syncnet_lr=1e-4, syncnet_eval_interval=1000, syncnet_checkpoint_interval=10000, disc_wt=0.07, disc_initial_learning_rate=1e-4, ) # Default hyperparameters hparamsdebug = HParams( num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality # network rescale=True, # Whether to rescale audio prior to preprocessing rescaling_max=0.9, # Rescaling value # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder # Does not work if n_ffit is not multiple of hop_size!! use_lws=False, n_fft=800, # Extra window size is filled with 0 paddings to match this parameter hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) # Mel and Linear spectrograms normalization/scaling and clipping signal_normalization=True, # Whether to normalize mel spectrograms to some predefined range (following below parameters) allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True symmetric_mels=True, # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, # faster and cleaner convergence) max_abs_value=4., # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not # be too big to avoid gradient explosion, # not too small for fast convergence) # Contribution by @begeekmyfriend # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude # levels. Also allows for better G&L phase reconstruction) preemphasize=True, # whether to apply filter preemphasis=0.97, # filter coefficient. # Limits min_level_db=-100, ref_level_db=20, fmin=55, # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) fmax=7600, # To be increased/reduced depending on data. ###################### Our training parameters ################################# img_size=96, fps=25, batch_size=2, initial_learning_rate=1e-3, nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs num_workers=0, checkpoint_interval=10000, eval_interval=10, writer_interval=5, save_optimizer_state=True, syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. syncnet_batch_size=64, syncnet_lr=1e-4, syncnet_eval_interval=10000, syncnet_checkpoint_interval=10000, disc_wt=0.07, disc_initial_learning_rate=1e-4, ) def hparams_debug_string(): values = hparams.values() hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] return "Hyperparameters:\n" + "\n".join(hp) ================================================ FILE: src/utils/init_path.py ================================================ import os import glob def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'): if old_version: #### load all the checkpoint of `pth` sadtalker_paths = { 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') } use_safetensor = False elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))): print('using safetensor as default') sadtalker_paths = { "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'), } use_safetensor = True else: print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!") use_safetensor = False sadtalker_paths = { 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') } sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting' sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml') sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml') sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml') if 'full' in preprocess: sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar') sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml') else: sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar') sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml') return sadtalker_paths ================================================ FILE: src/utils/model2safetensor.py ================================================ import torch import yaml import os import safetensors from safetensors.torch import save_file from yacs.config import CfgNode as CN import sys sys.path.append('/apdcephfs/private_shadowcun/SadTalker') from src.face3d.models import networks from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector from src.facerender.modules.mapping import MappingNet from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator from src.audio2pose_models.audio2pose import Audio2Pose from src.audio2exp_models.networks import SimpleWrapperV2 from src.test_audio2coeff import load_cpk size = 256 ############ face vid2vid config_path = os.path.join('src', 'config', 'facerender.yaml') current_root_path = '.' path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth') net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='') checkpoint = torch.load(path_of_net_recon_model, map_location='cpu') net_recon.load_state_dict(checkpoint['net_recon']) with open(config_path) as f: config = yaml.safe_load(f) generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']) kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], **config['model_params']['common_params']) mapping = MappingNet(**config['model_params']['mapping_params']) def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None, kp_detector=None, he_estimator=None, optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, optimizer_he_estimator=None, device="cpu"): checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) if generator is not None: generator.load_state_dict(checkpoint['generator']) if kp_detector is not None: kp_detector.load_state_dict(checkpoint['kp_detector']) if he_estimator is not None: he_estimator.load_state_dict(checkpoint['he_estimator']) if discriminator is not None: try: discriminator.load_state_dict(checkpoint['discriminator']) except: print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') if optimizer_generator is not None: optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) if optimizer_discriminator is not None: try: optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) except RuntimeError as e: print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') if optimizer_kp_detector is not None: optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) if optimizer_he_estimator is not None: optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) return checkpoint['epoch'] def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None, kp_detector=None, he_estimator=None, device="cpu"): checkpoint = safetensors.torch.load_file(checkpoint_path) if generator is not None: x_generator = {} for k,v in checkpoint.items(): if 'generator' in k: x_generator[k.replace('generator.', '')] = v generator.load_state_dict(x_generator) if kp_detector is not None: x_generator = {} for k,v in checkpoint.items(): if 'kp_extractor' in k: x_generator[k.replace('kp_extractor.', '')] = v kp_detector.load_state_dict(x_generator) if he_estimator is not None: x_generator = {} for k,v in checkpoint.items(): if 'he_estimator' in k: x_generator[k.replace('he_estimator.', '')] = v he_estimator.load_state_dict(x_generator) return None free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar' load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth') audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth') audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth') audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') fcfg_pose = open(audio2pose_yaml_path) cfg_pose = CN.load_cfg(fcfg_pose) cfg_pose.freeze() audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint) audio2pose_model.eval() load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu') # load audio2exp_model netG = SimpleWrapperV2() netG.eval() load_cpk(audio2exp_checkpoint, model=netG, device='cpu') class SadTalker(torch.nn.Module): def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon): super(SadTalker, self).__init__() self.kp_extractor = kp_extractor self.generator = generator self.audio2exp = netG self.audio2pose = audio2pose self.face_3drecon = face_3drecon model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon) # here, we want to convert it to safetensor save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors") ### test load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) ================================================ FILE: src/utils/paste_pic.py ================================================ import cv2, os import numpy as np from tqdm import tqdm import uuid from src.utils.videoio import save_video_with_watermark def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False): if not os.path.isfile(pic_path): raise ValueError('pic_path must be a valid path to video/image file') elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: # loader for first frame full_img = cv2.imread(pic_path) else: # loader for videos video_stream = cv2.VideoCapture(pic_path) fps = video_stream.get(cv2.CAP_PROP_FPS) full_frames = [] while 1: still_reading, frame = video_stream.read() if not still_reading: video_stream.release() break break full_img = frame frame_h = full_img.shape[0] frame_w = full_img.shape[1] video_stream = cv2.VideoCapture(video_path) fps = video_stream.get(cv2.CAP_PROP_FPS) crop_frames = [] while 1: still_reading, frame = video_stream.read() if not still_reading: video_stream.release() break crop_frames.append(frame) if len(crop_info) != 3: print("you didn't crop the image") return else: r_w, r_h = crop_info[0] clx, cly, crx, cry = crop_info[1] lx, ly, rx, ry = crop_info[2] lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx if extended_crop: oy1, oy2, ox1, ox2 = cly, cry, clx, crx else: oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx tmp_path = str(uuid.uuid4())+'.mp4' out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) for crop_frame in tqdm(crop_frames, 'seamlessClone:'): p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1)) mask = 255*np.ones(p.shape, p.dtype) location = ((ox1+ox2) // 2, (oy1+oy2) // 2) gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE) out_tmp.write(gen_img) out_tmp.release() save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False) os.remove(tmp_path) ================================================ FILE: src/utils/preprocess.py ================================================ import numpy as np import cv2, os, sys, torch from tqdm import tqdm from PIL import Image # 3dmm extraction import safetensors import safetensors.torch from src.face3d.util.preprocess import align_img from src.face3d.util.load_mats import load_lm3d from src.face3d.models import networks from scipy.io import loadmat, savemat from src.utils.croper import Preprocesser import warnings from src.utils.safetensor_helper import load_x_from_safetensor warnings.filterwarnings("ignore") def split_coeff(coeffs): """ Return: coeffs_dict -- a dict of torch.tensors Parameters: coeffs -- torch.tensor, size (B, 256) """ id_coeffs = coeffs[:, :80] exp_coeffs = coeffs[:, 80: 144] tex_coeffs = coeffs[:, 144: 224] angles = coeffs[:, 224: 227] gammas = coeffs[:, 227: 254] translations = coeffs[:, 254:] return { 'id': id_coeffs, 'exp': exp_coeffs, 'tex': tex_coeffs, 'angle': angles, 'gamma': gammas, 'trans': translations } class CropAndExtract(): def __init__(self, sadtalker_path, device): self.propress = Preprocesser(device) self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) if sadtalker_path['use_safetensor']: checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint']) self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon')) else: checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device)) self.net_recon.load_state_dict(checkpoint['net_recon']) self.net_recon.eval() self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting']) self.device = device def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256): pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt') coeff_path = os.path.join(save_dir, pic_name+'.mat') png_path = os.path.join(save_dir, pic_name+'.png') #load input if not os.path.isfile(input_path): raise ValueError('input_path must be a valid path to video/image file') elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: # loader for first frame full_frames = [cv2.imread(input_path)] fps = 25 else: # loader for videos video_stream = cv2.VideoCapture(input_path) fps = video_stream.get(cv2.CAP_PROP_FPS) full_frames = [] while 1: still_reading, frame = video_stream.read() if not still_reading: video_stream.release() break full_frames.append(frame) if source_image_flag: break x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames] #### crop images as the if 'crop' in crop_or_resize.lower(): # default crop x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) clx, cly, crx, cry = crop lx, ly, rx, ry = quad lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) elif 'full' in crop_or_resize.lower(): x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) clx, cly, crx, cry = crop lx, ly, rx, ry = quad lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) else: # resize mode oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] crop_info = ((ox2 - ox1, oy2 - oy1), None, None) frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames] if len(frames_pil) == 0: print('No face is detected in the input file') return None, None # save crop info for frame in frames_pil: cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) # 2. get the landmark according to the detected face. if not os.path.isfile(landmarks_path): lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path) else: print(' Using saved landmarks.') lm = np.loadtxt(landmarks_path).astype(np.float32) lm = lm.reshape([len(x_full_frames), -1, 2]) if not os.path.isfile(coeff_path): # load 3dmm paramter generator from Deep3DFaceRecon_pytorch video_coeffs, full_coeffs = [], [] for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'): frame = frames_pil[idx] W,H = frame.size lm1 = lm[idx].reshape([-1, 2]) if np.mean(lm1) == -1: lm1 = (self.lm3d_std[:, :2]+1)/2. lm1 = np.concatenate( [lm1[:, :1]*W, lm1[:, 1:2]*H], 1 ) else: lm1[:, -1] = H - 1 - lm1[:, -1] trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std) trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32) im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) with torch.no_grad(): full_coeff = self.net_recon(im_t) coeffs = split_coeff(full_coeff) pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs} pred_coeff = np.concatenate([ pred_coeff['exp'], pred_coeff['angle'], pred_coeff['trans'], trans_params[2:][None], ], 1) video_coeffs.append(pred_coeff) full_coeffs.append(full_coeff.cpu().numpy()) semantic_npy = np.array(video_coeffs)[:,0] savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0]}) return coeff_path, png_path, crop_info ================================================ FILE: src/utils/safetensor_helper.py ================================================ def load_x_from_safetensor(checkpoint, key): x_generator = {} for k,v in checkpoint.items(): if key in k: x_generator[k.replace(key+'.', '')] = v return x_generator ================================================ FILE: src/utils/text2speech.py ================================================ import os import tempfile from TTS.api import TTS class TTSTalker(): def __init__(self) -> None: model_name = TTS().list_models()[0] self.tts = TTS(model_name) def test(self, text, language='en'): tempf = tempfile.NamedTemporaryFile( delete = False, suffix = ('.'+'wav'), ) self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name) return tempf.name ================================================ FILE: src/utils/videoio.py ================================================ import shutil import uuid import os import cv2 def load_video_to_cv2(input_path): video_stream = cv2.VideoCapture(input_path) fps = video_stream.get(cv2.CAP_PROP_FPS) full_frames = [] while 1: still_reading, frame = video_stream.read() if not still_reading: video_stream.release() break full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) return full_frames def save_video_with_watermark(video, audio, save_path, watermark=False): temp_file = str(uuid.uuid4())+'.mp4' cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file) os.system(cmd) if watermark is False: shutil.move(temp_file, save_path) else: # watermark try: ##### check if stable-diffusion-webui import webui from modules import paths watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png" except: # get the root path of sadtalker. dir_path = os.path.dirname(os.path.realpath(__file__)) watarmark_path = dir_path+"/../../docs/sadtalker_logo.png" cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path) os.system(cmd) os.remove(temp_file) ================================================ FILE: webui.bat ================================================ @echo off IF NOT EXIST venv ( python -m venv venv ) ELSE ( echo venv folder already exists, skipping creation... ) call .\venv\Scripts\activate.bat set PYTHON="venv\Scripts\Python.exe" echo venv %PYTHON% %PYTHON% Launcher.py echo. echo Launch unsuccessful. Exiting. pause ================================================ FILE: webui.sh ================================================ #!/usr/bin/env bash # If run from macOS, load defaults from webui-macos-env.sh if [[ "$OSTYPE" == "darwin"* ]]; then export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" fi # python3 executable if [[ -z "${python_cmd}" ]] then python_cmd="python3" fi # git executable if [[ -z "${GIT}" ]] then export GIT="git" fi # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) if [[ -z "${venv_dir}" ]] then venv_dir="venv" fi if [[ -z "${LAUNCH_SCRIPT}" ]] then LAUNCH_SCRIPT="launcher.py" fi # this script cannot be run as root by default can_run_as_root=1 # read any command line flags to the webui.sh script while getopts "f" flag > /dev/null 2>&1 do case ${flag} in f) can_run_as_root=1;; *) break;; esac done # Disable sentry logging export ERROR_REPORTING=FALSE # Do not reinstall existing pip packages on Debian/Ubuntu export PIP_IGNORE_INSTALLED=0 # Pretty print delimiter="################################################################" printf "\n%s\n" "${delimiter}" printf "\e[1m\e[32mInstall script for SadTalker + Web UI\n" printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" printf "\n%s\n" "${delimiter}" # Do not run as root if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]] then printf "\n%s\n" "${delimiter}" printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m" printf "\n%s\n" "${delimiter}" exit 1 else printf "\n%s\n" "${delimiter}" printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)" printf "\n%s\n" "${delimiter}" fi if [[ -d .git ]] then printf "\n%s\n" "${delimiter}" printf "Repo already cloned, using it as install directory" printf "\n%s\n" "${delimiter}" install_dir="${PWD}/../" clone_dir="${PWD##*/}" fi # Check prerequisites gpu_info=$(lspci 2>/dev/null | grep VGA) case "$gpu_info" in *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" printf "\n%s\n" "${delimiter}" ;; *) ;; esac if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] then export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" fi for preq in "${GIT}" "${python_cmd}" do if ! hash "${preq}" &>/dev/null then printf "\n%s\n" "${delimiter}" printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}" printf "\n%s\n" "${delimiter}" exit 1 fi done if ! "${python_cmd}" -c "import venv" &>/dev/null then printf "\n%s\n" "${delimiter}" printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m" printf "\n%s\n" "${delimiter}" exit 1 fi printf "\n%s\n" "${delimiter}" printf "Create and activate python venv" printf "\n%s\n" "${delimiter}" cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } if [[ ! -d "${venv_dir}" ]] then "${python_cmd}" -m venv "${venv_dir}" first_launch=1 fi # shellcheck source=/dev/null if [[ -f "${venv_dir}"/bin/activate ]] then source "${venv_dir}"/bin/activate else printf "\n%s\n" "${delimiter}" printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" printf "\n%s\n" "${delimiter}" exit 1 fi printf "\n%s\n" "${delimiter}" printf "Launching launcher.py..." printf "\n%s\n" "${delimiter}" exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"