Showing preview only (613K chars total). Download the full file or copy to clipboard to get everything.
Repository: OpenTalker/SadTalker
Branch: main
Commit: cd4c0465ae0b
Files: 141
Total size: 575.1 KB
Directory structure:
gitextract_pkbb6gh_/
├── .gitignore
├── LICENSE
├── README.md
├── app_sadtalker.py
├── cog.yaml
├── docs/
│ ├── FAQ.md
│ ├── best_practice.md
│ ├── changlelog.md
│ ├── face3d.md
│ ├── install.md
│ └── webui_extension.md
├── inference.py
├── launcher.py
├── predict.py
├── quick_demo.ipynb
├── req.txt
├── requirements.txt
├── requirements3d.txt
├── scripts/
│ ├── download_models.sh
│ ├── extension.py
│ └── test.sh
├── src/
│ ├── audio2exp_models/
│ │ ├── audio2exp.py
│ │ └── networks.py
│ ├── audio2pose_models/
│ │ ├── audio2pose.py
│ │ ├── audio_encoder.py
│ │ ├── cvae.py
│ │ ├── discriminator.py
│ │ ├── networks.py
│ │ └── res_unet.py
│ ├── config/
│ │ ├── auido2exp.yaml
│ │ ├── auido2pose.yaml
│ │ ├── facerender.yaml
│ │ ├── facerender_still.yaml
│ │ └── similarity_Lm3D_all.mat
│ ├── face3d/
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── base_dataset.py
│ │ │ ├── flist_dataset.py
│ │ │ ├── image_folder.py
│ │ │ └── template_dataset.py
│ │ ├── extract_kp_videos.py
│ │ ├── extract_kp_videos_safe.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── arcface_torch/
│ │ │ │ ├── README.md
│ │ │ │ ├── backbones/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── iresnet.py
│ │ │ │ │ ├── iresnet2060.py
│ │ │ │ │ └── mobilefacenet.py
│ │ │ │ ├── configs/
│ │ │ │ │ ├── 3millions.py
│ │ │ │ │ ├── 3millions_pfc.py
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── base.py
│ │ │ │ │ ├── glint360k_mbf.py
│ │ │ │ │ ├── glint360k_r100.py
│ │ │ │ │ ├── glint360k_r18.py
│ │ │ │ │ ├── glint360k_r34.py
│ │ │ │ │ ├── glint360k_r50.py
│ │ │ │ │ ├── ms1mv3_mbf.py
│ │ │ │ │ ├── ms1mv3_r18.py
│ │ │ │ │ ├── ms1mv3_r2060.py
│ │ │ │ │ ├── ms1mv3_r34.py
│ │ │ │ │ ├── ms1mv3_r50.py
│ │ │ │ │ └── speed.py
│ │ │ │ ├── dataset.py
│ │ │ │ ├── docs/
│ │ │ │ │ ├── eval.md
│ │ │ │ │ ├── install.md
│ │ │ │ │ ├── modelzoo.md
│ │ │ │ │ └── speed_benchmark.md
│ │ │ │ ├── eval/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── verification.py
│ │ │ │ ├── eval_ijbc.py
│ │ │ │ ├── inference.py
│ │ │ │ ├── losses.py
│ │ │ │ ├── onnx_helper.py
│ │ │ │ ├── onnx_ijbc.py
│ │ │ │ ├── partial_fc.py
│ │ │ │ ├── requirement.txt
│ │ │ │ ├── run.sh
│ │ │ │ ├── torch2onnx.py
│ │ │ │ ├── train.py
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── plot.py
│ │ │ │ ├── utils_amp.py
│ │ │ │ ├── utils_callbacks.py
│ │ │ │ ├── utils_config.py
│ │ │ │ ├── utils_logging.py
│ │ │ │ └── utils_os.py
│ │ │ ├── base_model.py
│ │ │ ├── bfm.py
│ │ │ ├── facerecon_model.py
│ │ │ ├── losses.py
│ │ │ ├── networks.py
│ │ │ └── template_model.py
│ │ ├── options/
│ │ │ ├── __init__.py
│ │ │ ├── base_options.py
│ │ │ ├── inference_options.py
│ │ │ ├── test_options.py
│ │ │ └── train_options.py
│ │ ├── util/
│ │ │ ├── BBRegressorParam_r.mat
│ │ │ ├── __init__.py
│ │ │ ├── detect_lm68.py
│ │ │ ├── generate_list.py
│ │ │ ├── html.py
│ │ │ ├── load_mats.py
│ │ │ ├── my_awing_arch.py
│ │ │ ├── nvdiffrast.py
│ │ │ ├── preprocess.py
│ │ │ ├── skin_mask.py
│ │ │ ├── test_mean_face.txt
│ │ │ ├── util.py
│ │ │ └── visualizer.py
│ │ └── visualize.py
│ ├── facerender/
│ │ ├── animate.py
│ │ ├── modules/
│ │ │ ├── dense_motion.py
│ │ │ ├── discriminator.py
│ │ │ ├── generator.py
│ │ │ ├── keypoint_detector.py
│ │ │ ├── make_animation.py
│ │ │ ├── mapping.py
│ │ │ └── util.py
│ │ └── sync_batchnorm/
│ │ ├── __init__.py
│ │ ├── batchnorm.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ └── unittest.py
│ ├── generate_batch.py
│ ├── generate_facerender_batch.py
│ ├── gradio_demo.py
│ ├── test_audio2coeff.py
│ └── utils/
│ ├── audio.py
│ ├── croper.py
│ ├── face_enhancer.py
│ ├── hparams.py
│ ├── init_path.py
│ ├── model2safetensor.py
│ ├── paste_pic.py
│ ├── preprocess.py
│ ├── safetensor_helper.py
│ ├── text2speech.py
│ └── videoio.py
├── webui.bat
└── webui.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
examples/results/*
gfpgan/*
checkpoints/*
assets/*
results/*
Dockerfile
start_docker.sh
start.sh
checkpoints
# Mac
.DS_Store
================================================
FILE: LICENSE
================================================
Tencent is pleased to support the open source community by making SadTalker available.
Copyright (C), a Tencent company. All rights reserved.
SadTalker is licensed under the Apache 2.0 License, except for the third-party components listed below.
Terms of the Apache License Version 2.0:
---------------------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<div align="center">
<img src='https://user-images.githubusercontent.com/4397546/229094115-862c747e-7397-4b54-ba4a-bd368bfe2e0f.png' width='500px'/>
<!--<h2> 😭 SadTalker: <span style="font-size:12px">Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation </span> </h2> -->
<a href='https://arxiv.org/abs/2211.12194'><img src='https://img.shields.io/badge/ArXiv-PDF-red'></a> <a href='https://sadtalker.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> [](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) [](https://huggingface.co/spaces/vinthony/SadTalker) [](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) <br> [](https://replicate.com/cjwbw/sadtalker) [](https://discord.gg/rrayYqZ4tf)
<div>
<a target='_blank'>Wenxuan Zhang <sup>*,1,2</sup> </a> 
<a href='https://vinthony.github.io/' target='_blank'>Xiaodong Cun <sup>*,2</a> 
<a href='https://xuanwangvc.github.io/' target='_blank'>Xuan Wang <sup>3</sup></a> 
<a href='https://yzhang2016.github.io/' target='_blank'>Yong Zhang <sup>2</sup></a> 
<a href='https://xishen0220.github.io/' target='_blank'>Xi Shen <sup>2</sup></a>  </br>
<a href='https://yuguo-xjtu.github.io/' target='_blank'>Yu Guo<sup>1</sup> </a> 
<a href='https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ' target='_blank'>Ying Shan <sup>2</sup> </a> 
<a target='_blank'>Fei Wang <sup>1</sup> </a> 
</div>
<br>
<div>
<sup>1</sup> Xi'an Jiaotong University   <sup>2</sup> Tencent AI Lab   <sup>3</sup> Ant Group  
</div>
<br>
<i><strong><a href='https://arxiv.org/abs/2211.12194' target='_blank'>CVPR 2023</a></strong></i>
<br>
<br>

<b>TL;DR: single portrait image 🙎♂️ + audio 🎤 = talking head video 🎞.</b>
<br>
</div>
## Highlights
- The license has been updated to Apache 2.0, and we've removed the non-commercial restriction
- **SadTalker has now officially been integrated into Discord, where you can use it for free by sending files. You can also generate high-quailty videos from text prompts. Join: [](https://discord.gg/rrayYqZ4tf)**
- We've published a [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) extension. Check out more details [here](docs/webui_extension.md). [Demo Video](https://user-images.githubusercontent.com/4397546/231495639-5d4bb925-ea64-4a36-a519-6389917dac29.mp4)
- Full image mode is now available! [More details...](https://github.com/OpenTalker/SadTalker#full-bodyimage-generation)
| still+enhancer in v0.0.1 | still + enhancer in v0.0.2 | [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) |
|:--------------------: |:--------------------: | :----: |
| <video src="https://user-images.githubusercontent.com/48216707/229484996-5d7be64f-2553-4c9e-a452-c5cf0b8ebafe.mp4" type="video/mp4"> </video> | <video src="https://user-images.githubusercontent.com/4397546/230717873-355b7bf3-d3de-49f9-a439-9220e623fce7.mp4" type="video/mp4"> </video> | <img src='./examples/source_image/full_body_2.png' width='380'>
- Several new modes (Still, reference, and resize modes) are now available!
- We're happy to see more community demos on [bilibili](https://search.bilibili.com/all?keyword=sadtalker), [YouTube](https://www.youtube.com/results?search_query=sadtalker) and [X (#sadtalker)](https://twitter.com/search?q=%23sadtalker&src).
## Changelog
The previous changelog can be found [here](docs/changlelog.md).
- __[2023.06.12]__: Added more new features in WebUI extension, see the discussion [here](https://github.com/OpenTalker/SadTalker/discussions/386).
- __[2023.06.05]__: Released a new 512x512px (beta) face model. Fixed some bugs and improve the performance.
- __[2023.04.15]__: Added a WebUI Colab notebook by [@camenduru](https://github.com/camenduru/): [](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb)
- __[2023.04.12]__: Added a more detailed WebUI installation document and fixed a problem when reinstalling.
- __[2023.04.12]__: Fixed the WebUI safe issues becasue of 3rd-party packages, and optimized the output path in `sd-webui-extension`.
- __[2023.04.08]__: In v0.0.2, we added a logo watermark to the generated video to prevent abuse. _This watermark has since been removed in a later release._
- __[2023.04.08]__: In v0.0.2, we added features for full image animation and a link to download checkpoints from Baidu. We also optimized the enhancer logic.
## To-Do
We're tracking new updates in [issue #280](https://github.com/OpenTalker/SadTalker/issues/280).
## Troubleshooting
If you have any problems, please read our [FAQs](docs/FAQ.md) before opening an issue.
## 1. Installation.
Community tutorials: [中文Windows教程 (Chinese Windows tutorial)](https://www.bilibili.com/video/BV1Dc411W7V6/) | [日本語コース (Japanese tutorial)](https://br-d.fanbox.cc/posts/5685086).
### Linux/Unix
1. Install [Anaconda](https://www.anaconda.com/), Python and `git`.
2. Creating the env and install the requirements.
```bash
git clone https://github.com/OpenTalker/SadTalker.git
cd SadTalker
conda create -n sadtalker python=3.8
conda activate sadtalker
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
conda install ffmpeg
pip install -r requirements.txt
### Coqui TTS is optional for gradio demo.
### pip install TTS
```
### Windows
A video tutorial in chinese is available [here](https://www.bilibili.com/video/BV1Dc411W7V6/). You can also follow the following instructions:
1. Install [Python 3.8](https://www.python.org/downloads/windows/) and check "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win) manually or using [Scoop](https://scoop.sh/): `scoop install git`.
3. Install `ffmpeg`, following [this tutorial](https://www.wikihow.com/Install-FFmpeg-on-Windows) or using [scoop](https://scoop.sh/): `scoop install ffmpeg`.
4. Download the SadTalker repository by running `git clone https://github.com/Winfredy/SadTalker.git`.
5. Download the checkpoints and gfpgan models in the [downloads section](#2-download-models).
6. Run `start.bat` from Windows Explorer as normal, non-administrator, user, and a Gradio-powered WebUI demo will be started.
### macOS
A tutorial on installing SadTalker on macOS can be found [here](docs/install.md).
### Docker, WSL, etc
Please check out additional tutorials [here](docs/install.md).
## 2. Download Models
You can run the following script on Linux/macOS to automatically download all the models:
```bash
bash scripts/download_models.sh
```
We also provide an offline patch (`gfpgan/`), so no model will be downloaded when generating.
### Pre-Trained Models
* [Google Drive](https://drive.google.com/file/d/1gwWh45pF7aelNP_P78uDJL8Sycep-K7j/view?usp=sharing)
* [GitHub Releases](https://github.com/OpenTalker/SadTalker/releases)
* [Baidu (百度云盘)](https://pan.baidu.com/s/1kb1BCPaLOWX1JJb9Czbn6w?pwd=sadt) (Password: `sadt`)
<!-- TODO add Hugging Face links -->
### GFPGAN Offline Patch
* [Google Drive](https://drive.google.com/file/d/19AIBsmfcHW6BRJmeqSFlG5fL445Xmsyi?usp=sharing)
* [GitHub Releases](https://github.com/OpenTalker/SadTalker/releases)
* [Baidu (百度云盘)](https://pan.baidu.com/s/1P4fRgk9gaSutZnn8YW034Q?pwd=sadt) (Password: `sadt`)
<!-- TODO add Hugging Face links -->
<details><summary>Model Details</summary>
Model explains:
##### New version
| Model | Description
| :--- | :----------
|checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/SadTalker_V0.0.2_256.safetensors | packaged sadtalker checkpoints of old version, 256 face render).
|checkpoints/SadTalker_V0.0.2_512.safetensors | packaged sadtalker checkpoints of old version, 512 face render).
|gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`.
##### Old version
| Model | Description
| :--- | :----------
|checkpoints/auido2exp_00300-model.pth | Pre-trained ExpNet in Sadtalker.
|checkpoints/auido2pose_00140-model.pth | Pre-trained PoseVAE in Sadtalker.
|checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/facevid2vid_00189-model.pth.tar | Pre-trained face-vid2vid model from [the reappearance of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis).
|checkpoints/epoch_20.pth | Pre-trained 3DMM extractor in [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction).
|checkpoints/wav2lip.pth | Highly accurate lip-sync model in [Wav2lip](https://github.com/Rudrabha/Wav2Lip).
|checkpoints/shape_predictor_68_face_landmarks.dat | Face landmark model used in [dilb](http://dlib.net/).
|checkpoints/BFM | 3DMM library file.
|checkpoints/hub | Face detection models used in [face alignment](https://github.com/1adrianb/face-alignment).
|gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`.
The final folder will be shown as:
<img width="331" alt="image" src="https://user-images.githubusercontent.com/4397546/232511411-4ca75cbf-a434-48c5-9ae0-9009e8316484.png">
</details>
## 3. Quick Start
Please read our document on [best practices and configuration tips](docs/best_practice.md)
### WebUI Demos
**Online Demo**: [HuggingFace](https://huggingface.co/spaces/vinthony/SadTalker) | [SDWebUI-Colab](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) | [Colab](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)
**Local WebUI extension**: Please refer to [WebUI docs](docs/webui_extension.md).
**Local gradio demo (recommanded)**: A Gradio instance similar to our [Hugging Face demo](https://huggingface.co/spaces/vinthony/SadTalker) can be run locally:
```bash
## you need manually install TTS(https://github.com/coqui-ai/TTS) via `pip install tts` in advanced.
python app_sadtalker.py
```
You can also start it more easily:
- windows: just double click `webui.bat`, the requirements will be installed automatically.
- Linux/Mac OS: run `bash webui.sh` to start the webui.
### CLI usage
##### Animating a portrait image from default config:
```bash
python inference.py --driven_audio <audio.wav> \
--source_image <video.mp4 or picture.png> \
--enhancer gfpgan
```
The results will be saved in `results/$SOME_TIMESTAMP/*.mp4`.
##### Full body/image Generation:
Using `--still` to generate a natural full body video. You can add `enhancer` to improve the quality of the generated video.
```bash
python inference.py --driven_audio <audio.wav> \
--source_image <video.mp4 or picture.png> \
--result_dir <a file to store results> \
--still \
--preprocess full \
--enhancer gfpgan
```
More examples and configuration and tips can be founded in the [ >>> best practice documents <<<](docs/best_practice.md).
## Citation
If you find our work useful in your research, please consider citing:
```bibtex
@article{zhang2022sadtalker,
title={SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation},
author={Zhang, Wenxuan and Cun, Xiaodong and Wang, Xuan and Zhang, Yong and Shen, Xi and Guo, Yu and Shan, Ying and Wang, Fei},
journal={arXiv preprint arXiv:2211.12194},
year={2022}
}
```
## Acknowledgements
Facerender code borrows heavily from [zhanglonghao's reproduction of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis) and [PIRender](https://github.com/RenYurui/PIRender). We thank the authors for sharing their wonderful code. In training process, we also used the model from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction) and [Wav2lip](https://github.com/Rudrabha/Wav2Lip). We thank for their wonderful work.
We also use the following 3rd-party libraries:
- **Face Utils**: https://github.com/xinntao/facexlib
- **Face Enhancement**: https://github.com/TencentARC/GFPGAN
- **Image/Video Enhancement**:https://github.com/xinntao/Real-ESRGAN
## Extensions:
- [SadTalker-Video-Lip-Sync](https://github.com/Zz-ww/SadTalker-Video-Lip-Sync) from [@Zz-ww](https://github.com/Zz-ww): SadTalker for Video Lip Editing
## Related Works
- [StyleHEAT: One-Shot High-Resolution Editable Talking Face Generation via Pre-trained StyleGAN (ECCV 2022)](https://github.com/FeiiYin/StyleHEAT)
- [CodeTalker: Speech-Driven 3D Facial Animation with Discrete Motion Prior (CVPR 2023)](https://github.com/Doubiiu/CodeTalker)
- [VideoReTalking: Audio-based Lip Synchronization for Talking Head Video Editing In the Wild (SIGGRAPH Asia 2022)](https://github.com/vinthony/video-retalking)
- [DPE: Disentanglement of Pose and Expression for General Video Portrait Editing (CVPR 2023)](https://github.com/Carlyx/DPE)
- [3D GAN Inversion with Facial Symmetry Prior (CVPR 2023)](https://github.com/FeiiYin/SPI/)
- [T2M-GPT: Generating Human Motion from Textual Descriptions with Discrete Representations (CVPR 2023)](https://github.com/Mael-zys/T2M-GPT)
## Disclaimer
This is not an official product of Tencent.
```
1. Please carefully read and comply with the open-source license applicable to this code before using it.
2. Please carefully read and comply with the intellectual property declaration applicable to this code before using it.
3. This open-source code runs completely offline and does not collect any personal information or other data. If you use this code to provide services to end-users and collect related data, please take necessary compliance measures according to applicable laws and regulations (such as publishing privacy policies, adopting necessary data security strategies, etc.). If the collected data involves personal information, user consent must be obtained (if applicable). Any legal liabilities arising from this are unrelated to Tencent.
4. Without Tencent's written permission, you are not authorized to use the names or logos legally owned by Tencent, such as "Tencent." Otherwise, you may be liable for legal responsibilities.
5. This open-source code does not have the ability to directly provide services to end-users. If you need to use this code for further model training or demos, as part of your product to provide services to end-users, or for similar use, please comply with applicable laws and regulations for your product or service. Any legal liabilities arising from this are unrelated to Tencent.
6. It is prohibited to use this open-source code for activities that harm the legitimate rights and interests of others (including but not limited to fraud, deception, infringement of others' portrait rights, reputation rights, etc.), or other behaviors that violate applicable laws and regulations or go against social ethics and good customs (including providing incorrect or false information, spreading pornographic, terrorist, and violent information, etc.). Otherwise, you may be liable for legal responsibilities.
```
LOGO: color and font suggestion: [ChatGPT](https://chat.openai.com), logo font: [Montserrat Alternates
](https://fonts.google.com/specimen/Montserrat+Alternates?preview.text=SadTalker&preview.text_type=custom&query=mont).
All the copyrights of the demo images and audio are from community users or the generation from stable diffusion. Feel free to contact us if you would like use to remove them.
<!-- Spelling fixed on Tuesday, September 12, 2023 by @fakerybakery (https://github.com/fakerybakery). These changes are licensed under the Apache 2.0 license. -->
================================================
FILE: app_sadtalker.py
================================================
import os, sys
import gradio as gr
from src.gradio_demo import SadTalker
try:
import webui # in webui
in_webui = True
except:
in_webui = False
def toggle_audio_file(choice):
if choice == False:
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
def ref_video_fn(path_of_ref_video):
if path_of_ref_video is not None:
return gr.update(value=True)
else:
return gr.update(value=False)
def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None):
sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)
with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
<a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> \
<a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> \
<a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_source_image"):
with gr.TabItem('Upload image'):
with gr.Row():
source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
with gr.Tabs(elem_id="sadtalker_driven_audio"):
with gr.TabItem('Upload OR TTS'):
with gr.Column(variant='panel'):
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
if sys.platform != 'win32' and not in_webui:
from src.utils.text2speech import TTSTalker
tts_talker = TTSTalker()
with gr.Column(variant='panel'):
input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_checkbox"):
with gr.TabItem('Settings'):
gr.Markdown("need help? please visit our [best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md) for more detials")
with gr.Column(variant='panel'):
# width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
# height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) #
size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") #
preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2)
enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
with gr.Tabs(elem_id="sadtalker_genearted"):
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
if warpfn:
submit.click(
fn=warpfn(sad_talker.test),
inputs=[source_image,
driven_audio,
preprocess_type,
is_still_mode,
enhancer,
batch_size,
size_of_image,
pose_style
],
outputs=[gen_video]
)
else:
submit.click(
fn=sad_talker.test,
inputs=[source_image,
driven_audio,
preprocess_type,
is_still_mode,
enhancer,
batch_size,
size_of_image,
pose_style
],
outputs=[gen_video]
)
return sadtalker_interface
if __name__ == "__main__":
demo = sadtalker_demo()
demo.queue()
demo.launch()
================================================
FILE: cog.yaml
================================================
build:
gpu: true
cuda: "11.3"
python_version: "3.8"
system_packages:
- "ffmpeg"
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "torch==1.12.1"
- "torchvision==0.13.1"
- "torchaudio==0.12.1"
- "joblib==1.1.0"
- "scikit-image==0.19.3"
- "basicsr==1.4.2"
- "facexlib==0.3.0"
- "resampy==0.3.1"
- "pydub==0.25.1"
- "scipy==1.10.1"
- "kornia==0.6.8"
- "face_alignment==1.3.5"
- "imageio==2.19.3"
- "imageio-ffmpeg==0.4.7"
- "librosa==0.9.2" #
- "tqdm==4.65.0"
- "yacs==0.1.8"
- "gfpgan==1.3.8"
- "dlib-bin==19.24.1"
- "av==10.0.0"
- "trimesh==3.9.20"
run:
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth"
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip"
predict: "predict.py:Predictor"
================================================
FILE: docs/FAQ.md
================================================
## Frequency Asked Question
**Q: `ffmpeg` is not recognized as an internal or external command**
In Linux, you can install the ffmpeg via `conda install ffmpeg`. Or on Mac OS X, try to install ffmpeg via `brew install ffmpeg`. On windows, make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), then, following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) installation to install `ffmpeg`.
**Q: Running Requirments.**
Please refer to the discussion here: https://github.com/Winfredy/SadTalker/issues/124#issuecomment-1508113989
**Q: ModuleNotFoundError: No module named 'ai'**
please check the checkpoint's size of the `epoch_20.pth`. (https://github.com/Winfredy/SadTalker/issues/167, https://github.com/Winfredy/SadTalker/issues/113)
**Q: Illegal Hardware Error: Mac M1**
please reinstall the `dlib` by `pip install dlib` individually. (https://github.com/Winfredy/SadTalker/issues/129, https://github.com/Winfredy/SadTalker/issues/109)
**Q: FileNotFoundError: [Errno 2] No such file or directory: checkpoints\BFM_Fitting\similarity_Lm3D_all.mat**
Make sure you have downloaded the checkpoints and gfpgan as [here](https://github.com/Winfredy/SadTalker#-2-download-trained-models) and placed them in the right place.
**Q: RuntimeError: unexpected EOF, expected 237192 more bytes. The file might be corrupted.**
The files are not automatically downloaded. Please update the code and download the gfpgan folders as [here](https://github.com/Winfredy/SadTalker#-2-download-trained-models).
**Q: CUDA out of memory error**
please refer to https://stackoverflow.com/questions/73747731/runtimeerror-cuda-out-of-memory-how-setting-max-split-size-mb
```
# windows
set PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
python inference.py ...
# linux
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
python inference.py ...
```
**Q: Error while decoding stream #0:0: Invalid data found when processing input [mp3float @ 0000015037628c00] Header missing**
Our method only support wav or mp3 files as input, please make sure the feeded audios are in these formats.
================================================
FILE: docs/best_practice.md
================================================
# Best Practices and Tips for configuration
> Our model only works on REAL people or the portrait image similar to REAL person. The anime talking head genreation method will be released in future.
Advanced confiuration options for `inference.py`:
| Name | Configuration | default | Explaination |
|:------------- |:------------- |:----- | :------------- |
| Enhance Mode | `--enhancer` | None | Using `gfpgan` or `RestoreFormer` to enhance the generated face via face restoration network
| Background Enhancer | `--background_enhancer` | None | Using `realesrgan` to enhance the full video.
| Still Mode | ` --still` | False | Using the same pose parameters as the original image, fewer head motion.
| Expressive Mode | `--expression_scale` | 1.0 | a larger value will make the expression motion stronger.
| save path | `--result_dir` |`./results` | The file will be save in the newer location.
| preprocess | `--preprocess` | `crop` | Run and produce the results in the croped input image. Other choices: `resize`, where the images will be resized to the specific resolution. `full` Run the full image animation, use with `--still` to get better results.
| ref Mode (eye) | `--ref_eyeblink` | None | A video path, where we borrow the eyeblink from this reference video to provide more natural eyebrow movement.
| ref Mode (pose) | `--ref_pose` | None | A video path, where we borrow the pose from the head reference video.
| 3D Mode | `--face3dvis` | False | Need additional installation. More details to generate the 3d face can be founded [here](docs/face3d.md).
| free-view Mode | `--input_yaw`,<br> `--input_pitch`,<br> `--input_roll` | None | Genearting novel view or free-view 4D talking head from a single image. More details can be founded [here](https://github.com/Winfredy/SadTalker#generating-4d-free-view-talking-examples-from-audio-and-a-single-image).
### About `--preprocess`
Our system automatically handles the input images via `crop`, `resize` and `full`.
In `crop` mode, we only generate the croped image via the facial keypoints and generated the facial anime avator. The animation of both expression and head pose are realistic.
> Still mode will stop the eyeblink and head pose movement.
| [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) | crop | crop w/still |
|:--------------------: |:--------------------: | :----: |
| <img src='../examples/source_image/full_body_2.png' width='380'> |  |  |
In `resize` mode, we resize the whole images to generate the fully talking head video. Thus, an image similar to the ID photo can be produced. ⚠️ It will produce bad results for full person images.
| <img src='../examples/source_image/full_body_2.png' width='380'> | <img src='../examples/source_image/full4.jpeg' width='380'> |
|:--------------------: |:--------------------: |
| ❌ not suitable for resize mode | ✅ good for resize mode |
| <img src='resize_no.gif'> | <img src='resize_good.gif' width='380'> |
In `full` mode, our model will automatically process the croped region and paste back to the original image. Remember to use `--still` to keep the original head pose.
| input | `--still` | `--still` & `enhancer` |
|:--------------------: |:--------------------: | :--:|
| <img src='../examples/source_image/full_body_2.png' width='380'> | <img src='./example_full.gif' width='380'> | <img src='./example_full_enhanced.gif' width='380'>
### About `--enhancer`
For higher resolution, we intergate [gfpgan](https://github.com/TencentARC/GFPGAN) and [real-esrgan](https://github.com/xinntao/Real-ESRGAN) for different purpose. Just adding `--enhancer <gfpgan or RestoreFormer>` or `--background_enhancer <realesrgan>` for the enhancement of the face and the full image.
```bash
# make sure above packages are available:
pip install gfpgan
pip install realesrgan
```
### About `--face3dvis`
This flag indicate that we can generated the 3d-rendered face and it's 3d facial landmarks. More details can be founded [here](face3d.md).
| Input | Animated 3d face |
|:-------------: | :-------------: |
| <img src='../examples/source_image/art_0.png' width='200px'> | <video src="https://user-images.githubusercontent.com/4397546/226856847-5a6a0a4d-a5ec-49e2-9b05-3206db65e8e3.mp4"></video> |
> Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.
#### Reference eye-link mode.
| Input, w/ reference video , reference video |
|:-------------: |
| |
| If the reference video is shorter than the input audio, we will loop the reference video .
#### Generating 4D free-view talking examples from audio and a single image
We use `input_yaw`, `input_pitch`, `input_roll` to control head pose. For example, `--input_yaw -20 30 10` means the input head yaw degree changes from -20 to 30 and then changes from 30 to 10.
```bash
python inference.py --driven_audio <audio.wav> \
--source_image <video.mp4 or picture.png> \
--result_dir <a file to store results> \
--input_yaw -20 30 10
```
| Results, Free-view results, Novel view results |
|:-------------: |
| |
================================================
FILE: docs/changlelog.md
================================================
## changelogs
- __[2023.04.06]__: stable-diffiusion webui extension is release.
- __[2023.04.03]__: Enable TTS in huggingface and gradio local demo.
- __[2023.03.30]__: Launch beta version of the full body mode.
- __[2023.03.30]__: Launch new feature: through using reference videos, our algorithm can generate videos with more natural eye blinking and some eyebrow movement.
- __[2023.03.29]__: `resize mode` is online by `python infererence.py --preprocess resize`! Where we can produce a larger crop of the image as discussed in https://github.com/Winfredy/SadTalker/issues/35.
- __[2023.03.29]__: local gradio demo is online! `python app.py` to start the demo. New `requirments.txt` is used to avoid the bugs in `librosa`.
- __[2023.03.28]__: Online demo is launched in [](https://huggingface.co/spaces/vinthony/SadTalker), thanks AK!
- __[2023.03.22]__: Launch new feature: generating the 3d face animation from a single image. New applications about it will be updated.
- __[2023.03.22]__: Launch new feature: `still mode`, where only a small head pose will be produced via `python inference.py --still`.
- __[2023.03.18]__: Support `expression intensity`, now you can change the intensity of the generated motion: `python inference.py --expression_scale 1.3 (some value > 1)`.
- __[2023.03.18]__: Reconfig the data folders, now you can download the checkpoint automatically using `bash scripts/download_models.sh`.
- __[2023.03.18]__: We have offically integrate the [GFPGAN](https://github.com/TencentARC/GFPGAN) for face enhancement, using `python inference.py --enhancer gfpgan` for better visualization performance.
- __[2023.03.14]__: Specify the version of package `joblib` to remove the errors in using `librosa`, [](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) is online!
- __[2023.03.06]__: Solve some bugs in code and errors in installation
- __[2023.03.03]__: Release the test code for audio-driven single image animation!
- __[2023.02.28]__: SadTalker has been accepted by CVPR 2023!
================================================
FILE: docs/face3d.md
================================================
## 3D Face Visualization
We use `pytorch3d` to visualize the 3D faces from a single image.
The requirements for 3D visualization are difficult to install, so here's a tutorial:
```bash
git clone https://github.com/OpenTalker/SadTalker.git
cd SadTalker
conda create -n sadtalker3d python=3.8
source activate sadtalker3d
conda install ffmpeg
conda install -c fvcore -c iopath -c conda-forge fvcore iopath
conda install libgcc gmp
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
# insintall pytorch3d
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html
pip install -r requirements3d.txt
### install gpfgan for enhancer
pip install git+https://github.com/TencentARC/GFPGAN
### when occurs gcc version problem `from pytorch import _C` from pytorch3d, add the anaconda path to LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/$YOUR_ANACONDA_PATH/lib/
```
Then, generate the result via:
```bash
python inference.py --driven_audio <audio.wav> \
--source_image <video.mp4 or picture.png> \
--result_dir <a file to store results> \
--face3dvis
```
The results will appear, named `face3d.mp4`.
More applications about 3D face rendering will be released soon.
================================================
FILE: docs/install.md
================================================
### macOS
This method has been tested on a M1 Mac (13.3)
```bash
git clone https://github.com/OpenTalker/SadTalker.git
cd SadTalker
conda create -n sadtalker python=3.8
conda activate sadtalker
# install pytorch 2.0
pip install torch torchvision torchaudio
conda install ffmpeg
pip install -r requirements.txt
pip install dlib # macOS needs to install the original dlib.
```
### Windows Native
- Make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) tutorial to install `ffmpeg` or using scoop.
### Windows WSL
- Make sure the environment: `export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH`
### Docker Installation
A community Docker image by [@thegenerativegeneration](https://github.com/thegenerativegeneration) is available on the [Docker hub](https://hub.docker.com/repository/docker/wawa9000/sadtalker), which can be used directly:
```bash
docker run --gpus "all" --rm -v $(pwd):/host_dir wawa9000/sadtalker \
--driven_audio /host_dir/deyu.wav \
--source_image /host_dir/image.jpg \
--expression_scale 1.0 \
--still \
--result_dir /host_dir
```
================================================
FILE: docs/webui_extension.md
================================================
## Run SadTalker as a Stable Diffusion WebUI Extension.
1. Install the lastest version of [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and install SadTalker via `extension`.
<img width="726" alt="image" src="https://user-images.githubusercontent.com/4397546/230698519-267d1d1f-6e99-4dd4-81e1-7b889259efbd.png">
2. Download the checkpoints manually, for Linux and Mac:
```bash
cd SOMEWHERE_YOU_LIKE
bash <(wget -qO- https://raw.githubusercontent.com/Winfredy/OpenTalker/main/scripts/download_models.sh)
```
For Windows, you can download all the checkpoints [here](https://github.com/OpenTalker/SadTalker/tree/main#2-download-models).
3.1. Option 1: put the checkpoint in `stable-diffusion-webui/models/SadTalker` or `stable-diffusion-webui/extensions/SadTalker/checkpoints/`, the checkpoints will be detected automatically.
3.2. Option 2: Set the path of `SADTALKTER_CHECKPOINTS` in `webui_user.sh`(linux) or `webui_user.bat`(windows) by:
> only works if you are directly starting webui from `webui_user.sh` or `webui_user.bat`.
```bash
# Windows (webui_user.bat)
set SADTALKER_CHECKPOINTS=D:\SadTalker\checkpoints
# Linux/macOS (webui_user.sh)
export SADTALKER_CHECKPOINTS=/path/to/SadTalker/checkpoints
```
4. Start the WebUI via `webui.sh or webui_user.sh(linux)` or `webui_user.bat(windows)` or any other method. SadTalker can also be used in stable-diffusion-webui directly.
<img width="726" alt="image" src="https://user-images.githubusercontent.com/4397546/230698614-58015182-2916-4240-b324-e69022ef75b3.png">
## Questions
1. if you are running on CPU, you need to specific `--disable-safe-unpickle` in `webui_user.sh` or `webui_user.bat`.
```bash
# windows (webui_user.bat)
set COMMANDLINE_ARGS="--disable-safe-unpickle"
# linux (webui_user.sh)
export COMMANDLINE_ARGS="--disable-safe-unpickle"
```
(If you're unable to use the `full` mode, please read this [discussion](https://github.com/Winfredy/SadTalker/issues/78).)
================================================
FILE: inference.py
================================================
from glob import glob
import shutil
import torch
from time import strftime
import os, sys, time
from argparse import ArgumentParser
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path
def main(args):
#torch.backends.cudnn.enabled = False
pic_path = args.source_image
audio_path = args.driven_audio
save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
os.makedirs(save_dir, exist_ok=True)
pose_style = args.pose_style
device = args.device
batch_size = args.batch_size
input_yaw_list = args.input_yaw
input_pitch_list = args.input_pitch
input_roll_list = args.input_roll
ref_eyeblink = args.ref_eyeblink
ref_pose = args.ref_pose
current_root_path = os.path.split(sys.argv[0])[0]
sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
#init model
preprocess_model = CropAndExtract(sadtalker_paths, device)
audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
print('3DMM Extraction for source image')
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
source_image_flag=True, pic_size=args.size)
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
if ref_eyeblink is not None:
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing eye blinking')
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_eyeblink_coeff_path=None
if ref_pose is not None:
if ref_pose == ref_eyeblink:
ref_pose_coeff_path = ref_eyeblink_coeff_path
else:
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
os.makedirs(ref_pose_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing pose')
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_pose_coeff_path=None
#audio2ceoff
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
# 3dface render
if args.face3dvis:
from src.face3d.visualize import gen_composed_video
gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
#coeff2video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
shutil.move(result, save_dir+'.mp4')
print('The generated video is named:', save_dir+'.mp4')
if not args.verbose:
shutil.rmtree(save_dir)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
parser.add_argument("--cpu", dest="cpu", action="store_true")
parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
parser.add_argument('--init_path', type=str, default=None, help='Useless')
parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
# default renderer parameters
parser.add_argument('--focal', type=float, default=1015.)
parser.add_argument('--center', type=float, default=112.)
parser.add_argument('--camera_d', type=float, default=10.)
parser.add_argument('--z_near', type=float, default=5.)
parser.add_argument('--z_far', type=float, default=15.)
args = parser.parse_args()
if torch.cuda.is_available() and not args.cpu:
args.device = "cuda"
else:
args.device = "cpu"
main(args)
================================================
FILE: launcher.py
================================================
# this scripts installs necessary requirements and launches main program in webui.py
# borrow from : https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/launch.py
import subprocess
import os
import sys
import importlib.util
import shlex
import platform
import json
python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
skip_install = False
dir_repos = "repositories"
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
def check_python_version():
is_windows = platform.system() == "Windows"
major = sys.version_info.major
minor = sys.version_info.minor
micro = sys.version_info.micro
if is_windows:
supported_minors = [10]
else:
supported_minors = [7, 8, 9, 10, 11]
if not (major == 3 and minor in supported_minors):
raise (f"""
INCOMPATIBLE PYTHON VERSION
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
If you encounter an error with "RuntimeError: Couldn't install torch." message,
or any other error regarding unsuccessful package (library) installation,
please downgrade (or upgrade) to the latest version of 3.10 Python
and delete current Python and "venv" folder in WebUI's directory.
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
Use --skip-python-version-check to suppress this warning.
""")
def commit_hash():
global stored_commit_hash
if stored_commit_hash is not None:
return stored_commit_hash
try:
stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
except Exception:
stored_commit_hash = "<none>"
return stored_commit_hash
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None:
print(desc)
if live:
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}""")
return ""
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
message = f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
"""
raise RuntimeError(message)
return result.stdout.decode(encoding="utf8", errors="ignore")
def check_run(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
return result.returncode == 0
def is_installed(package):
try:
spec = importlib.util.find_spec(package)
except ModuleNotFoundError:
return False
return spec is not None
def repo_dir(name):
return os.path.join(script_path, dir_repos, name)
def run_python(code, desc=None, errdesc=None):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(args, desc=None):
if skip_install:
return
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
def check_run_python(code):
return check_run(f'"{python}" -c "{code}"')
def git_clone(url, dir, name, commithash=None):
# TODO clone into temporary dir and move if successful
if os.path.exists(dir):
if commithash is None:
return
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
if current_hash == commithash:
return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def git_pull_recursive(dir):
for subdir, _, _ in os.walk(dir):
if os.path.exists(os.path.join(subdir, '.git')):
try:
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
except subprocess.CalledProcessError as e:
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
def run_extension_installer(extension_dir):
path_installer = os.path.join(extension_dir, "install.py")
if not os.path.isfile(path_installer):
return
try:
env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".")
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
print(e, file=sys.stderr)
def prepare_environment():
global skip_install
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113")
## check windows
if sys.platform != 'win32':
requirements_file = os.environ.get('REQS_FILE', "req.txt")
else:
requirements_file = os.environ.get('REQS_FILE', "requirements.txt")
commit = commit_hash()
print(f"Python {sys.version}")
print(f"Commit hash: {commit}")
if not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)")
if sys.platform != 'win32' and not is_installed('tts'):
run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.")
def start():
print(f"Launching SadTalker Web UI")
from app_sadtalker import sadtalker_demo
demo = sadtalker_demo()
demo.queue()
demo.launch()
if __name__ == "__main__":
prepare_environment()
start()
================================================
FILE: predict.py
================================================
"""run bash scripts/download_models.sh first to prepare the weights file"""
import os
import shutil
from argparse import Namespace
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path
from cog import BasePredictor, Input, Path
checkpoints = "checkpoints"
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
device = "cuda"
sadtalker_paths = init_path(checkpoints,os.path.join("src","config"))
# init model
self.preprocess_model = CropAndExtract(sadtalker_paths, device
)
self.audio_to_coeff = Audio2Coeff(
sadtalker_paths,
device,
)
self.animate_from_coeff = {
"full": AnimateFromCoeff(
sadtalker_paths,
device,
),
"others": AnimateFromCoeff(
sadtalker_paths,
device,
),
}
def predict(
self,
source_image: Path = Input(
description="Upload the source image, it can be video.mp4 or picture.png",
),
driven_audio: Path = Input(
description="Upload the driven audio, accepts .wav and .mp4 file",
),
enhancer: str = Input(
description="Choose a face enhancer",
choices=["gfpgan", "RestoreFormer"],
default="gfpgan",
),
preprocess: str = Input(
description="how to preprocess the images",
choices=["crop", "resize", "full"],
default="full",
),
ref_eyeblink: Path = Input(
description="path to reference video providing eye blinking",
default=None,
),
ref_pose: Path = Input(
description="path to reference video providing pose",
default=None,
),
still: bool = Input(
description="can crop back to the original videos for the full body aniamtion when preprocess is full",
default=True,
),
) -> Path:
"""Run a single prediction on the model"""
animate_from_coeff = (
self.animate_from_coeff["full"]
if preprocess == "full"
else self.animate_from_coeff["others"]
)
args = load_default()
args.pic_path = str(source_image)
args.audio_path = str(driven_audio)
device = "cuda"
args.still = still
args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
args.ref_pose = None if ref_pose is None else str(ref_pose)
# crop image and extract 3dmm from image
results_dir = "results"
if os.path.exists(results_dir):
shutil.rmtree(results_dir)
os.makedirs(results_dir)
first_frame_dir = os.path.join(results_dir, "first_frame_dir")
os.makedirs(first_frame_dir)
print("3DMM Extraction for source image")
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
args.pic_path, first_frame_dir, preprocess, source_image_flag=True
)
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
if ref_eyeblink is not None:
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
0
]
ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
print("3DMM Extraction for the reference video providing eye blinking")
ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
ref_eyeblink, ref_eyeblink_frame_dir
)
else:
ref_eyeblink_coeff_path = None
if ref_pose is not None:
if ref_pose == ref_eyeblink:
ref_pose_coeff_path = ref_eyeblink_coeff_path
else:
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
os.makedirs(ref_pose_frame_dir, exist_ok=True)
print("3DMM Extraction for the reference video providing pose")
ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
ref_pose, ref_pose_frame_dir
)
else:
ref_pose_coeff_path = None
# audio2ceoff
batch = get_data(
first_coeff_path,
args.audio_path,
device,
ref_eyeblink_coeff_path,
still=still,
)
coeff_path = self.audio_to_coeff.generate(
batch, results_dir, args.pose_style, ref_pose_coeff_path
)
# coeff2video
print("coeff2video")
data = get_facerender_data(
coeff_path,
crop_pic_path,
first_coeff_path,
args.audio_path,
args.batch_size,
args.input_yaw,
args.input_pitch,
args.input_roll,
expression_scale=args.expression_scale,
still_mode=still,
preprocess=preprocess,
)
animate_from_coeff.generate(
data, results_dir, args.pic_path, crop_info,
enhancer=enhancer, background_enhancer=args.background_enhancer,
preprocess=preprocess)
output = "/tmp/out.mp4"
mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
shutil.copy(mp4_path, output)
return Path(output)
def load_default():
return Namespace(
pose_style=0,
batch_size=2,
expression_scale=1.0,
input_yaw=None,
input_pitch=None,
input_roll=None,
background_enhancer=None,
face3dvis=False,
net_recon="resnet50",
init_path=None,
use_last_fc=False,
bfm_folder="./src/config/",
bfm_model="BFM_model_front.mat",
focal=1015.0,
center=112.0,
camera_d=10.0,
z_near=5.0,
z_far=15.0,
)
================================================
FILE: quick_demo.ipynb
================================================
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "M74Gs_TjYl_B"
},
"source": [
"[](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "view-in-github"
},
"source": [
"### SadTalker:Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation \n",
"\n",
"[arxiv](https://arxiv.org/abs/2211.12194) | [project](https://sadtalker.github.io) | [Github](https://github.com/Winfredy/SadTalker)\n",
"\n",
"Wenxuan Zhang, Xiaodong Cun, Xuan Wang, Yong Zhang, Xi Shen, Yu Guo, Ying Shan, Fei Wang.\n",
"\n",
"Xi'an Jiaotong University, Tencent AI Lab, Ant Group\n",
"\n",
"CVPR 2023\n",
"\n",
"TL;DR: A realistic and stylized talking head video generation method from a single image and audio\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "kA89DV-sKS4i"
},
"source": [
"Installation (around 5 mins)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qJ4CplXsYl_E"
},
"outputs": [],
"source": [
"### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\n",
"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mdq6j4E5KQAR"
},
"outputs": [],
"source": [
"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.8 2\n",
"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.9 1\n",
"!sudo apt install python3.8\n",
"\n",
"!sudo apt-get install python3.8-distutils\n",
"\n",
"!python --version\n",
"\n",
"!apt-get update\n",
"\n",
"!apt install software-properties-common\n",
"\n",
"!sudo dpkg --remove --force-remove-reinstreq python3-pip python3-setuptools python3-wheel\n",
"\n",
"!apt-get install python3-pip\n",
"\n",
"print('Git clone project and install requirements...')\n",
"!git clone https://github.com/Winfredy/SadTalker &> /dev/null\n",
"%cd SadTalker\n",
"!export PYTHONPATH=/content/SadTalker:$PYTHONPATH\n",
"!python3.8 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\n",
"!apt update\n",
"!apt install ffmpeg &> /dev/null\n",
"!python3.8 -m pip install -r requirements.txt"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "DddcKB_nKsnk"
},
"source": [
"Download models (1 mins)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eDw3_UN8K2xa"
},
"outputs": [],
"source": [
"print('Download pre-trained models...')\n",
"!rm -rf checkpoints\n",
"!bash scripts/download_models.sh"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kK7DYeo7Yl_H"
},
"outputs": [],
"source": [
"# borrow from makeittalk\n",
"import ipywidgets as widgets\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"print(\"Choose the image name to animate: (saved in folder 'examples/')\")\n",
"img_list = glob.glob1('examples/source_image', '*.png')\n",
"img_list.sort()\n",
"img_list = [item.split('.')[0] for item in img_list]\n",
"default_head_name = widgets.Dropdown(options=img_list, value='full3')\n",
"def on_change(change):\n",
" if change['type'] == 'change' and change['name'] == 'value':\n",
" plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
" plt.axis('off')\n",
" plt.show()\n",
"default_head_name.observe(on_change)\n",
"display(default_head_name)\n",
"plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "-khNZcnGK4UK"
},
"source": [
"Animation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ToBlDusjK5sS"
},
"outputs": [],
"source": [
"# selected audio from exmaple/driven_audio\n",
"img = 'examples/source_image/{}.png'.format(default_head_name.value)\n",
"print(img)\n",
"!python3.8 inference.py --driven_audio ./examples/driven_audio/RD_Radio31_000.wav \\\n",
" --source_image {img} \\\n",
" --result_dir ./results --still --preprocess full --enhancer gfpgan"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fAjwGmKKYl_I"
},
"outputs": [],
"source": [
"# visualize code from makeittalk\n",
"from IPython.display import HTML\n",
"from base64 import b64encode\n",
"import os, sys\n",
"\n",
"# get the last from results\n",
"\n",
"results = sorted(os.listdir('./results/'))\n",
"\n",
"mp4_name = glob.glob('./results/*.mp4')[0]\n",
"\n",
"mp4 = open('{}'.format(mp4_name),'rb').read()\n",
"data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
"\n",
"print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
"display(HTML(\"\"\"\n",
" <video width=256 controls>\n",
" <source src=\"%s\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\" % data_url))\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.7"
},
"vscode": {
"interpreter": {
"hash": "db5031b3636a3f037ea48eb287fd3d023feb9033aefc2a9652a92e470fb0851b"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: req.txt
================================================
llvmlite==0.38.1
numpy==1.21.6
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.10.0.post2
numba==0.55.1
resampy==0.3.1
pydub==0.25.1
scipy==1.10.1
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.3.0
gradio
gfpgan
av
safetensors
================================================
FILE: requirements.txt
================================================
numpy==1.23.4
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.9.2 #
numba
resampy==0.3.1
pydub==0.25.1
scipy==1.10.1
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.3.0
gradio
gfpgan
av
safetensors
================================================
FILE: requirements3d.txt
================================================
numpy==1.23.4
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.9.2 #
numba
resampy==0.3.1
pydub==0.25.1
scipy==1.5.3
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.3.0
trimesh==3.9.20
gradio
gfpgan
safetensors
================================================
FILE: scripts/download_models.sh
================================================
mkdir ./checkpoints
# lagency download link
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
# unzip -n ./checkpoints/hub.zip -d ./checkpoints/
#### download the new links.
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O ./checkpoints/SadTalker_V0.0.2_256.safetensors
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O ./checkpoints/SadTalker_V0.0.2_512.safetensors
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
# unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/
### enhancer
mkdir -p ./gfpgan/weights
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth
wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth
================================================
FILE: scripts/extension.py
================================================
import os, sys
from pathlib import Path
import tempfile
import gradio as gr
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
from modules.shared import opts, OptionInfo
from modules import shared, paths, script_callbacks
import launch
import glob
from huggingface_hub import snapshot_download
def check_all_files_safetensor(current_dir):
kv = {
"SadTalker_V0.0.2_256.safetensors": "sadtalker-256",
"SadTalker_V0.0.2_512.safetensors": "sadtalker-512",
"mapping_00109-model.pth.tar" : "mapping-109" ,
"mapping_00229-model.pth.tar" : "mapping-229" ,
}
if not os.path.isdir(current_dir):
return False
dirs = os.listdir(current_dir)
for f in dirs:
if f in kv.keys():
del kv[f]
return len(kv.keys()) == 0
def check_all_files(current_dir):
kv = {
"auido2exp_00300-model.pth": "audio2exp",
"auido2pose_00140-model.pth": "audio2pose",
"epoch_20.pth": "face_recon",
"facevid2vid_00189-model.pth.tar": "face-render",
"mapping_00109-model.pth.tar" : "mapping-109" ,
"mapping_00229-model.pth.tar" : "mapping-229" ,
"wav2lip.pth": "wav2lip",
"shape_predictor_68_face_landmarks.dat": "dlib",
}
if not os.path.isdir(current_dir):
return False
dirs = os.listdir(current_dir)
for f in dirs:
if f in kv.keys():
del kv[f]
return len(kv.keys()) == 0
def download_model(local_dir='./checkpoints'):
REPO_ID = 'vinthony/SadTalker'
snapshot_download(repo_id=REPO_ID, local_dir=local_dir, local_dir_use_symlinks=False)
def get_source_image(image):
return image
def get_img_from_txt2img(x):
talker_path = Path(paths.script_path) / "outputs"
imgs_from_txt_dir = str(talker_path / "txt2img-images/")
imgs = glob.glob(imgs_from_txt_dir+'/*/*.png')
imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x)))
img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1])
return img_from_txt_path, img_from_txt_path
def get_img_from_img2img(x):
talker_path = Path(paths.script_path) / "outputs"
imgs_from_img_dir = str(talker_path / "img2img-images/")
imgs = glob.glob(imgs_from_img_dir+'/*/*.png')
imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x)))
img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1])
return img_from_img_path, img_from_img_path
def get_default_checkpoint_path():
# check the path of models/checkpoints and extensions/
checkpoint_path = Path(paths.script_path) / "models"/ "SadTalker"
extension_checkpoint_path = Path(paths.script_path) / "extensions"/ "SadTalker" / "checkpoints"
if check_all_files_safetensor(checkpoint_path):
# print('founding sadtalker checkpoint in ' + str(checkpoint_path))
return checkpoint_path
if check_all_files_safetensor(extension_checkpoint_path):
# print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
return extension_checkpoint_path
if check_all_files(checkpoint_path):
# print('founding sadtalker checkpoint in ' + str(checkpoint_path))
return checkpoint_path
if check_all_files(extension_checkpoint_path):
# print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
return extension_checkpoint_path
return None
def install():
kv = {
"face_alignment": "face-alignment==1.3.5",
"imageio": "imageio==2.19.3",
"imageio_ffmpeg": "imageio-ffmpeg==0.4.7",
"librosa":"librosa==0.8.0",
"pydub":"pydub==0.25.1",
"scipy":"scipy==1.8.1",
"tqdm": "tqdm",
"yacs":"yacs==0.1.8",
"yaml": "pyyaml",
"av":"av",
"gfpgan": "gfpgan",
}
# # dlib is not necessary currently
# if 'darwin' in sys.platform:
# kv['dlib'] = "dlib"
# else:
# kv['dlib'] = 'dlib-bin'
# #### we need to have a newer version of imageio for our method.
# launch.run_pip("install imageio==2.19.3", "requirements for SadTalker")
for k,v in kv.items():
if not launch.is_installed(k):
print(k, launch.is_installed(k))
launch.run_pip("install "+ v, "requirements for SadTalker")
if os.getenv('SADTALKER_CHECKPOINTS'):
print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))
elif get_default_checkpoint_path() is not None:
os.environ['SADTALKER_CHECKPOINTS'] = str(get_default_checkpoint_path())
else:
print(
""""
SadTalker will not support download all the files from hugging face, which will take a long time.
please manually set the SADTALKER_CHECKPOINTS in `webui_user.bat`(windows) or `webui_user.sh`(linux)
"""
)
# python = sys.executable
# launch.run(f'"{python}" -m pip uninstall -y huggingface_hub', live=True)
# launch.run(f'"{python}" -m pip install --upgrade git+https://github.com/huggingface/huggingface_hub@main', live=True)
# ### run the scripts to downlod models to correct localtion.
# # print('download models for SadTalker')
# # launch.run("cd " + paths.script_path+"/extensions/SadTalker && bash ./scripts/download_models.sh", live=True)
# # print('SadTalker is successfully installed!')
# download_model(paths.script_path+'/extensions/SadTalker/checkpoints')
def on_ui_tabs():
install()
sys.path.extend([paths.script_path+'/extensions/SadTalker'])
repo_dir = paths.script_path+'/extensions/SadTalker/'
result_dir = opts.sadtalker_result_dir
os.makedirs(result_dir, exist_ok=True)
from app_sadtalker import sadtalker_demo
if os.getenv('SADTALKER_CHECKPOINTS'):
checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')
else:
checkpoint_path = repo_dir+'checkpoints/'
audio_to_video = sadtalker_demo(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', warpfn = wrap_queued_call)
return [(audio_to_video, "SadTalker", "extension")]
def on_ui_settings():
talker_path = Path(paths.script_path) / "outputs"
section = ('extension', "SadTalker")
opts.add_option("sadtalker_result_dir", OptionInfo(str(talker_path / "SadTalker/"), "Path to save results of sadtalker", section=section))
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)
================================================
FILE: scripts/test.sh
================================================
# ### some test command before commit.
# python inference.py --preprocess crop --size 256
# python inference.py --preprocess crop --size 512
# python inference.py --preprocess extcrop --size 256
# python inference.py --preprocess extcrop --size 512
# python inference.py --preprocess resize --size 256
# python inference.py --preprocess resize --size 512
# python inference.py --preprocess full --size 256
# python inference.py --preprocess full --size 512
# python inference.py --preprocess extfull --size 256
# python inference.py --preprocess extfull --size 512
python inference.py --preprocess full --size 256 --enhancer gfpgan
python inference.py --preprocess full --size 512 --enhancer gfpgan
python inference.py --preprocess full --size 256 --enhancer gfpgan --still
python inference.py --preprocess full --size 512 --enhancer gfpgan --still
================================================
FILE: src/audio2exp_models/audio2exp.py
================================================
from tqdm import tqdm
import torch
from torch import nn
class Audio2Exp(nn.Module):
def __init__(self, netG, cfg, device, prepare_training_loss=False):
super(Audio2Exp, self).__init__()
self.cfg = cfg
self.device = device
self.netG = netG.to(device)
def test(self, batch):
mel_input = batch['indiv_mels'] # bs T 1 80 16
bs = mel_input.shape[0]
T = mel_input.shape[1]
exp_coeff_pred = []
for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
current_mel_input = mel_input[:,i:i+10]
#ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
ref = batch['ref'][:, :, :64][:, i:i+10]
ratio = batch['ratio_gt'][:, i:i+10] #bs T
audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
exp_coeff_pred += [curr_exp_coeff_pred]
# BS x T x 64
results_dict = {
'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
}
return results_dict
================================================
FILE: src/audio2exp_models/networks.py
================================================
import torch
import torch.nn.functional as F
from torch import nn
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
self.use_act = use_act
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
if self.use_act:
return self.act(out)
else:
return out
class SimpleWrapperV2(nn.Module):
def __init__(self) -> None:
super().__init__()
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
)
#### load the pre-trained audio_encoder
#self.audio_encoder = self.audio_encoder.to(device)
'''
wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
state_dict = self.audio_encoder.state_dict()
for k,v in wav2lip_state_dict.items():
if 'audio_encoder' in k:
print('init:', k)
state_dict[k.replace('module.audio_encoder.', '')] = v
self.audio_encoder.load_state_dict(state_dict)
'''
self.mapping1 = nn.Linear(512+64+1, 64)
#self.mapping2 = nn.Linear(30, 64)
#nn.init.constant_(self.mapping1.weight, 0.)
nn.init.constant_(self.mapping1.bias, 0.)
def forward(self, x, ref, ratio):
x = self.audio_encoder(x).view(x.size(0), -1)
ref_reshape = ref.reshape(x.size(0), -1)
ratio = ratio.reshape(x.size(0), -1)
y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
return out
================================================
FILE: src/audio2pose_models/audio2pose.py
================================================
import torch
from torch import nn
from src.audio2pose_models.cvae import CVAE
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
from src.audio2pose_models.audio_encoder import AudioEncoder
class Audio2Pose(nn.Module):
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
super().__init__()
self.cfg = cfg
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
self.device = device
self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
self.audio_encoder.eval()
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.netG = CVAE(cfg)
self.netD_motion = PoseSequenceDiscriminator(cfg)
def forward(self, x):
batch = {}
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
batch['class'] = x['class'].squeeze(0).cuda() # bs
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
# forward
audio_emb_list = []
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
batch['audio_emb'] = audio_emb
batch = self.netG(batch)
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
batch['pose_pred'] = pose_pred
batch['pose_gt'] = pose_gt
return batch
def test(self, x):
batch = {}
ref = x['ref'] #bs 1 70
batch['ref'] = x['ref'][:,0,-6:]
batch['class'] = x['class']
bs = ref.shape[0]
indiv_mels= x['indiv_mels'] # bs T 1 80 16
indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
num_frames = x['num_frames']
num_frames = int(num_frames) - 1
#
div = num_frames//self.seq_len
re = num_frames%self.seq_len
audio_emb_list = []
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
device=batch['ref'].device)]
for i in range(div):
z = torch.randn(bs, self.latent_dim).to(ref.device)
batch['z'] = z
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
batch['audio_emb'] = audio_emb
batch = self.netG.test(batch)
pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
if re != 0:
z = torch.randn(bs, self.latent_dim).to(ref.device)
batch['z'] = z
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
if audio_emb.shape[1] != self.seq_len:
pad_dim = self.seq_len-audio_emb.shape[1]
pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
batch['audio_emb'] = audio_emb
batch = self.netG.test(batch)
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
batch['pose_motion_pred'] = pose_motion_pred
pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
batch['pose_pred'] = pose_pred
return batch
================================================
FILE: src/audio2pose_models/audio_encoder.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act(out)
class AudioEncoder(nn.Module):
def __init__(self, wav2lip_checkpoint, device):
super(AudioEncoder, self).__init__()
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
#### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
# wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
# state_dict = self.audio_encoder.state_dict()
# for k,v in wav2lip_state_dict.items():
# if 'audio_encoder' in k:
# state_dict[k.replace('module.audio_encoder.', '')] = v
# self.audio_encoder.load_state_dict(state_dict)
def forward(self, audio_sequences):
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.size(0)
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
dim = audio_embedding.shape[1]
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
================================================
FILE: src/audio2pose_models/cvae.py
================================================
import torch
import torch.nn.functional as F
from torch import nn
from src.audio2pose_models.res_unet import ResUnet
def class2onehot(idx, class_num):
assert torch.max(idx).item() < class_num
onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
onehot.scatter_(1, idx, 1)
return onehot
class CVAE(nn.Module):
def __init__(self, cfg):
super().__init__()
encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
latent_size = cfg.MODEL.CVAE.LATENT_SIZE
num_classes = cfg.DATASET.NUM_CLASSES
audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
seq_len = cfg.MODEL.CVAE.SEQ_LEN
self.latent_size = latent_size
self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len)
self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, batch):
batch = self.encoder(batch)
mu = batch['mu']
logvar = batch['logvar']
z = self.reparameterize(mu, logvar)
batch['z'] = z
return self.decoder(batch)
def test(self, batch):
'''
class_id = batch['class']
z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
batch['z'] = z
'''
return self.decoder(batch)
class ENCODER(nn.Module):
def __init__(self, layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len):
super().__init__()
self.resunet = ResUnet()
self.num_classes = num_classes
self.seq_len = seq_len
self.MLP = nn.Sequential()
layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
self.MLP.add_module(
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
def forward(self, batch):
class_id = batch['class']
pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
ref = batch['ref'] #bs 6
bs = pose_motion_gt.shape[0]
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
#pose encode
pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
#audio mapping
print(audio_in.shape)
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
audio_out = audio_out.reshape(bs, -1)
class_bias = self.classbias[class_id] #bs latent_size
x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
x_out = self.MLP(x_in)
mu = self.linear_means(x_out)
logvar = self.linear_means(x_out) #bs latent_size
batch.update({'mu':mu, 'logvar':logvar})
return batch
class DECODER(nn.Module):
def __init__(self, layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len):
super().__init__()
self.resunet = ResUnet()
self.num_classes = num_classes
self.seq_len = seq_len
self.MLP = nn.Sequential()
input_size = latent_size + seq_len*audio_emb_out_size + 6
for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
self.MLP.add_module(
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
if i+1 < len(layer_sizes):
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
else:
self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
self.pose_linear = nn.Linear(6, 6)
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
def forward(self, batch):
z = batch['z'] #bs latent_size
bs = z.shape[0]
class_id = batch['class']
ref = batch['ref'] #bs 6
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
#print('audio_in: ', audio_in[:, :, :10])
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
#print('audio_out: ', audio_out[:, :, :10])
audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
class_bias = self.classbias[class_id] #bs latent_size
z = z + class_bias
x_in = torch.cat([ref, z, audio_out], dim=-1)
x_out = self.MLP(x_in) # bs layer_sizes[-1]
x_out = x_out.reshape((bs, self.seq_len, -1))
#print('x_out: ', x_out)
pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
batch.update({'pose_motion_pred':pose_motion_pred})
return batch
================================================
FILE: src/audio2pose_models/discriminator.py
================================================
import torch
import torch.nn.functional as F
from torch import nn
class ConvNormRelu(nn.Module):
def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
super().__init__()
if kernel_size is None:
if downsample:
kernel_size, stride, padding = 4, 2, 1
else:
kernel_size, stride, padding = 3, 1, 1
if conv_type == '2d':
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
)
if norm == 'BN':
self.norm = nn.BatchNorm2d(out_channels)
elif norm == 'IN':
self.norm = nn.InstanceNorm2d(out_channels)
else:
raise NotImplementedError
elif conv_type == '1d':
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
)
if norm == 'BN':
self.norm = nn.BatchNorm1d(out_channels)
elif norm == 'IN':
self.norm = nn.InstanceNorm1d(out_channels)
else:
raise NotImplementedError
nn.init.kaiming_normal_(self.conv.weight)
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
if isinstance(self.norm, nn.InstanceNorm1d):
x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
else:
x = self.norm(x)
x = self.act(x)
return x
class PoseSequenceDiscriminator(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
self.seq = nn.Sequential(
ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
)
def forward(self, x):
x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
x = self.seq(x)
x = x.squeeze(1)
return x
================================================
FILE: src/audio2pose_models/networks.py
================================================
import torch.nn as nn
import torch
class ResidualConv(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding):
super(ResidualConv, self).__init__()
self.conv_block = nn.Sequential(
nn.BatchNorm2d(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
),
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(output_dim),
)
def forward(self, x):
return self.conv_block(x) + self.conv_skip(x)
class Upsample(nn.Module):
def __init__(self, input_dim, output_dim, kernel, stride):
super(Upsample, self).__init__()
self.upsample = nn.ConvTranspose2d(
input_dim, output_dim, kernel_size=kernel, stride=stride
)
def forward(self, x):
return self.upsample(x)
class Squeeze_Excite_Block(nn.Module):
def __init__(self, channel, reduction=16):
super(Squeeze_Excite_Block, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ASPP(nn.Module):
def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
super(ASPP, self).__init__()
self.aspp_block1 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.aspp_block2 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.aspp_block3 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
self._init_weights()
def forward(self, x):
x1 = self.aspp_block1(x)
x2 = self.aspp_block2(x)
x3 = self.aspp_block3(x)
out = torch.cat([x1, x2, x3], dim=1)
return self.output(out)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Upsample_(nn.Module):
def __init__(self, scale=2):
super(Upsample_, self).__init__()
self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
def forward(self, x):
return self.upsample(x)
class AttentionBlock(nn.Module):
def __init__(self, input_encoder, input_decoder, output_dim):
super(AttentionBlock, self).__init__()
self.conv_encoder = nn.Sequential(
nn.BatchNorm2d(input_encoder),
nn.ReLU(),
nn.Conv2d(input_encoder, output_dim, 3, padding=1),
nn.MaxPool2d(2, 2),
)
self.conv_decoder = nn.Sequential(
nn.BatchNorm2d(input_decoder),
nn.ReLU(),
nn.Conv2d(input_decoder, output_dim, 3, padding=1),
)
self.conv_attn = nn.Sequential(
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, 1, 1),
)
def forward(self, x1, x2):
out = self.conv_encoder(x1) + self.conv_decoder(x2)
out = self.conv_attn(out)
return out * x2
================================================
FILE: src/audio2pose_models/res_unet.py
================================================
import torch
import torch.nn as nn
from src.audio2pose_models.networks import ResidualConv, Upsample
class ResUnet(nn.Module):
def __init__(self, channel=1, filters=[32, 64, 128, 256]):
super(ResUnet, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
nn.BatchNorm2d(filters[0]),
nn.ReLU(),
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
)
self.input_skip = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
)
self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
self.output_layer = nn.Sequential(
nn.Conv2d(filters[0], 1, 1, 1),
nn.Sigmoid(),
)
def forward(self, x):
# Encode
x1 = self.input_layer(x) + self.input_skip(x)
x2 = self.residual_conv_1(x1)
x3 = self.residual_conv_2(x2)
# Bridge
x4 = self.bridge(x3)
# Decode
x4 = self.upsample_1(x4)
x5 = torch.cat([x4, x3], dim=1)
x6 = self.up_residual_conv1(x5)
x6 = self.upsample_2(x6)
x7 = torch.cat([x6, x2], dim=1)
x8 = self.up_residual_conv2(x7)
x8 = self.upsample_3(x8)
x9 = torch.cat([x8, x1], dim=1)
x10 = self.up_residual_conv3(x9)
output = self.output_layer(x10)
return output
================================================
FILE: src/config/auido2exp.yaml
================================================
DATASET:
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
TRAIN_BATCH_SIZE: 32
EVAL_BATCH_SIZE: 32
EXP: True
EXP_DIM: 64
FRAME_LEN: 32
COEFF_LEN: 73
NUM_CLASSES: 46
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
DEBUG: True
NUM_REPEATS: 2
T: 40
MODEL:
FRAMEWORK: V2
AUDIOENCODER:
LEAKY_RELU: True
NORM: 'IN'
DISCRIMINATOR:
LEAKY_RELU: False
INPUT_CHANNELS: 6
CVAE:
AUDIO_EMB_IN_SIZE: 512
AUDIO_EMB_OUT_SIZE: 128
SEQ_LEN: 32
LATENT_SIZE: 256
ENCODER_LAYER_SIZES: [192, 1024]
DECODER_LAYER_SIZES: [1024, 192]
TRAIN:
MAX_EPOCH: 300
GENERATOR:
LR: 2.0e-5
DISCRIMINATOR:
LR: 1.0e-5
LOSS:
W_FEAT: 0
W_COEFF_EXP: 2
W_LM: 1.0e-2
W_LM_MOUTH: 0
W_REG: 0
W_SYNC: 0
W_COLOR: 0
W_EXPRESSION: 0
W_LIPREADING: 0.01
W_LIPREADING_VV: 0
W_EYE_BLINK: 4
TAG:
NAME: small_dataset
================================================
FILE: src/config/auido2pose.yaml
================================================
DATASET:
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
TRAIN_BATCH_SIZE: 64
EVAL_BATCH_SIZE: 1
EXP: True
EXP_DIM: 64
FRAME_LEN: 32
COEFF_LEN: 73
NUM_CLASSES: 46
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
DEBUG: True
MODEL:
AUDIOENCODER:
LEAKY_RELU: True
NORM: 'IN'
DISCRIMINATOR:
LEAKY_RELU: False
INPUT_CHANNELS: 6
CVAE:
AUDIO_EMB_IN_SIZE: 512
AUDIO_EMB_OUT_SIZE: 6
SEQ_LEN: 32
LATENT_SIZE: 64
ENCODER_LAYER_SIZES: [192, 128]
DECODER_LAYER_SIZES: [128, 192]
TRAIN:
MAX_EPOCH: 150
GENERATOR:
LR: 1.0e-4
DISCRIMINATOR:
LR: 1.0e-4
LOSS:
LAMBDA_REG: 1
LAMBDA_LANDMARKS: 0
LAMBDA_VERTICES: 0
LAMBDA_GAN_MOTION: 0.7
LAMBDA_GAN_COEFF: 0
LAMBDA_KL: 1
TAG:
NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
================================================
FILE: src/config/facerender.yaml
================================================
model_params:
common_params:
num_kp: 15
image_channel: 3
feature_channel: 32
estimate_jacobian: False # True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25 # 0.25
num_blocks: 5
reshape_channel: 16384 # 16384 = 1024 * 16
reshape_depth: 16
he_estimator_params:
block_expansion: 64
max_features: 2048
num_bins: 66
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
reshape_depth: 16 # 512 = 32 * 16
num_resblocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
reshape_depth: 16
compress: 4
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
mapping_params:
coeff_nc: 70
descriptor_nc: 1024
layer: 3
num_kp: 15
num_bins: 66
================================================
FILE: src/config/facerender_still.yaml
================================================
model_params:
common_params:
num_kp: 15
image_channel: 3
feature_channel: 32
estimate_jacobian: False # True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25 # 0.25
num_blocks: 5
reshape_channel: 16384 # 16384 = 1024 * 16
reshape_depth: 16
he_estimator_params:
block_expansion: 64
max_features: 2048
num_bins: 66
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
reshape_depth: 16 # 512 = 32 * 16
num_resblocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
reshape_depth: 16
compress: 4
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
mapping_params:
coeff_nc: 73
descriptor_nc: 1024
layer: 3
num_kp: 15
num_bins: 66
================================================
FILE: src/face3d/data/__init__.py
================================================
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import numpy as np
import importlib
import torch.utils.data
from face3d.data.base_dataset import BaseDataset
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt, rank=0):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt, rank=rank)
dataset = data_loader.load_data()
return dataset
class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
def __init__(self, opt, rank=0):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
"""
self.opt = opt
dataset_class = find_dataset_using_name(opt.dataset_mode)
self.dataset = dataset_class(opt)
self.sampler = None
print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
if opt.use_ddp and opt.isTrain:
world_size = opt.world_size
self.sampler = torch.utils.data.distributed.DistributedSampler(
self.dataset,
num_replicas=world_size,
rank=rank,
shuffle=not opt.serial_batches
)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
sampler=self.sampler,
num_workers=int(opt.num_threads / world_size),
batch_size=int(opt.batch_size / world_size),
drop_last=True)
else:
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=(not opt.serial_batches) and opt.isTrain,
num_workers=int(opt.num_threads),
drop_last=True
)
def set_epoch(self, epoch):
self.dataset.current_epoch = epoch
if self.sampler is not None:
self.sampler.set_epoch(epoch)
def load_data(self):
return self
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data
================================================
FILE: src/face3d/data/base_dataset.py
================================================
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the class; save the options in the class
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt
# self.root = opt.dataroot
self.current_epoch = 0
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
def get_transform(grayscale=False):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
transform_list += [transforms.ToTensor()]
return transforms.Compose(transform_list)
def get_affine_mat(opt, size):
shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
w, h = size
if 'shift' in opt.preprocess:
shift_pixs = int(opt.shift_pixs)
shift_x = random.randint(-shift_pixs, shift_pixs)
shift_y = random.randint(-shift_pixs, shift_pixs)
if 'scale' in opt.preprocess:
scale = 1 + opt.scale_delta * (2 * random.random() - 1)
if 'rot' in opt.preprocess:
rot_angle = opt.rot_angle * (2 * random.random() - 1)
rot_rad = -rot_angle * np.pi/180
if 'flip' in opt.preprocess:
flip = random.random() > 0.5
shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
affine_inv = np.linalg.inv(affine)
return affine, affine_inv, flip
def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
def apply_lm_affine(landmark, affine, flip, size):
_, h = size
lm = landmark.copy()
lm[:, 1] = h - 1 - lm[:, 1]
lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
lm = lm @ np.transpose(affine)
lm[:, :2] = lm[:, :2] / lm[:, 2:]
lm = lm[:, :2]
lm[:, 1] = h - 1 - lm[:, 1]
if flip:
lm_ = lm.copy()
lm_[:17] = lm[16::-1]
lm_[17:22] = lm[26:21:-1]
lm_[22:27] = lm[21:16:-1]
lm_[31:36] = lm[35:30:-1]
lm_[36:40] = lm[45:41:-1]
lm_[40:42] = lm[47:45:-1]
lm_[42:46] = lm[39:35:-1]
lm_[46:48] = lm[41:39:-1]
lm_[48:55] = lm[54:47:-1]
lm_[55:60] = lm[59:54:-1]
lm_[60:65] = lm[64:59:-1]
lm_[65:68] = lm[67:64:-1]
lm = lm_
return lm
================================================
FILE: src/face3d/data/flist_dataset.py
================================================
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
"""
import os.path
from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
from data.image_folder import make_dataset
from PIL import Image
import random
import util.util as util
import numpy as np
import json
import torch
from scipy.io import loadmat, savemat
import pickle
from util.preprocess import align_img, estimate_norm
from util.load_mats import load_lm3d
def default_flist_reader(flist):
"""
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
"""
imlist = []
with open(flist, 'r') as rf:
for line in rf.readlines():
impath = line.strip()
imlist.append(impath)
return imlist
def jason_flist_reader(flist):
with open(flist, 'r') as fp:
info = json.load(fp)
return info
def parse_label(label):
return torch.tensor(np.array(label).astype(np.float32))
class FlistDataset(BaseDataset):
"""
It requires one directories to host training images '/path/to/data/train'
You can train the model with the dataset flag '--dataroot /path/to/data'.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.lm3d_std = load_lm3d(opt.bfm_folder)
msk_names = default_flist_reader(opt.flist)
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
self.size = len(self.msk_paths)
self.opt = opt
self.name = 'train' if opt.isTrain else 'val'
if '_' in opt.flist:
self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
img (tensor) -- an image in the input domain
msk (tensor) -- its corresponding attention mask
lm (tensor) -- its corresponding 3d landmarks
im_paths (str) -- image paths
aug_flag (bool) -- a flag used to tell whether its raw or augmented
"""
msk_path = self.msk_paths[index % self.size] # make sure index is within then range
img_path = msk_path.replace('mask/', '')
lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
raw_img = Image.open(img_path).convert('RGB')
raw_msk = Image.open(msk_path).convert('RGB')
raw_lm = np.loadtxt(lm_path).astype(np.float32)
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
aug_flag = self.opt.use_aug and self.opt.isTrain
if aug_flag:
img, lm, msk = self._augmentation(img, lm, self.opt, msk)
_, H = img.size
M = estimate_norm(lm, H)
transform = get_transform()
img_tensor = transform(img)
msk_tensor = transform(msk)[:1, ...]
lm_tensor = parse_label(lm)
M_tensor = parse_label(M)
return {'imgs': img_tensor,
'lms': lm_tensor,
'msks': msk_tensor,
'M': M_tensor,
'im_paths': img_path,
'aug_flag': aug_flag,
'dataset': self.name}
def _augmentation(self, img, lm, opt, msk=None):
affine, affine_inv, flip = get_affine_mat(opt, img.size)
img = apply_img_affine(img, affine_inv)
lm = apply_lm_affine(lm, affine, flip, img.size)
if msk is not None:
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
return img, lm, msk
def __len__(self):
"""Return the total number of images in the dataset.
"""
return self.size
================================================
FILE: src/face3d/data/image_folder.py
================================================
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
import numpy as np
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: src/face3d/data/template_dataset.py
================================================
"""Dataset class template
This module provides a template for users to implement custom datasets.
You can specify '--dataset_mode template' to use this dataset.
The class name should be consistent with both the filename and its dataset_mode option.
The filename should be <dataset_mode>_dataset.py
The class name should be <Dataset_mode>Dataset.py
You need to implement the following functions:
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
-- <__init__>: Initialize this dataset class.
-- <__getitem__>: Return a data point and its metadata information.
-- <__len__>: Return the number of images.
"""
from data.base_dataset import BaseDataset, get_transform
# from data.image_folder import make_dataset
# from PIL import Image
class TemplateDataset(BaseDataset):
"""A template dataset class for you to implement custom datasets."""
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
return parser
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
A few things can be done here.
- save the options (have been done in BaseDataset)
- get image paths and meta information of the dataset.
- define the image transformation.
"""
# save the option and dataset root
BaseDataset.__init__(self, opt)
# get the image paths of your dataset;
self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
self.transform = get_transform(opt)
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index -- a random integer for data indexing
Returns:
a dictionary of data with their names. It usually contains the data itself and its metadata information.
Step 1: get a random image path: e.g., path = self.image_paths[index]
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
Step 4: return a data point as a dictionary.
"""
path = 'temp' # needs to be a string
data_A = None # needs to be a tensor
data_B = None # needs to be a tensor
return {'data_A': data_A, 'data_B': data_B, 'path': path}
def __len__(self):
"""Return the total number of images."""
return len(self.image_paths)
================================================
FILE: src/face3d/extract_kp_videos.py
================================================
import os
import cv2
import time
import glob
import argparse
import face_alignment
import numpy as np
from PIL import Image
from tqdm import tqdm
from itertools import cycle
from torch.multiprocessing import Pool, Process, set_start_method
class KeypointExtractor():
def __init__(self, device):
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
device=device)
def extract_keypoint(self, images, name=None, info=True):
if isinstance(images, list):
keypoints = []
if info:
i_range = tqdm(images,desc='landmark Det:')
else:
i_range = images
for image in i_range:
current_kp = self.extract_keypoint(image)
if np.mean(current_kp) == -1 and keypoints:
keypoints.append(keypoints[-1])
else:
keypoints.append(current_kp[None])
keypoints = np.concatenate(keypoints, 0)
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
return keypoints
else:
while True:
try:
keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
break
except RuntimeError as e:
if str(e).startswith('CUDA'):
print("Warning: out of memory, sleep for 1s")
time.sleep(1)
else:
print(e)
break
except TypeError:
print('No face detected in this image')
shape = [68, 2]
keypoints = -1. * np.ones(shape)
break
if name is not None:
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
return keypoints
def read_video(filename):
frames = []
cap = cv2.VideoCapture(filename)
while cap.isOpened():
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
else:
break
cap.release()
return frames
def run(data):
filename, opt, device = data
os.environ['CUDA_VISIBLE_DEVICES'] = device
kp_extractor = KeypointExtractor()
images = read_video(filename)
name = filename.split('/')[-2:]
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
kp_extractor.extract_keypoint(
images,
name=os.path.join(opt.output_dir, name[-2], name[-1])
)
if __name__ == '__main__':
set_start_method('spawn')
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
parser.add_argument('--device_ids', type=str, default='0,1')
parser.add_argument('--workers', type=int, default=4)
opt = parser.parse_args()
filenames = list()
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
extensions = VIDEO_EXTENSIONS
for ext in extensions:
os.listdir(f'{opt.input_dir}')
print(f'{opt.input_dir}/*.{ext}')
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
print('Total number of videos:', len(filenames))
pool = Pool(opt.workers)
args_list = cycle([opt])
device_ids = opt.device_ids.split(",")
device_ids = cycle(device_ids)
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
None
================================================
FILE: src/face3d/extract_kp_videos_safe.py
================================================
import os
import cv2
import time
import glob
import argparse
import numpy as np
from PIL import Image
import torch
from tqdm import tqdm
from itertools import cycle
from torch.multiprocessing import Pool, Process, set_start_method
from facexlib.alignment import landmark_98_to_68
from facexlib.detection import init_detection_model
from facexlib.utils import load_file_from_url
from src.face3d.util.my_awing_arch import FAN
def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'awing_fan':
model = FAN(num_modules=4, num_landmarks=98, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
model.eval()
model = model.to(device)
return model
class KeypointExtractor():
def __init__(self, device='cuda'):
### gfpgan/weights
try:
import webui # in webui
root_path = 'extensions/SadTalker/gfpgan/weights'
except:
root_path = 'gfpgan/weights'
self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
def extract_keypoint(self, images, name=None, info=True):
if isinstance(images, list):
keypoints = []
if info:
i_range = tqdm(images,desc='landmark Det:')
else:
i_range = images
for image in i_range:
current_kp = self.extract_keypoint(image)
# current_kp = self.detector.get_landmarks(np.array(image))
if np.mean(current_kp) == -1 and keypoints:
keypoints.append(keypoints[-1])
else:
keypoints.append(current_kp[None])
keypoints = np.concatenate(keypoints, 0)
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
return keypoints
else:
while True:
try:
with torch.no_grad():
# face detection -> face alignment.
img = np.array(images)
bboxes = self.det_net.detect_faces(images, 0.97)
bboxes = bboxes[0]
img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
#### keypoints to the original location
keypoints[:,0] += int(bboxes[0])
keypoints[:,1] += int(bboxes[1])
break
except RuntimeError as e:
if str(e).startswith('CUDA'):
print("Warning: out of memory, sleep for 1s")
time.sleep(1)
else:
print(e)
break
except TypeError:
print('No face detected in this image')
shape = [68, 2]
keypoints = -1. * np.ones(shape)
break
if name is not None:
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
return keypoints
def read_video(filename):
frames = []
cap = cv2.VideoCapture(filename)
while cap.isOpened():
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
else:
break
cap.release()
return frames
def run(data):
filename, opt, device = data
os.environ['CUDA_VISIBLE_DEVICES'] = device
kp_extractor = KeypointExtractor()
images = read_video(filename)
name = filename.split('/')[-2:]
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
kp_extractor.extract_keypoint(
images,
name=os.path.join(opt.output_dir, name[-2], name[-1])
)
if __name__ == '__main__':
set_start_method('spawn')
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
parser.add_argument('--device_ids', type=str, default='0,1')
parser.add_argument('--workers', type=int, default=4)
opt = parser.parse_args()
filenames = list()
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
extensions = VIDEO_EXTENSIONS
for ext in extensions:
os.listdir(f'{opt.input_dir}')
print(f'{opt.input_dir}/*.{ext}')
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
print('Total number of videos:', len(filenames))
pool = Pool(opt.workers)
args_list = cycle([opt])
device_ids = opt.device_ids.split(",")
device_ids = cycle(device_ids)
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
None
================================================
FILE: src/face3d/models/__init__.py
================================================
"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
import importlib
from src.face3d.models.base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "face3d.models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance
================================================
FILE: src/face3d/models/arcface_torch/README.md
================================================
# Distributed Arcface Training in Pytorch
This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
identity on a single server.
## Requirements
- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
- `pip install -r requirements.txt`.
- Download the dataset
from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
.
## How to Training
To train a model, run `train.py` with the path to the configs:
### 1. Single node, 8 GPUs:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
```
### 2. Multiple nodes, each node 8 GPUs:
Node 0:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
```
Node 1:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
```
### 3.Training resnet2060 with 8 GPUs:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
```
## Model Zoo
- The models are available for non-commercial research purposes only.
- All models can be found in here.
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
As the result, we can evaluate the FAIR performance for different algorithms.
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
There are totally 13,928 positive pairs and 96,983,824 negative pairs.
| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
| :---: | :--- | :--- | :--- |:--- |:--- |
| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
### Performance on IJB-C and Verification Datasets
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.)
## [Speed Benchmark](docs/speed_benchmark.md)
**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
accuracy with several times faster training performance and smaller GPU memory.
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
training and mixed precision training.

More details see
[speed_benchmark.md](docs/speed_benchmark.md) in docs.
### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
`-` means training failed because of gpu memory limitations.
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
| :--- | :--- | :--- | :--- |
|125000 | 4681 | 4824 | 5004 |
|1400000 | **1672** | 3043 | 4738 |
|5500000 | **-** | **1389** | 3975 |
|8000000 | **-** | **-** | 3565 |
|16000000 | **-** | **-** | 2679 |
|29000000 | **-** | **-** | **1855** |
### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
| :--- | :--- | :--- | :--- |
|125000 | 7358 | 5306 | 4868 |
|1400000 | 32252 | 11178 | 6056 |
|5500000 | **-** | 32188 | 9854 |
|8000000 | **-** | **-** | 12310 |
|16000000 | **-** | **-** | 19950 |
|29000000 | **-** | **-** | 32324 |
## Evaluation ICCV2021-MFR and IJB-C
More details see [eval.md](docs/eval.md) in docs.
## Test
We tested many versions of PyTorch. Please create an issue if you are having trouble.
- [x] torch 1.6.0
- [x] torch 1.7.1
- [x] torch 1.8.0
- [x] torch 1.9.0
## Citation
```
@inproceedings{deng2019arcface,
title={Arcface: Additive angular margin loss for deep face recognition},
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={4690--4699},
year={2019}
}
@inproceedings{an2020partical_fc,
title={Partial FC: Training 10 Million Identities on a Single Machine},
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
Zhang, Debing and Fu Ying},
booktitle={Arxiv 2010.05222},
year={2020}
}
```
================================================
FILE: src/face3d/models/arcface_torch/backbones/__init__.py
================================================
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
from .mobilefacenet import get_mbf
def get_model(name, **kwargs):
# resnet
if name == "r18":
return iresnet18(False, **kwargs)
elif name == "r34":
return iresnet34(False, **kwargs)
elif name == "r50":
return iresnet50(False, **kwargs)
elif name == "r100":
return iresnet100(False, **kwargs)
elif name == "r200":
return iresnet200(False, **kwargs)
elif name == "r2060":
from .iresnet2060 import iresnet2060
return iresnet2060(False, **kwargs)
elif name == "mbf":
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf(fp16=fp16, num_features=num_features)
else:
raise ValueError()
================================================
FILE: src/face3d/models/arcface_torch/backbones/iresnet.py
================================================
import torch
from torch import nn
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class IBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
groups=1, base_width=64, dilation=1):
super(IBasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
self.conv1 = conv3x3(inplanes, planes)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
self.prelu = nn.PReLU(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class IResNet(nn.Module):
fc_scale = 7 * 7
def __init__(self,
block, layers, dropout=0, num_features=512, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super(IResNet, self).__init__()
self.fp16 = fp16
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
self.prelu = nn.PReLU(self.inplanes)
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
self.layer2 = self._make_layer(block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
self.dropout = nn.Dropout(p=dropout, inplace=True)
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
nn.init.constant_(self.features.weight, 1.0)
self.features.weight.requires_grad = False
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0, 0.1)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, IBasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation))
return nn.Sequential(*layers)
def forward(self, x):
with torch.cuda.amp.autocast(self.fp16):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn2(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x.float() if self.fp16 else x)
x = self.features(x)
return x
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
model = IResNet(block, layers, **kwargs)
if pretrained:
raise ValueError()
return model
def iresnet18(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
progress, **kwargs)
def iresnet34(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def iresnet50(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
progress, **kwargs)
def iresnet100(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
progress, **kwargs)
def iresnet200(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
progress, **kwargs)
================================================
FILE: src/face3d/models/arcface_torch/backbones/iresnet2060.py
================================================
import torch
from torch import nn
assert torch.__version__ >= "1.8.1"
from torch.utils.checkpoint import checkpoint_sequential
__all__ = ['iresnet2060']
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class IBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
groups=1, base_width=64, dilation=1):
super(IBasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
self.conv1 = conv3x3(inplanes, planes)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
self.prelu = nn.PReLU(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class IResNet(nn.Module):
fc_scale = 7 * 7
def __init__(self,
block, layers, dropout=0, num_features=512, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super(IResNet, self).__init__()
self.fp16 = fp16
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
self.prelu = nn.PReLU(self.inplanes)
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
self.layer2 = self._make_layer(block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
self.dropout = nn.Dropout(p=dropout, inplace=True)
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
nn.init.constant_(self.features.weight, 1.0)
self.features.weight.requires_grad = False
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0, 0.1)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, IBasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation))
return nn.Sequential(*layers)
def checkpoint(self, func, num_seg, x):
if self.training:
return checkpoint_sequential(func, num_seg, x)
else:
return func(x)
def forward(self, x):
with torch.cuda.amp.autocast(self.fp16):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
x = self.checkpoint(self.layer2, 20, x)
x = self.checkpoint(self.layer3, 100, x)
x = self.layer4(x)
x = self.bn2(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x.float() if self.fp16 else x)
x = self.features(x)
return x
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
model = IResNet(block, layers, **kwargs)
if pretrained:
raise ValueError()
return model
def iresnet2060(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
================================================
FILE: src/face3d/models/arcface_torch/backbones/mobilefacenet.py
================================================
'''
Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
Original author cavalleria
'''
import torch.nn as nn
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
import torch
class Flatten(Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ConvBlock(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(ConvBlock, self).__init__()
self.layers = nn.Sequential(
Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
BatchNorm2d(num_features=out_c),
PReLU(num_parameters=out_c)
)
def forward(self, x):
return self.layers(x)
class LinearBlock(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(LinearBlock, self).__init__()
self.layers = nn.Sequential(
Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
BatchNorm2d(num_features=out_c)
)
def forward(self, x):
return self.layers(x)
class DepthWise(Module):
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
super(DepthWise, self).__init__()
self.residual = residual
self.layers = nn.Sequential(
ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
)
def forward(self, x):
short_cut = None
if self.residual:
short_cut = x
x = self.layers(x)
if self.residual:
output = short_cut + x
else:
output = x
return output
class Residual(Module):
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
super(Residual, self).__init__()
modules = []
for _ in range(num_block):
modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
self.layers = Sequential(*modules)
def forward(self, x):
return self.layers(x)
class GDC(Module):
def __init__(self, embedding_size):
super(GDC, self).__init__()
self.layers = nn.Sequential(
LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
Flatten(),
Linear(512, embedding_size, bias=False),
BatchNorm1d(embedding_size))
def forward(self, x):
return self.layers(x)
class MobileFaceNet(Module):
def __init__(self, fp16=False, num_features=512):
super(MobileFaceNet, self).__init__()
scale = 2
self.fp16 = fp16
self.layers = nn.Sequential(
ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
)
self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
self.features = GDC(num_features)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
with torch.cuda.amp.autocast(self.fp16):
x = self.layers(x)
x = self.conv_sep(x.float() if self.fp16 else x)
x = self.features(x)
return x
def get_mbf(fp16, num_features):
return MobileFaceNet(fp16, num_features)
================================================
FILE: src/face3d/models/arcface_torch/configs/3millions.py
================================================
from easydict import EasyDict as edict
# configs for test speed
config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "synthetic"
config.num_classes = 300 * 10000
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = []
================================================
FILE: src/face3d/models/arcface_torch/configs/3millions_pfc.py
================================================
from easydict import EasyDict as edict
# configs for test speed
config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 0.1
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "synthetic"
config.num_classes = 300 * 10000
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = []
================================================
FILE: src/face3d/models/arcface_torch/configs/__init__.py
================================================
================================================
FILE: src/face3d/models/arcface_torch/configs/base.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = "ms1mv3_arcface_r50"
config.dataset = "ms1m-retinaface-t1"
config.embedding_size = 512
config.sample_rate = 1
config.fp16 = False
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
if config.dataset == "emore":
config.rec = "/train_tmp/faces_emore"
config.num_classes = 85742
config.num_image = 5822653
config.num_epoch = 16
config.warmup_epoch = -1
config.decay_epoch = [8, 14, ]
config.val_targets = ["lfw", ]
elif config.dataset == "ms1m-retinaface-t1":
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [11, 17, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
elif config.dataset == "glint360k":
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
elif config.dataset == "webface":
config.rec = "/train_tmp/faces_webface_112x112"
config.num_classes = 10572
config.num_image = "forget"
config.num_epoch = 34
config.warmup_epoch = -1
config.decay_epoch = [20, 28, 32]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_mbf.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "cosface"
config.network = "mbf"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 0.1
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 2e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_r100.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "cosface"
config.network = "r100"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_r18.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "cosface"
config.network = "r18"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_r34.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "cosface"
config.network = "r34"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_r50.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "cosface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "arcface"
config.network = "mbf"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 2e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 20, 25]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "arcface"
config.network = "r18"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "arcface"
config.network = "r2060"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 64
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "arcface"
config.network = "r34"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
================================================
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate =
gitextract_pkbb6gh_/ ├── .gitignore ├── LICENSE ├── README.md ├── app_sadtalker.py ├── cog.yaml ├── docs/ │ ├── FAQ.md │ ├── best_practice.md │ ├── changlelog.md │ ├── face3d.md │ ├── install.md │ └── webui_extension.md ├── inference.py ├── launcher.py ├── predict.py ├── quick_demo.ipynb ├── req.txt ├── requirements.txt ├── requirements3d.txt ├── scripts/ │ ├── download_models.sh │ ├── extension.py │ └── test.sh ├── src/ │ ├── audio2exp_models/ │ │ ├── audio2exp.py │ │ └── networks.py │ ├── audio2pose_models/ │ │ ├── audio2pose.py │ │ ├── audio_encoder.py │ │ ├── cvae.py │ │ ├── discriminator.py │ │ ├── networks.py │ │ └── res_unet.py │ ├── config/ │ │ ├── auido2exp.yaml │ │ ├── auido2pose.yaml │ │ ├── facerender.yaml │ │ ├── facerender_still.yaml │ │ └── similarity_Lm3D_all.mat │ ├── face3d/ │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_dataset.py │ │ │ ├── flist_dataset.py │ │ │ ├── image_folder.py │ │ │ └── template_dataset.py │ │ ├── extract_kp_videos.py │ │ ├── extract_kp_videos_safe.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── arcface_torch/ │ │ │ │ ├── README.md │ │ │ │ ├── backbones/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── iresnet.py │ │ │ │ │ ├── iresnet2060.py │ │ │ │ │ └── mobilefacenet.py │ │ │ │ ├── configs/ │ │ │ │ │ ├── 3millions.py │ │ │ │ │ ├── 3millions_pfc.py │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base.py │ │ │ │ │ ├── glint360k_mbf.py │ │ │ │ │ ├── glint360k_r100.py │ │ │ │ │ ├── glint360k_r18.py │ │ │ │ │ ├── glint360k_r34.py │ │ │ │ │ ├── glint360k_r50.py │ │ │ │ │ ├── ms1mv3_mbf.py │ │ │ │ │ ├── ms1mv3_r18.py │ │ │ │ │ ├── ms1mv3_r2060.py │ │ │ │ │ ├── ms1mv3_r34.py │ │ │ │ │ ├── ms1mv3_r50.py │ │ │ │ │ └── speed.py │ │ │ │ ├── dataset.py │ │ │ │ ├── docs/ │ │ │ │ │ ├── eval.md │ │ │ │ │ ├── install.md │ │ │ │ │ ├── modelzoo.md │ │ │ │ │ └── speed_benchmark.md │ │ │ │ ├── eval/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── verification.py │ │ │ │ ├── eval_ijbc.py │ │ │ │ ├── inference.py │ │ │ │ ├── losses.py │ │ │ │ ├── onnx_helper.py │ │ │ │ ├── onnx_ijbc.py │ │ │ │ ├── partial_fc.py │ │ │ │ ├── requirement.txt │ │ │ │ ├── run.sh │ │ │ │ ├── torch2onnx.py │ │ │ │ ├── train.py │ │ │ │ └── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── plot.py │ │ │ │ ├── utils_amp.py │ │ │ │ ├── utils_callbacks.py │ │ │ │ ├── utils_config.py │ │ │ │ ├── utils_logging.py │ │ │ │ └── utils_os.py │ │ │ ├── base_model.py │ │ │ ├── bfm.py │ │ │ ├── facerecon_model.py │ │ │ ├── losses.py │ │ │ ├── networks.py │ │ │ └── template_model.py │ │ ├── options/ │ │ │ ├── __init__.py │ │ │ ├── base_options.py │ │ │ ├── inference_options.py │ │ │ ├── test_options.py │ │ │ └── train_options.py │ │ ├── util/ │ │ │ ├── BBRegressorParam_r.mat │ │ │ ├── __init__.py │ │ │ ├── detect_lm68.py │ │ │ ├── generate_list.py │ │ │ ├── html.py │ │ │ ├── load_mats.py │ │ │ ├── my_awing_arch.py │ │ │ ├── nvdiffrast.py │ │ │ ├── preprocess.py │ │ │ ├── skin_mask.py │ │ │ ├── test_mean_face.txt │ │ │ ├── util.py │ │ │ └── visualizer.py │ │ └── visualize.py │ ├── facerender/ │ │ ├── animate.py │ │ ├── modules/ │ │ │ ├── dense_motion.py │ │ │ ├── discriminator.py │ │ │ ├── generator.py │ │ │ ├── keypoint_detector.py │ │ │ ├── make_animation.py │ │ │ ├── mapping.py │ │ │ └── util.py │ │ └── sync_batchnorm/ │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ ├── generate_batch.py │ ├── generate_facerender_batch.py │ ├── gradio_demo.py │ ├── test_audio2coeff.py │ └── utils/ │ ├── audio.py │ ├── croper.py │ ├── face_enhancer.py │ ├── hparams.py │ ├── init_path.py │ ├── model2safetensor.py │ ├── paste_pic.py │ ├── preprocess.py │ ├── safetensor_helper.py │ ├── text2speech.py │ └── videoio.py ├── webui.bat └── webui.sh
SYMBOL INDEX (702 symbols across 88 files)
FILE: app_sadtalker.py
function toggle_audio_file (line 13) | def toggle_audio_file(choice):
function ref_video_fn (line 19) | def ref_video_fn(path_of_ref_video):
function sadtalker_demo (line 25) | def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/confi...
FILE: inference.py
function main (line 15) | def main(args):
FILE: launcher.py
function check_python_version (line 23) | def check_python_version():
function commit_hash (line 49) | def commit_hash():
function run (line 63) | def run(command, desc=None, errdesc=None, custom_env=None, live=False):
function check_run (line 91) | def check_run(command):
function is_installed (line 96) | def is_installed(package):
function repo_dir (line 105) | def repo_dir(name):
function run_python (line 109) | def run_python(code, desc=None, errdesc=None):
function run_pip (line 113) | def run_pip(args, desc=None):
function check_run_python (line 121) | def check_run_python(code):
function git_clone (line 125) | def git_clone(url, dir, name, commithash=None):
function git_pull_recursive (line 146) | def git_pull_recursive(dir):
function run_extension_installer (line 156) | def run_extension_installer(extension_dir):
function prepare_environment (line 170) | def prepare_environment():
function start (line 195) | def start():
FILE: predict.py
class Predictor (line 16) | class Predictor(BasePredictor):
method setup (line 17) | def setup(self):
method predict (line 44) | def predict(
function load_default (line 172) | def load_default():
FILE: scripts/extension.py
function check_all_files_safetensor (line 14) | def check_all_files_safetensor(current_dir):
function check_all_files (line 33) | def check_all_files(current_dir):
function download_model (line 58) | def download_model(local_dir='./checkpoints'):
function get_source_image (line 62) | def get_source_image(image):
function get_img_from_txt2img (line 65) | def get_img_from_txt2img(x):
function get_img_from_img2img (line 73) | def get_img_from_img2img(x):
function get_default_checkpoint_path (line 81) | def get_default_checkpoint_path():
function install (line 106) | def install():
function on_ui_tabs (line 162) | def on_ui_tabs():
function on_ui_settings (line 183) | def on_ui_settings():
FILE: src/audio2exp_models/audio2exp.py
class Audio2Exp (line 6) | class Audio2Exp(nn.Module):
method __init__ (line 7) | def __init__(self, netG, cfg, device, prepare_training_loss=False):
method test (line 13) | def test(self, batch):
FILE: src/audio2exp_models/networks.py
class Conv2d (line 5) | class Conv2d(nn.Module):
method __init__ (line 6) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
method forward (line 16) | def forward(self, x):
class SimpleWrapperV2 (line 26) | class SimpleWrapperV2(nn.Module):
method __init__ (line 27) | def __init__(self) -> None:
method forward (line 67) | def forward(self, x, ref, ratio):
FILE: src/audio2pose_models/audio2pose.py
class Audio2Pose (line 7) | class Audio2Pose(nn.Module):
method __init__ (line 8) | def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
method forward (line 24) | def forward(self, x):
method test (line 48) | def test(self, x):
FILE: src/audio2pose_models/audio_encoder.py
class Conv2d (line 5) | class Conv2d(nn.Module):
method __init__ (line 6) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
method forward (line 15) | def forward(self, x):
class AudioEncoder (line 21) | class AudioEncoder(nn.Module):
method __init__ (line 22) | def __init__(self, wav2lip_checkpoint, device):
method forward (line 54) | def forward(self, audio_sequences):
FILE: src/audio2pose_models/cvae.py
function class2onehot (line 6) | def class2onehot(idx, class_num):
class CVAE (line 13) | class CVAE(nn.Module):
method __init__ (line 14) | def __init__(self, cfg):
method reparameterize (line 30) | def reparameterize(self, mu, logvar):
method forward (line 35) | def forward(self, batch):
method test (line 43) | def test(self, batch):
class ENCODER (line 51) | class ENCODER(nn.Module):
method __init__ (line 52) | def __init__(self, layer_sizes, latent_size, num_classes,
method forward (line 73) | def forward(self, batch):
class DECODER (line 99) | class DECODER(nn.Module):
method __init__ (line 100) | def __init__(self, layer_sizes, latent_size, num_classes,
method forward (line 123) | def forward(self, batch):
FILE: src/audio2pose_models/discriminator.py
class ConvNormRelu (line 5) | class ConvNormRelu(nn.Module):
method __init__ (line 6) | def __init__(self, conv_type='1d', in_channels=3, out_channels=64, dow...
method forward (line 49) | def forward(self, x):
class PoseSequenceDiscriminator (line 59) | class PoseSequenceDiscriminator(nn.Module):
method __init__ (line 60) | def __init__(self, cfg):
method forward (line 72) | def forward(self, x):
FILE: src/audio2pose_models/networks.py
class ResidualConv (line 5) | class ResidualConv(nn.Module):
method __init__ (line 6) | def __init__(self, input_dim, output_dim, stride, padding):
method forward (line 24) | def forward(self, x):
class Upsample (line 29) | class Upsample(nn.Module):
method __init__ (line 30) | def __init__(self, input_dim, output_dim, kernel, stride):
method forward (line 37) | def forward(self, x):
class Squeeze_Excite_Block (line 41) | class Squeeze_Excite_Block(nn.Module):
method __init__ (line 42) | def __init__(self, channel, reduction=16):
method forward (line 52) | def forward(self, x):
class ASPP (line 59) | class ASPP(nn.Module):
method __init__ (line 60) | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
method forward (line 88) | def forward(self, x):
method _init_weights (line 95) | def _init_weights(self):
class Upsample_ (line 104) | class Upsample_(nn.Module):
method __init__ (line 105) | def __init__(self, scale=2):
method forward (line 110) | def forward(self, x):
class AttentionBlock (line 114) | class AttentionBlock(nn.Module):
method __init__ (line 115) | def __init__(self, input_encoder, input_decoder, output_dim):
method forward (line 137) | def forward(self, x1, x2):
FILE: src/audio2pose_models/res_unet.py
class ResUnet (line 6) | class ResUnet(nn.Module):
method __init__ (line 7) | def __init__(self, channel=1, filters=[32, 64, 128, 256]):
method forward (line 39) | def forward(self, x):
FILE: src/face3d/data/__init__.py
function find_dataset_using_name (line 19) | def find_dataset_using_name(dataset_name):
function get_option_setter (line 42) | def get_option_setter(dataset_name):
function create_dataset (line 48) | def create_dataset(opt, rank=0):
class CustomDatasetDataLoader (line 62) | class CustomDatasetDataLoader():
method __init__ (line 65) | def __init__(self, opt, rank=0):
method set_epoch (line 99) | def set_epoch(self, epoch):
method load_data (line 104) | def load_data(self):
method __len__ (line 107) | def __len__(self):
method __iter__ (line 111) | def __iter__(self):
FILE: src/face3d/data/base_dataset.py
class BaseDataset (line 13) | class BaseDataset(data.Dataset, ABC):
method __init__ (line 23) | def __init__(self, opt):
method modify_commandline_options (line 34) | def modify_commandline_options(parser, is_train):
method __len__ (line 47) | def __len__(self):
method __getitem__ (line 52) | def __getitem__(self, index):
function get_transform (line 64) | def get_transform(grayscale=False):
function get_affine_mat (line 71) | def get_affine_mat(opt, size):
function apply_img_affine (line 98) | def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
function apply_lm_affine (line 101) | def apply_lm_affine(landmark, affine, flip, size):
FILE: src/face3d/data/flist_dataset.py
function default_flist_reader (line 19) | def default_flist_reader(flist):
function jason_flist_reader (line 31) | def jason_flist_reader(flist):
function parse_label (line 36) | def parse_label(label):
class FlistDataset (line 40) | class FlistDataset(BaseDataset):
method __init__ (line 46) | def __init__(self, opt):
method __getitem__ (line 67) | def __getitem__(self, index):
method _augmentation (line 111) | def _augmentation(self, img, lm, opt, msk=None):
method __len__ (line 122) | def __len__(self):
FILE: src/face3d/data/image_folder.py
function is_image_file (line 20) | def is_image_file(filename):
function make_dataset (line 24) | def make_dataset(dir, max_dataset_size=float("inf")):
function default_loader (line 36) | def default_loader(path):
class ImageFolder (line 40) | class ImageFolder(data.Dataset):
method __init__ (line 42) | def __init__(self, root, transform=None, return_paths=False,
method __getitem__ (line 55) | def __getitem__(self, index):
method __len__ (line 65) | def __len__(self):
FILE: src/face3d/data/template_dataset.py
class TemplateDataset (line 19) | class TemplateDataset(BaseDataset):
method modify_commandline_options (line 22) | def modify_commandline_options(parser, is_train):
method __init__ (line 36) | def __init__(self, opt):
method __getitem__ (line 54) | def __getitem__(self, index):
method __len__ (line 73) | def __len__(self):
FILE: src/face3d/extract_kp_videos.py
class KeypointExtractor (line 14) | class KeypointExtractor():
method __init__ (line 15) | def __init__(self, device):
method extract_keypoint (line 19) | def extract_keypoint(self, images, name=None, info=True):
function read_video (line 58) | def read_video(filename):
function run (line 72) | def run(data):
FILE: src/face3d/extract_kp_videos_safe.py
function init_alignment_model (line 19) | def init_alignment_model(model_name, half=False, device='cuda', model_ro...
class KeypointExtractor (line 34) | class KeypointExtractor():
method __init__ (line 35) | def __init__(self, device='cuda'):
method extract_keypoint (line 48) | def extract_keypoint(self, images, name=None, info=True):
function read_video (line 101) | def read_video(filename):
function run (line 115) | def run(data):
FILE: src/face3d/models/__init__.py
function find_model_using_name (line 25) | def find_model_using_name(model_name):
function get_option_setter (line 48) | def get_option_setter(model_name):
function create_model (line 54) | def create_model(opt):
FILE: src/face3d/models/arcface_torch/backbones/__init__.py
function get_model (line 5) | def get_model(name, **kwargs):
FILE: src/face3d/models/arcface_torch/backbones/iresnet.py
function conv3x3 (line 7) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
function conv1x1 (line 19) | def conv1x1(in_planes, out_planes, stride=1):
class IBasicBlock (line 28) | class IBasicBlock(nn.Module):
method __init__ (line 30) | def __init__(self, inplanes, planes, stride=1, downsample=None,
method forward (line 46) | def forward(self, x):
class IResNet (line 60) | class IResNet(nn.Module):
method __init__ (line 62) | def __init__(self,
method _make_layer (line 114) | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
method forward (line 140) | def forward(self, x):
function _iresnet (line 157) | def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
function iresnet18 (line 164) | def iresnet18(pretrained=False, progress=True, **kwargs):
function iresnet34 (line 169) | def iresnet34(pretrained=False, progress=True, **kwargs):
function iresnet50 (line 174) | def iresnet50(pretrained=False, progress=True, **kwargs):
function iresnet100 (line 179) | def iresnet100(pretrained=False, progress=True, **kwargs):
function iresnet200 (line 184) | def iresnet200(pretrained=False, progress=True, **kwargs):
FILE: src/face3d/models/arcface_torch/backbones/iresnet2060.py
function conv3x3 (line 10) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
function conv1x1 (line 22) | def conv1x1(in_planes, out_planes, stride=1):
class IBasicBlock (line 31) | class IBasicBlock(nn.Module):
method __init__ (line 34) | def __init__(self, inplanes, planes, stride=1, downsample=None,
method forward (line 50) | def forward(self, x):
class IResNet (line 64) | class IResNet(nn.Module):
method __init__ (line 67) | def __init__(self,
method _make_layer (line 119) | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
method checkpoint (line 145) | def checkpoint(self, func, num_seg, x):
method forward (line 151) | def forward(self, x):
function _iresnet (line 168) | def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
function iresnet2060 (line 175) | def iresnet2060(pretrained=False, progress=True, **kwargs):
FILE: src/face3d/models/arcface_torch/backbones/mobilefacenet.py
class Flatten (line 11) | class Flatten(Module):
method forward (line 12) | def forward(self, x):
class ConvBlock (line 16) | class ConvBlock(Module):
method __init__ (line 17) | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=...
method forward (line 25) | def forward(self, x):
class LinearBlock (line 29) | class LinearBlock(Module):
method __init__ (line 30) | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=...
method forward (line 37) | def forward(self, x):
class DepthWise (line 41) | class DepthWise(Module):
method __init__ (line 42) | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=...
method forward (line 51) | def forward(self, x):
class Residual (line 63) | class Residual(Module):
method __init__ (line 64) | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1),...
method forward (line 71) | def forward(self, x):
class GDC (line 75) | class GDC(Module):
method __init__ (line 76) | def __init__(self, embedding_size):
method forward (line 84) | def forward(self, x):
class MobileFaceNet (line 88) | class MobileFaceNet(Module):
method __init__ (line 89) | def __init__(self, fp16=False, num_features=512):
method _initialize_weights (line 107) | def _initialize_weights(self):
method forward (line 121) | def forward(self, x):
function get_mbf (line 129) | def get_mbf(fp16, num_features):
FILE: src/face3d/models/arcface_torch/dataset.py
class BackgroundGenerator (line 13) | class BackgroundGenerator(threading.Thread):
method __init__ (line 14) | def __init__(self, generator, local_rank, max_prefetch=6):
method run (line 22) | def run(self):
method next (line 28) | def next(self):
method __next__ (line 34) | def __next__(self):
method __iter__ (line 37) | def __iter__(self):
class DataLoaderX (line 41) | class DataLoaderX(DataLoader):
method __init__ (line 43) | def __init__(self, local_rank, **kwargs):
method __iter__ (line 48) | def __iter__(self):
method preload (line 54) | def preload(self):
method __next__ (line 62) | def __next__(self):
class MXFaceDataset (line 71) | class MXFaceDataset(Dataset):
method __init__ (line 72) | def __init__(self, root_dir, local_rank):
method __getitem__ (line 93) | def __getitem__(self, index):
method __len__ (line 106) | def __len__(self):
class SyntheticDataset (line 110) | class SyntheticDataset(Dataset):
method __init__ (line 111) | def __init__(self, local_rank):
method __getitem__ (line 120) | def __getitem__(self, index):
method __len__ (line 123) | def __len__(self):
FILE: src/face3d/models/arcface_torch/eval/verification.py
class LFold (line 41) | class LFold:
method __init__ (line 42) | def __init__(self, n_splits=2, shuffle=False):
method split (line 47) | def split(self, indices):
function calculate_roc (line 54) | def calculate_roc(thresholds,
function calculate_accuracy (line 109) | def calculate_accuracy(threshold, dist, actual_issame):
function calculate_val (line 124) | def calculate_val(thresholds,
function calculate_val_far (line 165) | def calculate_val_far(threshold, dist, actual_issame):
function evaluate (line 179) | def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
function load_bin (line 200) | def load_bin(path, image_size):
function test (line 227) | def test(data_set, backbone, batch_size, nfolds=10):
function dumpR (line 275) | def dumpR(data_set,
FILE: src/face3d/models/arcface_torch/eval_ijbc.py
class Embedding (line 54) | class Embedding(object):
method __init__ (line 55) | def __init__(self, prefix, data_shape, batch_size=1):
method get (line 75) | def get(self, rimg, landmark):
method forward_db (line 104) | def forward_db(self, batch_data):
function divideIntoNstrand (line 113) | def divideIntoNstrand(listTemp, n):
function read_template_media_list (line 120) | def read_template_media_list(path):
function read_template_pair_list (line 131) | def read_template_pair_list(path):
function read_image_feature (line 145) | def read_image_feature(path):
function get_image_feature (line 154) | def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
function image2template_feature (line 212) | def image2template_feature(img_feats=None, templates=None, medias=None):
function verification (line 252) | def verification(template_norm_feats=None,
function verification2 (line 282) | def verification2(template_norm_feats=None,
function read_score (line 306) | def read_score(path):
FILE: src/face3d/models/arcface_torch/inference.py
function inference (line 11) | def inference(weight, name, img):
FILE: src/face3d/models/arcface_torch/losses.py
function get_loss (line 5) | def get_loss(name):
class CosFace (line 14) | class CosFace(nn.Module):
method __init__ (line 15) | def __init__(self, s=64.0, m=0.40):
method forward (line 20) | def forward(self, cosine, label):
class ArcFace (line 29) | class ArcFace(nn.Module):
method __init__ (line 30) | def __init__(self, s=64.0, m=0.5):
method forward (line 35) | def forward(self, cosine: torch.Tensor, label):
FILE: src/face3d/models/arcface_torch/onnx_helper.py
class ArcFaceORT (line 15) | class ArcFaceORT:
method __init__ (line 16) | def __init__(self, model_path, cpu=False):
method check (line 22) | def check(self, track='cfat', test_img = None):
method check_batch (line 184) | def check_batch(self, img):
method meta_info (line 202) | def meta_info(self):
method forward (line 206) | def forward(self, imgs):
method benchmark (line 222) | def benchmark(self, img):
FILE: src/face3d/models/arcface_torch/onnx_ijbc.py
class AlignedDataSet (line 28) | class AlignedDataSet(mx.gluon.data.Dataset):
method __init__ (line 29) | def __init__(self, root, lines, align=True):
method __len__ (line 34) | def __len__(self):
method __getitem__ (line 37) | def __getitem__(self, idx):
function extract (line 54) | def extract(model_root, dataset):
function read_template_media_list (line 78) | def read_template_media_list(path):
function read_template_pair_list (line 85) | def read_template_pair_list(path):
function read_image_feature (line 93) | def read_image_feature(path):
function image2template_feature (line 99) | def image2template_feature(img_feats=None,
function verification (line 125) | def verification(template_norm_feats=None,
function verification2 (line 147) | def verification2(template_norm_feats=None,
function main (line 169) | def main(args):
FILE: src/face3d/models/arcface_torch/partial_fc.py
class PartialFC (line 11) | class PartialFC(Module):
method __init__ (line 20) | def __init__(self, rank, local_rank, world_size, batch_size, resume,
method save_params (line 93) | def save_params(self):
method sample (line 100) | def sample(self, total_label):
method forward (line 125) | def forward(self, total_features, norm_weight):
method update (line 133) | def update(self):
method prepare (line 139) | def prepare(self, label, optimizer):
method forward_backward (line 159) | def forward_backward(self, label, features, optimizer):
FILE: src/face3d/models/arcface_torch/torch2onnx.py
function convert_onnx (line 6) | def convert_onnx(net, path_module, output, opset=11, simplify=False):
FILE: src/face3d/models/arcface_torch/train.py
function main (line 21) | def main(args):
FILE: src/face3d/models/arcface_torch/utils/plot.py
function read_template_pair_list (line 19) | def read_template_pair_list(path):
FILE: src/face3d/models/arcface_torch/utils/utils_amp.py
class _MultiDeviceReplicator (line 14) | class _MultiDeviceReplicator(object):
method __init__ (line 19) | def __init__(self, master_tensor: torch.Tensor) -> None:
method get (line 24) | def get(self, device) -> torch.Tensor:
class MaxClipGradScaler (line 32) | class MaxClipGradScaler(GradScaler):
method __init__ (line 33) | def __init__(self, init_scale, max_scale: float, growth_interval=100):
method scale_clip (line 37) | def scale_clip(self):
method scale (line 46) | def scale(self, outputs):
FILE: src/face3d/models/arcface_torch/utils/utils_callbacks.py
class CallBackVerification (line 12) | class CallBackVerification(object):
method __init__ (line 13) | def __init__(self, frequent, rank, val_targets, rec_prefix, image_size...
method ver_test (line 23) | def ver_test(self, backbone: torch.nn.Module, global_step: int):
method init_dataset (line 36) | def init_dataset(self, val_targets, data_dir, image_size):
method __call__ (line 44) | def __call__(self, num_update, backbone: torch.nn.Module):
class CallBackLogging (line 51) | class CallBackLogging(object):
method __init__ (line 52) | def __init__(self, frequent, rank, total_step, batch_size, world_size,...
method __call__ (line 64) | def __call__(self,
class CallBackModelCheckpoint (line 105) | class CallBackModelCheckpoint(object):
method __init__ (line 106) | def __init__(self, rank, output="./"):
method __call__ (line 110) | def __call__(self, global_step, backbone, partial_fc, ):
FILE: src/face3d/models/arcface_torch/utils/utils_config.py
function get_config (line 5) | def get_config(config_file):
FILE: src/face3d/models/arcface_torch/utils/utils_logging.py
class AverageMeter (line 6) | class AverageMeter(object):
method __init__ (line 10) | def __init__(self):
method reset (line 17) | def reset(self):
method update (line 23) | def update(self, val, n=1):
function init_logging (line 30) | def init_logging(rank, models_root):
FILE: src/face3d/models/base_model.py
class BaseModel (line 12) | class BaseModel(ABC):
method __init__ (line 22) | def __init__(self, opt):
method dict_grad_hook_factory (line 49) | def dict_grad_hook_factory(add_func=lambda x: x):
method modify_commandline_options (line 60) | def modify_commandline_options(parser, is_train):
method set_input (line 73) | def set_input(self, input):
method forward (line 82) | def forward(self):
method optimize_parameters (line 87) | def optimize_parameters(self):
method setup (line 91) | def setup(self, opt):
method parallelize (line 107) | def parallelize(self, convert_sync_batchnorm=True):
method data_dependent_initialize (line 138) | def data_dependent_initialize(self, data):
method train (line 141) | def train(self):
method eval (line 148) | def eval(self):
method test (line 155) | def test(self):
method compute_visuals (line 165) | def compute_visuals(self):
method get_image_paths (line 169) | def get_image_paths(self, name='A'):
method update_learning_rate (line 173) | def update_learning_rate(self):
method get_current_visuals (line 184) | def get_current_visuals(self):
method get_current_losses (line 192) | def get_current_losses(self):
method save_networks (line 200) | def save_networks(self, epoch):
method __patch_instance_norm_state_dict (line 230) | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i...
method load_networks (line 244) | def load_networks(self, epoch):
method print_networks (line 284) | def print_networks(self, verbose):
method set_requires_grad (line 302) | def set_requires_grad(self, nets, requires_grad=False):
method generate_visuals_for_evaluation (line 315) | def generate_visuals_for_evaluation(self, data, mode):
FILE: src/face3d/models/bfm.py
function perspective_projection (line 11) | def perspective_projection(focal, center):
class SH (line 19) | class SH:
method __init__ (line 20) | def __init__(self):
class ParametricFaceModel (line 26) | class ParametricFaceModel:
method __init__ (line 27) | def __init__(self,
method to (line 80) | def to(self, device):
method compute_shape (line 87) | def compute_shape(self, id_coeff, exp_coeff):
method compute_texture (line 103) | def compute_texture(self, tex_coeff, normalize=True):
method compute_norm (line 118) | def compute_norm(self, face_shape):
method compute_color (line 141) | def compute_color(self, face_texture, face_norm, gamma):
method compute_rotation (line 175) | def compute_rotation(self, angles):
method to_camera (line 211) | def to_camera(self, face_shape):
method to_image (line 215) | def to_image(self, face_shape):
method transform (line 230) | def transform(self, face_shape, rot, trans):
method get_landmarks (line 243) | def get_landmarks(self, face_proj):
method split_coeff (line 253) | def split_coeff(self, coeffs):
method compute_for_render (line 275) | def compute_for_render(self, coeffs):
method compute_for_render_woRotation (line 302) | def compute_for_render_woRotation(self, coeffs):
FILE: src/face3d/models/facerecon_model.py
class FaceReconModel (line 17) | class FaceReconModel(BaseModel):
method modify_commandline_options (line 20) | def modify_commandline_options(parser, is_train=False):
method __init__ (line 71) | def __init__(self, opt):
method set_input (line 115) | def set_input(self, input):
method forward (line 127) | def forward(self, output_coeff, device):
method compute_losses (line 137) | def compute_losses(self):
method optimize_parameters (line 169) | def optimize_parameters(self, isTrain=True):
method compute_visuals (line 178) | def compute_visuals(self):
method save_mesh (line 200) | def save_mesh(self, name):
method save_coeff (line 211) | def save_coeff(self,name):
FILE: src/face3d/models/losses.py
function resize_n_crop (line 7) | def resize_n_crop(image, M, dsize=112):
class PerceptualLoss (line 13) | class PerceptualLoss(nn.Module):
method __init__ (line 14) | def __init__(self, recog_net, input_size=112):
method forward (line 19) | def forward(imageA, imageB, M):
function perceptual_loss (line 39) | def perceptual_loss(id_featureA, id_featureB):
function photo_loss (line 45) | def photo_loss(imageA, imageB, mask, eps=1e-6):
function landmark_loss (line 56) | def landmark_loss(predict_lm, gt_lm, weight=None):
function reg_loss (line 76) | def reg_loss(coeffs_dict, opt=None):
function reflectance_loss (line 101) | def reflectance_loss(texture, mask):
FILE: src/face3d/models/networks.py
function resize_n_crop (line 21) | def resize_n_crop(image, M, dsize=112):
function filter_state_dict (line 26) | def filter_state_dict(state_dict, remove_name='fc'):
function get_scheduler (line 34) | def get_scheduler(optimizer, opt):
function define_net_recon (line 61) | def define_net_recon(net_recon, use_last_fc=False, init_path=None):
function define_net_recog (line 64) | def define_net_recog(net_recog, pretrained_path=None):
class ReconNetWrapper (line 69) | class ReconNetWrapper(nn.Module):
method __init__ (line 71) | def __init__(self, net_recon, use_last_fc=False, init_path=None):
method forward (line 97) | def forward(self, x):
class RecogNetWrapper (line 107) | class RecogNetWrapper(nn.Module):
method __init__ (line 108) | def __init__(self, net_recog, pretrained_path=None, input_size=112):
method forward (line 121) | def forward(self, image, M):
function conv3x3 (line 146) | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: in...
function conv1x1 (line 152) | def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool...
class BasicBlock (line 157) | class BasicBlock(nn.Module):
method __init__ (line 160) | def __init__(
method forward (line 187) | def forward(self, x: Tensor) -> Tensor:
class Bottleneck (line 206) | class Bottleneck(nn.Module):
method __init__ (line 215) | def __init__(
method forward (line 241) | def forward(self, x: Tensor) -> Tensor:
class ResNet (line 264) | class ResNet(nn.Module):
method __init__ (line 266) | def __init__(
method _make_layer (line 331) | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], plan...
method _forward_impl (line 356) | def _forward_impl(self, x: Tensor) -> Tensor:
method forward (line 374) | def forward(self, x: Tensor) -> Tensor:
function _resnet (line 378) | def _resnet(
function resnet18 (line 394) | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: ...
function resnet34 (line 406) | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: ...
function resnet50 (line 418) | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: ...
function resnet101 (line 430) | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs:...
function resnet152 (line 442) | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs:...
function resnext50_32x4d (line 454) | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **k...
function resnext101_32x8d (line 468) | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **...
function wide_resnet50_2 (line 482) | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **k...
function wide_resnet101_2 (line 500) | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **...
FILE: src/face3d/models/template_model.py
class TemplateModel (line 24) | class TemplateModel(BaseModel):
method modify_commandline_options (line 26) | def modify_commandline_options(parser, is_train=True):
method __init__ (line 42) | def __init__(self, opt):
method set_input (line 73) | def set_input(self, input):
method forward (line 84) | def forward(self):
method backward (line 88) | def backward(self):
method optimize_parameters (line 95) | def optimize_parameters(self):
FILE: src/face3d/options/base_options.py
class BaseOptions (line 13) | class BaseOptions():
method __init__ (line 20) | def __init__(self, cmd_line=None):
method initialize (line 27) | def initialize(self, parser):
method gather_options (line 52) | def gather_options(self):
method print_options (line 93) | def print_options(self, opt):
method parse (line 122) | def parse(self):
FILE: src/face3d/options/inference_options.py
class InferenceOptions (line 4) | class InferenceOptions(BaseOptions):
method initialize (line 10) | def initialize(self, parser):
FILE: src/face3d/options/test_options.py
class TestOptions (line 7) | class TestOptions(BaseOptions):
method initialize (line 13) | def initialize(self, parser):
FILE: src/face3d/options/train_options.py
class TrainOptions (line 7) | class TrainOptions(BaseOptions):
method initialize (line 13) | def initialize(self, parser):
FILE: src/face3d/util/detect_lm68.py
function save_label (line 12) | def save_label(labels, save_path):
function draw_landmarks (line 15) | def draw_landmarks(img, landmark, save_name):
function load_data (line 35) | def load_data(img_name, txt_name):
function load_lm_graph (line 39) | def load_lm_graph(graph_filename):
function detect_68p (line 53) | def detect_68p(img_path,sess,input_op,output_op):
FILE: src/face3d/util/generate_list.py
function write_list (line 7) | def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder=...
function check_list (line 21) | def check_list(rlms_list, rimgs_list, rmsks_list):
FILE: src/face3d/util/html.py
class HTML (line 6) | class HTML:
method __init__ (line 14) | def __init__(self, web_dir, title, refresh=0):
method get_image_dir (line 35) | def get_image_dir(self):
method add_header (line 39) | def add_header(self, text):
method add_images (line 48) | def add_images(self, ims, txts, links, width=400):
method save (line 68) | def save(self):
FILE: src/face3d/util/load_mats.py
function LoadExpBasis (line 11) | def LoadExpBasis(bfm_folder='BFM'):
function transferBFM09 (line 32) | def transferBFM09(bfm_folder='BFM'):
function load_lm3d (line 105) | def load_lm3d(bfm_folder):
FILE: src/face3d/util/my_awing_arch.py
function calculate_points (line 8) | def calculate_points(heatmaps):
class AddCoordsTh (line 44) | class AddCoordsTh(nn.Module):
method __init__ (line 46) | def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=Fal...
method forward (line 53) | def forward(self, input_tensor, heatmap=None):
class CoordConvTh (line 110) | class CoordConvTh(nn.Module):
method __init__ (line 113) | def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, f...
method forward (line 123) | def forward(self, input_tensor, heatmap=None):
function conv3x3 (line 130) | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilati...
class BasicBlock (line 135) | class BasicBlock(nn.Module):
method __init__ (line 138) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 148) | def forward(self, x):
class ConvBlock (line 165) | class ConvBlock(nn.Module):
method __init__ (line 167) | def __init__(self, in_planes, out_planes):
method forward (line 185) | def forward(self, x):
class HourGlass (line 210) | class HourGlass(nn.Module):
method __init__ (line 212) | def __init__(self, num_modules, depth, num_features, first_one=False):
method _generate_network (line 230) | def _generate_network(self, level):
method _forward (line 242) | def _forward(self, level, inp):
method forward (line 264) | def forward(self, x, heatmap):
class FAN (line 269) | class FAN(nn.Module):
method __init__ (line 271) | def __init__(self, num_modules=1, end_relu=False, gray_scale=False, nu...
method forward (line 324) | def forward(self, x):
method get_landmarks (line 359) | def get_landmarks(self, img):
FILE: src/face3d/util/nvdiffrast.py
class MeshRenderer (line 32) | class MeshRenderer(nn.Module):
method __init__ (line 33) | def __init__(self,
method forward (line 50) | def forward(self, vertex, tri, feat=None):
FILE: src/face3d/util/preprocess.py
function POS (line 17) | def POS(xp, x):
function resize_n_crop_img (line 42) | def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
function extract_5p (line 66) | def extract_5p(lm):
function align_img (line 74) | def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor...
FILE: src/face3d/util/skin_mask.py
class GMM (line 9) | class GMM:
method __init__ (line 10) | def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
method likelihood (line 23) | def likelihood(self, data):
function _rgb2ycbcr (line 42) | def _rgb2ycbcr(rgb):
function _bgr2ycbcr (line 54) | def _bgr2ycbcr(bgr):
function skinmask (line 90) | def skinmask(imbgr):
function get_skin_mask (line 111) | def get_skin_mask(img_path):
FILE: src/face3d/util/util.py
function str2bool (line 14) | def str2bool(v):
function copyconf (line 25) | def copyconf(default_opt, **kwargs):
function genvalconf (line 31) | def genvalconf(train_opt, **kwargs):
function find_class_in_module (line 43) | def find_class_in_module(target_cls_name, module):
function tensor2im (line 56) | def tensor2im(input_image, imtype=np.uint8):
function diagnose_network (line 77) | def diagnose_network(net, name='network'):
function save_image (line 96) | def save_image(image_numpy, image_path, aspect_ratio=1.0):
function print_numpy (line 116) | def print_numpy(x, val=True, shp=False):
function mkdirs (line 132) | def mkdirs(paths):
function mkdir (line 145) | def mkdir(path):
function correct_resize_label (line 155) | def correct_resize_label(t, size):
function correct_resize (line 169) | def correct_resize(t, size, mode=Image.BICUBIC):
function draw_landmarks (line 180) | def draw_landmarks(img, landmark, color='r', step=2):
FILE: src/face3d/util/visualizer.py
function save_images (line 13) | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
class Visualizer (line 44) | class Visualizer():
method __init__ (line 50) | def __init__(self, opt):
method reset (line 77) | def reset(self):
method display_current_results (line 82) | def display_current_results(self, visuals, total_iters, epoch, save_re...
method plot_current_losses (line 117) | def plot_current_losses(self, total_iters, losses):
method print_current_losses (line 131) | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
class MyVisualizer (line 150) | class MyVisualizer:
method __init__ (line 151) | def __init__(self, opt):
method display_current_results (line 174) | def display_current_results(self, visuals, total_iters, epoch, dataset...
method plot_current_losses (line 205) | def plot_current_losses(self, total_iters, losses, dataset='train'):
method print_current_losses (line 210) | def print_current_losses(self, epoch, iters, losses, t_comp, t_data, d...
FILE: src/face3d/visualize.py
function gen_composed_video (line 12) | def gen_composed_video(args, device, first_frame_coeff, coeff_path, audi...
FILE: src/facerender/animate.py
class AnimateFromCoeff (line 33) | class AnimateFromCoeff():
method __init__ (line 35) | def __init__(self, sadtalker_path, device):
method load_cpk_facevid2vid_safetensor (line 86) | def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=N...
method load_cpk_facevid2vid (line 113) | def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discri...
method load_cpk_mapping (line 143) | def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminato...
method generate (line 157) | def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=No...
FILE: src/facerender/modules/dense_motion.py
class DenseMotionNetwork (line 9) | class DenseMotionNetwork(nn.Module):
method __init__ (line 14) | def __init__(self, block_expansion, num_blocks, max_features, num_kp, ...
method create_sparse_motions (line 34) | def create_sparse_motions(self, feature, kp_driving, kp_source):
method create_deformed_feature (line 59) | def create_deformed_feature(self, feature, sparse_motions):
method create_heatmap_representations (line 68) | def create_heatmap_representations(self, feature, kp_driving, kp_source):
method forward (line 80) | def forward(self, feature, kp_driving, kp_source):
FILE: src/facerender/modules/discriminator.py
class DownBlock2d (line 7) | class DownBlock2d(nn.Module):
method __init__ (line 12) | def __init__(self, in_features, out_features, norm=False, kernel_size=...
method forward (line 25) | def forward(self, x):
class Discriminator (line 36) | class Discriminator(nn.Module):
method __init__ (line 41) | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, m...
method forward (line 57) | def forward(self, x):
class MultiScaleDiscriminator (line 69) | class MultiScaleDiscriminator(nn.Module):
method __init__ (line 74) | def __init__(self, scales=(), **kwargs):
method forward (line 82) | def forward(self, x):
FILE: src/facerender/modules/generator.py
class OcclusionAwareGenerator (line 8) | class OcclusionAwareGenerator(nn.Module):
method __init__ (line 13) | def __init__(self, image_channel, feature_channel, num_kp, block_expan...
method deform_input (line 61) | def deform_input(self, inp, deformation):
method forward (line 70) | def forward(self, source_image, kp_driving, kp_source):
class SPADEDecoder (line 120) | class SPADEDecoder(nn.Module):
method __init__ (line 121) | def __init__(self):
method forward (line 140) | def forward(self, feature):
class OcclusionAwareSPADEGenerator (line 161) | class OcclusionAwareSPADEGenerator(nn.Module):
method __init__ (line 163) | def __init__(self, image_channel, feature_channel, num_kp, block_expan...
method deform_input (line 201) | def deform_input(self, inp, deformation):
method forward (line 210) | def forward(self, source_image, kp_driving, kp_source):
FILE: src/facerender/modules/keypoint_detector.py
class KPDetector (line 9) | class KPDetector(nn.Module):
method __init__ (line 14) | def __init__(self, block_expansion, feature_channel, num_kp, image_cha...
method gaussian2kp (line 44) | def gaussian2kp(self, heatmap):
method forward (line 56) | def forward(self, x):
class HEEstimator (line 85) | class HEEstimator(nn.Module):
method __init__ (line 90) | def __init__(self, block_expansion, feature_channel, num_kp, image_cha...
method forward (line 136) | def forward(self, x):
FILE: src/facerender/modules/make_animation.py
function normalize_kp (line 7) | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_moveme...
function headpose_pred_to_degree (line 29) | def headpose_pred_to_degree(pred):
function get_rotation_matrix (line 37) | def get_rotation_matrix(yaw, pitch, roll):
function keypoint_transformation (line 65) | def keypoint_transformation(kp_canonical, he, wo_exp=False):
function make_animation (line 102) | def make_animation(source_image, source_semantics, target_semantics,
class AnimateModel (line 141) | class AnimateModel(torch.nn.Module):
method __init__ (line 146) | def __init__(self, generator, kp_extractor, mapping):
method forward (line 156) | def forward(self, x):
FILE: src/facerender/modules/mapping.py
class MappingNet (line 8) | class MappingNet(nn.Module):
method __init__ (line 9) | def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins):
method forward (line 32) | def forward(self, input_3dmm):
FILE: src/facerender/modules/util.py
function kp2gaussian (line 12) | def kp2gaussian(kp, spatial_size, kp_variance):
function make_coordinate_grid_2d (line 35) | def make_coordinate_grid_2d(spatial_size, type):
function make_coordinate_grid (line 54) | def make_coordinate_grid(spatial_size, type):
class ResBottleneck (line 73) | class ResBottleneck(nn.Module):
method __init__ (line 74) | def __init__(self, in_features, stride):
method forward (line 88) | def forward(self, x):
class ResBlock2d (line 105) | class ResBlock2d(nn.Module):
method __init__ (line 110) | def __init__(self, in_features, kernel_size, padding):
method forward (line 119) | def forward(self, x):
class ResBlock3d (line 130) | class ResBlock3d(nn.Module):
method __init__ (line 135) | def __init__(self, in_features, kernel_size, padding):
method forward (line 144) | def forward(self, x):
class UpBlock2d (line 155) | class UpBlock2d(nn.Module):
method __init__ (line 160) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 167) | def forward(self, x):
class UpBlock3d (line 174) | class UpBlock3d(nn.Module):
method __init__ (line 179) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 186) | def forward(self, x):
class DownBlock2d (line 195) | class DownBlock2d(nn.Module):
method __init__ (line 200) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 207) | def forward(self, x):
class DownBlock3d (line 215) | class DownBlock3d(nn.Module):
method __init__ (line 220) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 231) | def forward(self, x):
class SameBlock2d (line 239) | class SameBlock2d(nn.Module):
method __init__ (line 244) | def __init__(self, in_features, out_features, groups=1, kernel_size=3,...
method forward (line 254) | def forward(self, x):
class Encoder (line 261) | class Encoder(nn.Module):
method __init__ (line 266) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 276) | def forward(self, x):
class Decoder (line 283) | class Decoder(nn.Module):
method __init__ (line 288) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 305) | def forward(self, x):
class Hourglass (line 319) | class Hourglass(nn.Module):
method __init__ (line 324) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 330) | def forward(self, x):
class KPHourglass (line 334) | class KPHourglass(nn.Module):
method __init__ (line 339) | def __init__(self, block_expansion, in_features, reshape_features, res...
method forward (line 360) | def forward(self, x):
class AntiAliasInterpolation2d (line 371) | class AntiAliasInterpolation2d(nn.Module):
method __init__ (line 375) | def __init__(self, channels, scale):
method forward (line 409) | def forward(self, input):
class SPADE (line 420) | class SPADE(nn.Module):
method __init__ (line 421) | def __init__(self, norm_nc, label_nc):
method forward (line 433) | def forward(self, x, segmap):
class SPADEResnetBlock (line 443) | class SPADEResnetBlock(nn.Module):
method __init__ (line 444) | def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation...
method forward (line 467) | def forward(self, x, seg1):
method shortcut (line 474) | def shortcut(self, x, seg1):
method actvn (line 481) | def actvn(self, x):
class audio2image (line 484) | class audio2image(nn.Module):
method __init__ (line 485) | def __init__(self, generator, kp_extractor, he_estimator_video, he_est...
method headpose_pred_to_degree (line 494) | def headpose_pred_to_degree(self, pred):
method get_rotation_matrix (line 503) | def get_rotation_matrix(self, yaw, pitch, roll):
method keypoint_transformation (line 531) | def keypoint_transformation(self, kp_canonical, he):
method forward (line 557) | def forward(self, source_image, target_audio):
FILE: src/facerender/sync_batchnorm/batchnorm.py
function _sum_ft (line 24) | def _sum_ft(tensor):
function _unsqueeze_ft (line 29) | def _unsqueeze_ft(tensor):
class _SynchronizedBatchNorm (line 38) | class _SynchronizedBatchNorm(_BatchNorm):
method __init__ (line 39) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
method forward (line 48) | def forward(self, input):
method __data_parallel_replicate__ (line 80) | def __data_parallel_replicate__(self, ctx, copy_id):
method _data_parallel_master (line 90) | def _data_parallel_master(self, intermediates):
method _compute_mean_std (line 113) | def _compute_mean_std(self, sum_, ssum, size):
class SynchronizedBatchNorm1d (line 128) | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
method _check_input_dim (line 184) | def _check_input_dim(self, input):
class SynchronizedBatchNorm2d (line 191) | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
method _check_input_dim (line 247) | def _check_input_dim(self, input):
class SynchronizedBatchNorm3d (line 254) | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
method _check_input_dim (line 311) | def _check_input_dim(self, input):
FILE: src/facerender/sync_batchnorm/comm.py
class FutureResult (line 18) | class FutureResult(object):
method __init__ (line 21) | def __init__(self):
method put (line 26) | def put(self, result):
method get (line 32) | def get(self):
class SlavePipe (line 46) | class SlavePipe(_SlavePipeBase):
method run_slave (line 49) | def run_slave(self, msg):
class SyncMaster (line 56) | class SyncMaster(object):
method __init__ (line 67) | def __init__(self, master_callback):
method __getstate__ (line 78) | def __getstate__(self):
method __setstate__ (line 81) | def __setstate__(self, state):
method register_slave (line 84) | def register_slave(self, identifier):
method run_master (line 102) | def run_master(self, master_msg):
method nr_slaves (line 136) | def nr_slaves(self):
FILE: src/facerender/sync_batchnorm/replicate.py
class CallbackContext (line 23) | class CallbackContext(object):
function execute_replication_callbacks (line 27) | def execute_replication_callbacks(modules):
class DataParallelWithCallback (line 50) | class DataParallelWithCallback(DataParallel):
method replicate (line 64) | def replicate(self, module, device_ids):
function patch_replication_callback (line 70) | def patch_replication_callback(data_parallel):
FILE: src/facerender/sync_batchnorm/unittest.py
function as_numpy (line 17) | def as_numpy(v):
class TorchTestCase (line 23) | class TorchTestCase(unittest.TestCase):
method assertTensorClose (line 24) | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
FILE: src/generate_batch.py
function crop_pad_audio (line 10) | def crop_pad_audio(wav, audio_length):
function parse_audio_length (line 17) | def parse_audio_length(audio_length, sr, fps):
function generate_blink_seq (line 25) | def generate_blink_seq(num_frames):
function generate_blink_seq_randomly (line 37) | def generate_blink_seq_randomly(num_frames):
function get_data (line 51) | def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_pa...
FILE: src/generate_facerender_batch.py
function get_facerender_data (line 8) | def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
function transform_semantic_1 (line 88) | def transform_semantic_1(semantic, semantic_radius):
function transform_semantic_target (line 93) | def transform_semantic_target(coeff_3dmm, frame_index, semantic_radius):
function gen_camera_pose (line 100) | def gen_camera_pose(camera_degree_list, frame_num, batch_size):
FILE: src/gradio_demo.py
function mp3_to_wav (line 14) | def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
class SadTalker (line 19) | class SadTalker():
method __init__ (line 21) | def __init__(self, checkpoint_path='checkpoints', config_path='src/con...
method test (line 36) | def test(self, source_image, driven_audio, preprocess='crop',
FILE: src/test_audio2coeff.py
function load_cpk (line 16) | def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"):
class Audio2Coeff (line 25) | class Audio2Coeff():
method __init__ (line 27) | def __init__(self, sadtalker_path, device):
method generate (line 74) | def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_p...
method using_refpose (line 107) | def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):
FILE: src/utils/audio.py
function load_wav (line 9) | def load_wav(path, sr):
function save_wav (line 12) | def save_wav(wav, path, sr):
function save_wavenet_wav (line 17) | def save_wavenet_wav(wav, path, sr):
function preemphasis (line 20) | def preemphasis(wav, k, preemphasize=True):
function inv_preemphasis (line 25) | def inv_preemphasis(wav, k, inv_preemphasize=True):
function get_hop_size (line 30) | def get_hop_size():
function linearspectrogram (line 37) | def linearspectrogram(wav):
function melspectrogram (line 45) | def melspectrogram(wav):
function _lws_processor (line 53) | def _lws_processor():
function _stft (line 57) | def _stft(y):
function num_frames (line 65) | def num_frames(length, fsize, fshift):
function pad_lr (line 76) | def pad_lr(x, fsize, fshift):
function librosa_pad_lr (line 86) | def librosa_pad_lr(x, fsize, fshift):
function _linear_to_mel (line 92) | def _linear_to_mel(spectogram):
function _build_mel_basis (line 98) | def _build_mel_basis():
function _amp_to_db (line 103) | def _amp_to_db(x):
function _db_to_amp (line 107) | def _db_to_amp(x):
function _normalize (line 110) | def _normalize(S):
function _denormalize (line 124) | def _denormalize(D):
FILE: src/utils/croper.py
class Preprocesser (line 19) | class Preprocesser:
method __init__ (line 20) | def __init__(self, device='cuda'):
method get_landmark (line 23) | def get_landmark(self, img_np):
method align_face (line 43) | def align_face(self, img, lm, output_size=1024):
method crop (line 126) | def crop(self, img_np_list, still=False, xsize=512): # first frame ...
FILE: src/utils/face_enhancer.py
class GeneratorWithLen (line 13) | class GeneratorWithLen(object):
method __init__ (line 16) | def __init__(self, gen, length):
method __len__ (line 20) | def __len__(self):
method __iter__ (line 23) | def __iter__(self):
function enhancer_list (line 26) | def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
function enhancer_generator_with_len (line 30) | def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='r...
function enhancer_generator_no_len (line 42) | def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='rea...
FILE: src/utils/hparams.py
class HParams (line 4) | class HParams:
method __init__ (line 5) | def __init__(self, **kwargs):
method __getattr__ (line 11) | def __getattr__(self, key):
method set_hparam (line 16) | def set_hparam(self, key, value):
function hparams_debug_string (line 157) | def hparams_debug_string():
FILE: src/utils/init_path.py
function init_path (line 4) | def init_path(checkpoint_dir, config_dir, size=512, old_version=False, p...
FILE: src/utils/model2safetensor.py
function load_cpk_facevid2vid (line 43) | def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=...
function load_cpk_facevid2vid_safetensor (line 75) | def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None,
class SadTalker (line 125) | class SadTalker(torch.nn.Module):
method __init__ (line 126) | def __init__(self, kp_extractor, generator, netG, audio2pose, face_3dr...
FILE: src/utils/paste_pic.py
function paste_pic (line 8) | def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_vide...
FILE: src/utils/preprocess.py
function split_coeff (line 22) | def split_coeff(coeffs):
class CropAndExtract (line 46) | class CropAndExtract():
method __init__ (line 47) | def __init__(self, sadtalker_path, device):
method generate (line 63) | def generate(self, input_path, save_dir, crop_or_resize='crop', source...
FILE: src/utils/safetensor_helper.py
function load_x_from_safetensor (line 3) | def load_x_from_safetensor(checkpoint, key):
FILE: src/utils/text2speech.py
class TTSTalker (line 6) | class TTSTalker():
method __init__ (line 7) | def __init__(self) -> None:
method test (line 11) | def test(self, text, language='en'):
FILE: src/utils/videoio.py
function load_video_to_cv2 (line 8) | def load_video_to_cv2(input_path):
function save_video_with_watermark (line 20) | def save_video_with_watermark(video, audio, save_path, watermark=False):
Condensed preview — 141 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (617K chars).
[
{
"path": ".gitignore",
"chars": 3205,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "LICENSE",
"chars": 11692,
"preview": "Tencent is pleased to support the open source community by making SadTalker available.\n\nCopyright (C), a Tencent company"
},
{
"path": "README.md",
"chars": 17068,
"preview": "<div align=\"center\">\n\n<img src='https://user-images.githubusercontent.com/4397546/229094115-862c747e-7397-4b54-ba4a-bd36"
},
{
"path": "app_sadtalker.py",
"chars": 5836,
"preview": "import os, sys\nimport gradio as gr\nfrom src.gradio_demo import SadTalker \n\n\ntry:\n import webui # in webui\n in_we"
},
{
"path": "cog.yaml",
"chars": 1115,
"preview": "build:\n gpu: true\n cuda: \"11.3\"\n python_version: \"3.8\"\n system_packages:\n - \"ffmpeg\"\n - \"libgl1-mesa-glx\"\n "
},
{
"path": "docs/FAQ.md",
"chars": 2168,
"preview": "\n## Frequency Asked Question\n\n**Q: `ffmpeg` is not recognized as an internal or external command**\n\nIn Linux, you can in"
},
{
"path": "docs/best_practice.md",
"chars": 5334,
"preview": "# Best Practices and Tips for configuration\n\n> Our model only works on REAL people or the portrait image similar to REAL"
},
{
"path": "docs/changlelog.md",
"chars": 2230,
"preview": "## changelogs\n\n\n- __[2023.04.06]__: stable-diffiusion webui extension is release.\n\n- __[2023.04.03]__: Enable TTS in hug"
},
{
"path": "docs/face3d.md",
"chars": 1412,
"preview": "## 3D Face Visualization\n\nWe use `pytorch3d` to visualize the 3D faces from a single image.\n\nThe requirements for 3D vis"
},
{
"path": "docs/install.md",
"chars": 1235,
"preview": "### macOS\n\nThis method has been tested on a M1 Mac (13.3)\n\n```bash\ngit clone https://github.com/OpenTalker/SadTalker.git"
},
{
"path": "docs/webui_extension.md",
"chars": 2070,
"preview": "## Run SadTalker as a Stable Diffusion WebUI Extension.\n\n1. Install the lastest version of [stable-diffusion-webui](http"
},
{
"path": "inference.py",
"chars": 7610,
"preview": "from glob import glob\nimport shutil\nimport torch\nfrom time import strftime\nimport os, sys, time\nfrom argparse import Ar"
},
{
"path": "launcher.py",
"chars": 7039,
"preview": "# this scripts installs necessary requirements and launches main program in webui.py\n# borrow from : https://github.com/"
},
{
"path": "predict.py",
"chars": 6481,
"preview": "\"\"\"run bash scripts/download_models.sh first to prepare the weights file\"\"\"\nimport os\nimport shutil\nfrom argparse import"
},
{
"path": "quick_demo.ipynb",
"chars": 6917,
"preview": "{\n \"cells\": [\n {\n \"attachments\": {},\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"M74Gs_Tj"
},
{
"path": "req.txt",
"chars": 306,
"preview": "llvmlite==0.38.1\nnumpy==1.21.6\nface_alignment==1.3.5\nimageio==2.19.3\nimageio-ffmpeg==0.4.7\nlibrosa==0.10.0.post2\nnumba=="
},
{
"path": "requirements.txt",
"chars": 277,
"preview": "numpy==1.23.4\nface_alignment==1.3.5\nimageio==2.19.3\nimageio-ffmpeg==0.4.7\nlibrosa==0.9.2 # \nnumba\nresampy==0.3.1\npydub=="
},
{
"path": "requirements3d.txt",
"chars": 288,
"preview": "numpy==1.23.4\nface_alignment==1.3.5\nimageio==2.19.3\nimageio-ffmpeg==0.4.7\nlibrosa==0.9.2 # \nnumba\nresampy==0.3.1\npydub=="
},
{
"path": "scripts/download_models.sh",
"chars": 2763,
"preview": "mkdir ./checkpoints \n\n# lagency download link\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2"
},
{
"path": "scripts/extension.py",
"chars": 6754,
"preview": "import os, sys\r\nfrom pathlib import Path\r\nimport tempfile\r\nimport gradio as gr\r\nfrom modules.call_queue import wrap_grad"
},
{
"path": "scripts/test.sh",
"chars": 855,
"preview": "# ### some test command before commit.\n# python inference.py --preprocess crop --size 256\n# python inference.py --prepro"
},
{
"path": "src/audio2exp_models/audio2exp.py",
"chars": 1253,
"preview": "from tqdm import tqdm\nimport torch\nfrom torch import nn\n\n\nclass Audio2Exp(nn.Module):\n def __init__(self, netG, cfg, "
},
{
"path": "src/audio2exp_models/networks.py",
"chars": 2977,
"preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nclass Conv2d(nn.Module):\n def __init__(self, cin, "
},
{
"path": "src/audio2pose_models/audio2pose.py",
"chars": 3812,
"preview": "import torch\nfrom torch import nn\nfrom src.audio2pose_models.cvae import CVAE\nfrom src.audio2pose_models.discriminator i"
},
{
"path": "src/audio2pose_models/audio_encoder.py",
"chars": 2745,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass Conv2d(nn.Module):\n def __init__(self, "
},
{
"path": "src/audio2pose_models/cvae.py",
"chars": 6072,
"preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom src.audio2pose_models.res_unet import ResUnet\n\nde"
},
{
"path": "src/audio2pose_models/discriminator.py",
"chars": 2669,
"preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nclass ConvNormRelu(nn.Module):\n def __init__(self,"
},
{
"path": "src/audio2pose_models/networks.py",
"chars": 4247,
"preview": "import torch.nn as nn\nimport torch\n\n\nclass ResidualConv(nn.Module):\n def __init__(self, input_dim, output_dim, stride"
},
{
"path": "src/audio2pose_models/res_unet.py",
"chars": 2229,
"preview": "import torch\nimport torch.nn as nn\nfrom src.audio2pose_models.networks import ResidualConv, Upsample\n\n\nclass ResUnet(nn."
},
{
"path": "src/config/auido2exp.yaml",
"chars": 1209,
"preview": "DATASET:\n TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt\n EVAL_FILE_LIST: /apdceph"
},
{
"path": "src/config/auido2pose.yaml",
"chars": 1097,
"preview": "DATASET:\n TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt\n "
},
{
"path": "src/config/facerender.yaml",
"chars": 1087,
"preview": "model_params:\n common_params:\n num_kp: 15 \n image_channel: 3 \n feature_channel: 32\n esti"
},
{
"path": "src/config/facerender_still.yaml",
"chars": 1087,
"preview": "model_params:\n common_params:\n num_kp: 15 \n image_channel: 3 \n feature_channel: 32\n esti"
},
{
"path": "src/face3d/data/__init__.py",
"chars": 4584,
"preview": "\"\"\"This package includes all the modules related to data loading and preprocessing\n\n To add a custom dataset class calle"
},
{
"path": "src/face3d/data/base_dataset.py",
"chars": 4679,
"preview": "\"\"\"This module implements an abstract base class (ABC) 'BaseDataset' for datasets.\n\nIt also includes common transformati"
},
{
"path": "src/face3d/data/flist_dataset.py",
"chars": 4093,
"preview": "\"\"\"This script defines the custom dataset for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os.path\nfrom data.base_dataset import "
},
{
"path": "src/face3d/data/image_folder.py",
"chars": 1959,
"preview": "\"\"\"A modified image folder class\n\nWe modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/ma"
},
{
"path": "src/face3d/data/template_dataset.py",
"chars": 3506,
"preview": "\"\"\"Dataset class template\n\nThis module provides a template for users to implement custom datasets.\nYou can specify '--da"
},
{
"path": "src/face3d/extract_kp_videos.py",
"chars": 3915,
"preview": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport face_alignment\nimport numpy as np\nfrom PIL import Im"
},
{
"path": "src/face3d/extract_kp_videos_safe.py",
"chars": 5705,
"preview": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport numpy as np\nfrom PIL import Image\nimport torch\nfrom "
},
{
"path": "src/face3d/models/__init__.py",
"chars": 3090,
"preview": "\"\"\"This package contains modules related to objective functions, optimizations, and network architectures.\n\nTo add a cus"
},
{
"path": "src/face3d/models/arcface_torch/README.md",
"chars": 8597,
"preview": "# Distributed Arcface Training in Pytorch\n\nThis is a deep learning library that makes face recognition efficient, and ef"
},
{
"path": "src/face3d/models/arcface_torch/backbones/__init__.py",
"chars": 822,
"preview": "from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200\nfrom .mobilefacenet import get_mbf\n\n\ndef ge"
},
{
"path": "src/face3d/models/arcface_torch/backbones/iresnet.py",
"chars": 7149,
"preview": "import torch\nfrom torch import nn\n\n__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']\n\n\ndef c"
},
{
"path": "src/face3d/models/arcface_torch/backbones/iresnet2060.py",
"chars": 6708,
"preview": "import torch\nfrom torch import nn\n\nassert torch.__version__ >= \"1.8.1\"\nfrom torch.utils.checkpoint import checkpoint_seq"
},
{
"path": "src/face3d/models/arcface_torch/backbones/mobilefacenet.py",
"chars": 4895,
"preview": "'''\nAdapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py\nOriginal author ca"
},
{
"path": "src/face3d/models/arcface_torch/configs/3millions.py",
"chars": 519,
"preview": "from easydict import EasyDict as edict\n\n# configs for test speed\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.networ"
},
{
"path": "src/face3d/models/arcface_torch/configs/3millions_pfc.py",
"chars": 519,
"preview": "from easydict import EasyDict as edict\n\n# configs for test speed\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.networ"
},
{
"path": "src/face3d/models/arcface_torch/configs/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/face3d/models/arcface_torch/configs/base.py",
"chars": 1628,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/glint360k_mbf.py",
"chars": 647,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/glint360k_r100.py",
"chars": 648,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/glint360k_r18.py",
"chars": 647,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/glint360k_r34.py",
"chars": 647,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/glint360k_r50.py",
"chars": 647,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py",
"chars": 651,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/ms1mv3_r18.py",
"chars": 651,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py",
"chars": 652,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/ms1mv3_r34.py",
"chars": 651,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/ms1mv3_r50.py",
"chars": 651,
"preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G tmpfs /t"
},
{
"path": "src/face3d/models/arcface_torch/configs/speed.py",
"chars": 519,
"preview": "from easydict import EasyDict as edict\n\n# configs for test speed\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.networ"
},
{
"path": "src/face3d/models/arcface_torch/dataset.py",
"chars": 3868,
"preview": "import numbers\nimport os\nimport queue as Queue\nimport threading\n\nimport mxnet as mx\nimport numpy as np\nimport torch\nfrom"
},
{
"path": "src/face3d/models/arcface_torch/docs/eval.md",
"chars": 655,
"preview": "## Eval on ICCV2021-MFR\n\ncoming soon.\n\n\n## Eval IJBC\nYou can eval ijbc with pytorch or onnx.\n\n\n1. Eval IJBC With Onnx\n``"
},
{
"path": "src/face3d/models/arcface_torch/docs/install.md",
"chars": 1613,
"preview": "## v1.8.0 \n### Linux and Windows \n```shell\n# CUDA 11.0\npip --default-timeout=100 install torch==1.8.0+cu111 torchvision"
},
{
"path": "src/face3d/models/arcface_torch/docs/modelzoo.md",
"chars": 0,
"preview": ""
},
{
"path": "src/face3d/models/arcface_torch/docs/speed_benchmark.md",
"chars": 5465,
"preview": "## Test Training Speed\n\n- Test Commands\n\nYou need to use the following two commands to test the Partial FC training perf"
},
{
"path": "src/face3d/models/arcface_torch/eval/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/face3d/models/arcface_torch/eval/verification.py",
"chars": 16092,
"preview": "\"\"\"Helper for evaluation on the Labeled Faces in the Wild dataset \n\"\"\"\n\n# MIT License\n#\n# Copyright (c) 2016 David Sandb"
},
{
"path": "src/face3d/models/arcface_torch/eval_ijbc.py",
"chars": 17222,
"preview": "# coding: utf-8\n\nimport os\nimport pickle\n\nimport matplotlib\nimport pandas as pd\n\nmatplotlib.use('Agg')\nimport matplotlib"
},
{
"path": "src/face3d/models/arcface_torch/inference.py",
"chars": 1033,
"preview": "import argparse\n\nimport cv2\nimport numpy as np\nimport torch\n\nfrom backbones import get_model\n\n\n@torch.no_grad()\ndef infe"
},
{
"path": "src/face3d/models/arcface_torch/losses.py",
"chars": 1137,
"preview": "import torch\nfrom torch import nn\n\n\ndef get_loss(name):\n if name == \"cosface\":\n return CosFace()\n elif name"
},
{
"path": "src/face3d/models/arcface_torch/onnx_helper.py",
"chars": 10422,
"preview": "from __future__ import division\nimport datetime\nimport os\nimport os.path as osp\nimport glob\nimport numpy as np\nimport cv"
},
{
"path": "src/face3d/models/arcface_torch/onnx_ijbc.py",
"chars": 10321,
"preview": "import argparse\nimport os\nimport pickle\nimport timeit\n\nimport cv2\nimport mxnet as mx\nimport numpy as np\nimport pandas as"
},
{
"path": "src/face3d/models/arcface_torch/partial_fc.py",
"chars": 9492,
"preview": "import logging\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom torch.nn import Module\nfrom torch.nn.functi"
},
{
"path": "src/face3d/models/arcface_torch/requirement.txt",
"chars": 40,
"preview": "tensorboard\neasydict\nmxnet\nonnx\nsklearn\n"
},
{
"path": "src/face3d/models/arcface_torch/run.sh",
"chars": 260,
"preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --ma"
},
{
"path": "src/face3d/models/arcface_torch/torch2onnx.py",
"chars": 2365,
"preview": "import numpy as np\nimport onnx\nimport torch\n\n\ndef convert_onnx(net, path_module, output, opset=11, simplify=False):\n "
},
{
"path": "src/face3d/models/arcface_torch/train.py",
"chars": 6023,
"preview": "import argparse\nimport logging\nimport os\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\n"
},
{
"path": "src/face3d/models/arcface_torch/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/face3d/models/arcface_torch/utils/plot.py",
"chars": 2222,
"preview": "# coding: utf-8\n\nimport os\nfrom pathlib import Path\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as"
},
{
"path": "src/face3d/models/arcface_torch/utils/utils_amp.py",
"chars": 3286,
"preview": "from typing import Dict, List\n\nimport torch\n\nif torch.__version__ < '1.9':\n Iterable = torch._six.container_abcs.Iter"
},
{
"path": "src/face3d/models/arcface_torch/utils/utils_callbacks.py",
"chars": 5038,
"preview": "import logging\nimport os\nimport time\nfrom typing import List\n\nimport torch\n\nfrom eval import verification\nfrom utils.uti"
},
{
"path": "src/face3d/models/arcface_torch/utils/utils_config.py",
"chars": 571,
"preview": "import importlib\nimport os.path as osp\n\n\ndef get_config(config_file):\n assert config_file.startswith('configs/'), 'co"
},
{
"path": "src/face3d/models/arcface_torch/utils/utils_logging.py",
"chars": 1110,
"preview": "import logging\nimport os\nimport sys\n\n\nclass AverageMeter(object):\n \"\"\"Computes and stores the average and current val"
},
{
"path": "src/face3d/models/arcface_torch/utils/utils_os.py",
"chars": 0,
"preview": ""
},
{
"path": "src/face3d/models/base_model.py",
"chars": 13168,
"preview": "\"\"\"This script defines the base network model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os\nimport numpy as np\nimport torch"
},
{
"path": "src/face3d/models/bfm.py",
"chars": 12349,
"preview": "\"\"\"This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nimport torch\nim"
},
{
"path": "src/face3d/models/facerecon_model.py",
"chars": 10843,
"preview": "\"\"\"This script defines the face reconstruction model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nimport torch\nfr"
},
{
"path": "src/face3d/models/losses.py",
"chars": 4171,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom kornia.geometry import warp_affine\nimport torch.nn.functional"
},
{
"path": "src/face3d/models/networks.py",
"chars": 20766,
"preview": "\"\"\"This script defines deep neural networks for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os\nimport numpy as np\nimport torch.n"
},
{
"path": "src/face3d/models/template_model.py",
"chars": 5970,
"preview": "\"\"\"Model class template\n\nThis module provides a template for users to implement custom models.\nYou can specify '--model "
},
{
"path": "src/face3d/options/__init__.py",
"chars": 136,
"preview": "\"\"\"This package options includes option modules: training options, test options, and basic options (used in both trainin"
},
{
"path": "src/face3d/options/base_options.py",
"chars": 7493,
"preview": "\"\"\"This script contains base options for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport argparse\nimport os\nfrom util import util\nim"
},
{
"path": "src/face3d/options/inference_options.py",
"chars": 1171,
"preview": "from face3d.options.base_options import BaseOptions\n\n\nclass InferenceOptions(BaseOptions):\n \"\"\"This class includes te"
},
{
"path": "src/face3d/options/test_options.py",
"chars": 830,
"preview": "\"\"\"This script contains the test options for Deep3DFaceRecon_pytorch\n\"\"\"\n\nfrom .base_options import BaseOptions\n\n\nclass "
},
{
"path": "src/face3d/options/train_options.py",
"chars": 3744,
"preview": "\"\"\"This script contains the training options for Deep3DFaceRecon_pytorch\n\"\"\"\n\nfrom .base_options import BaseOptions\nfrom"
},
{
"path": "src/face3d/util/__init__.py",
"chars": 114,
"preview": "\"\"\"This package includes a miscellaneous collection of useful helper functions.\"\"\"\nfrom src.face3d.util import *\n\n"
},
{
"path": "src/face3d/util/detect_lm68.py",
"chars": 4033,
"preview": "import os\nimport cv2\nimport numpy as np\nfrom scipy.io import loadmat\nimport tensorflow as tf\nfrom util.preprocess import"
},
{
"path": "src/face3d/util/generate_list.py",
"chars": 1346,
"preview": "\"\"\"This script is to generate training list files for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os\n\n# save path to training da"
},
{
"path": "src/face3d/util/html.py",
"chars": 3223,
"preview": "import dominate\nfrom dominate.tags import meta, h3, table, tr, td, p, a, img, br\nimport os\n\n\nclass HTML:\n \"\"\"This HTM"
},
{
"path": "src/face3d/util/load_mats.py",
"chars": 4445,
"preview": "\"\"\"This script is to load 3D face model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nfrom PIL import Image\nfrom s"
},
{
"path": "src/face3d/util/my_awing_arch.py",
"chars": 12503,
"preview": "import cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef calculate_points("
},
{
"path": "src/face3d/util/nvdiffrast.py",
"chars": 4627,
"preview": "\"\"\"This script is the differentiable renderer for Deep3DFaceRecon_pytorch\n Attention, antialiasing step is missing in"
},
{
"path": "src/face3d/util/preprocess.py",
"chars": 3336,
"preview": "\"\"\"This script contains the image preprocessing code for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nfrom scipy.io i"
},
{
"path": "src/face3d/util/skin_mask.py",
"chars": 5333,
"preview": "\"\"\"This script is to generate skin attention mask for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport math\nimport numpy as np\nimport"
},
{
"path": "src/face3d/util/test_mean_face.txt",
"chars": 3481,
"preview": "-5.228591537475585938e+01\n2.078247070312500000e-01\n-5.064269638061523438e+01\n-1.315765380859375000e+01\n-4.95293922424316"
},
{
"path": "src/face3d/util/util.py",
"chars": 6588,
"preview": "\"\"\"This script contains basic utilities for Deep3DFaceRecon_pytorch\n\"\"\"\nfrom __future__ import print_function\nimport num"
},
{
"path": "src/face3d/util/visualizer.py",
"chars": 10485,
"preview": "\"\"\"This script defines the visualizer for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nimport os\nimport sys\nimport nt"
},
{
"path": "src/face3d/visualize.py",
"chars": 1719,
"preview": "# check the sync of 3dmm feature and the audio\nimport cv2\nimport numpy as np\nfrom src.face3d.models.bfm import Parametri"
},
{
"path": "src/facerender/animate.py",
"chars": 11529,
"preview": "import os\nimport cv2\nimport yaml\nimport numpy as np\nimport warnings\nfrom skimage import img_as_ubyte\nimport safetensors\n"
},
{
"path": "src/facerender/modules/dense_motion.py",
"chars": 5866,
"preview": "from torch import nn\nimport torch.nn.functional as F\nimport torch\nfrom src.facerender.modules.util import Hourglass, mak"
},
{
"path": "src/facerender/modules/discriminator.py",
"chars": 2872,
"preview": "from torch import nn\nimport torch.nn.functional as F\nfrom facerender.modules.util import kp2gaussian\nimport torch\n\n\nclas"
},
{
"path": "src/facerender/modules/generator.py",
"chars": 11493,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom src.facerender.modules.util import ResBlock2d, Sa"
},
{
"path": "src/facerender/modules/keypoint_detector.py",
"chars": 6687,
"preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\n\nfrom src.facerender.sync_batchnorm import Synchronize"
},
{
"path": "src/facerender/modules/make_animation.py",
"chars": 6802,
"preview": "from scipy.spatial import ConvexHull\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom tqdm import tq"
},
{
"path": "src/facerender/modules/mapping.py",
"chars": 1585,
"preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MappingNet(nn.Module):\n "
},
{
"path": "src/facerender/modules/util.py",
"chars": 20173,
"preview": "from torch import nn\n\nimport torch.nn.functional as F\nimport torch\n\nfrom src.facerender.sync_batchnorm import Synchroniz"
},
{
"path": "src/facerender/sync_batchnorm/__init__.py",
"chars": 449,
"preview": "# -*- coding: utf-8 -*-\n# File : __init__.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2"
},
{
"path": "src/facerender/sync_batchnorm/batchnorm.py",
"chars": 12973,
"preview": "# -*- coding: utf-8 -*-\n# File : batchnorm.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/"
},
{
"path": "src/facerender/sync_batchnorm/comm.py",
"chars": 4449,
"preview": "# -*- coding: utf-8 -*-\n# File : comm.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2018\n"
},
{
"path": "src/facerender/sync_batchnorm/replicate.py",
"chars": 3226,
"preview": "# -*- coding: utf-8 -*-\n# File : replicate.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/"
},
{
"path": "src/facerender/sync_batchnorm/unittest.py",
"chars": 835,
"preview": "# -*- coding: utf-8 -*-\n# File : unittest.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2"
},
{
"path": "src/generate_batch.py",
"chars": 4453,
"preview": "import os\n\nfrom tqdm import tqdm\nimport torch\nimport numpy as np\nimport random\nimport scipy.io as scio\nimport src.utils."
},
{
"path": "src/generate_facerender_batch.py",
"chars": 5757,
"preview": "import os\nimport numpy as np\nfrom PIL import Image\nfrom skimage import io, img_as_float32, transform\nimport torch\nimport"
},
{
"path": "src/gradio_demo.py",
"chars": 6819,
"preview": "import torch, uuid\r\nimport os, sys, shutil\r\nfrom src.utils.preprocess import CropAndExtract\r\nfrom src.test_audio2coeff i"
},
{
"path": "src/test_audio2coeff.py",
"chars": 5393,
"preview": "import os \nimport torch\nimport numpy as np\nfrom scipy.io import savemat, loadmat\nfrom yacs.config import CfgNode as CN\nf"
},
{
"path": "src/utils/audio.py",
"chars": 4518,
"preview": "import librosa\nimport librosa.filters\nimport numpy as np\n# import tensorflow as tf\nfrom scipy import signal\nfrom scipy.i"
},
{
"path": "src/utils/croper.py",
"chars": 5870,
"preview": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport scipy\nimport numpy as np\nfrom PIL import Image\nimpor"
},
{
"path": "src/utils/face_enhancer.py",
"chars": 4436,
"preview": "import os\nimport torch \n\nfrom gfpgan import GFPGANer\n\nfrom tqdm import tqdm\n\nfrom src.utils.videoio import load_video_to"
},
{
"path": "src/utils/hparams.py",
"chars": 6055,
"preview": "from glob import glob\nimport os\n\nclass HParams:\n\tdef __init__(self, **kwargs):\n\t\tself.data = {}\n\n\t\tfor key, value in kwa"
},
{
"path": "src/utils/init_path.py",
"chars": 2607,
"preview": "import os\nimport glob\n\ndef init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'):\n\n i"
},
{
"path": "src/utils/model2safetensor.py",
"chars": 6027,
"preview": "import torch\nimport yaml\nimport os\n\nimport safetensors\nfrom safetensors.torch import save_file\nfrom yacs.config import C"
},
{
"path": "src/utils/paste_pic.py",
"chars": 2390,
"preview": "import cv2, os\nimport numpy as np\nfrom tqdm import tqdm\nimport uuid\n\nfrom src.utils.videoio import save_video_with_water"
},
{
"path": "src/utils/preprocess.py",
"chars": 6890,
"preview": "import numpy as np\nimport cv2, os, sys, torch\nfrom tqdm import tqdm\nfrom PIL import Image \n\n# 3dmm extraction\nimport saf"
},
{
"path": "src/utils/safetensor_helper.py",
"chars": 198,
"preview": "\n\ndef load_x_from_safetensor(checkpoint, key):\n x_generator = {}\n for k,v in checkpoint.items():\n if key in"
},
{
"path": "src/utils/text2speech.py",
"chars": 489,
"preview": "import os\nimport tempfile\nfrom TTS.api import TTS\n\n\nclass TTSTalker():\n def __init__(self) -> None:\n model_nam"
},
{
"path": "src/utils/videoio.py",
"chars": 1455,
"preview": "import shutil\nimport uuid\n\nimport os\n\nimport cv2\n\ndef load_video_to_cv2(input_path):\n video_stream = cv2.VideoCapture"
},
{
"path": "webui.bat",
"chars": 275,
"preview": "@echo off\n\nIF NOT EXIST venv (\npython -m venv venv\n) ELSE (\necho venv folder already exists, skipping creation...\n)\ncall"
},
{
"path": "webui.sh",
"chars": 3747,
"preview": "#!/usr/bin/env bash\n\n\n# If run from macOS, load defaults from webui-macos-env.sh\nif [[ \"$OSTYPE\" == \"darwin\"* ]]; then\n "
}
]
// ... and 2 more files (download for full content)
About this extraction
This page contains the full source code of the OpenTalker/SadTalker GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 141 files (575.1 KB), approximately 156.7k tokens, and a symbol index with 702 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.