Full Code of OpenTalker/SadTalker for AI

main cd4c0465ae0b cached
141 files
575.1 KB
156.7k tokens
702 symbols
1 requests
Download .txt
Showing preview only (613K chars total). Download the full file or copy to clipboard to get everything.
Repository: OpenTalker/SadTalker
Branch: main
Commit: cd4c0465ae0b
Files: 141
Total size: 575.1 KB

Directory structure:
gitextract_pkbb6gh_/

├── .gitignore
├── LICENSE
├── README.md
├── app_sadtalker.py
├── cog.yaml
├── docs/
│   ├── FAQ.md
│   ├── best_practice.md
│   ├── changlelog.md
│   ├── face3d.md
│   ├── install.md
│   └── webui_extension.md
├── inference.py
├── launcher.py
├── predict.py
├── quick_demo.ipynb
├── req.txt
├── requirements.txt
├── requirements3d.txt
├── scripts/
│   ├── download_models.sh
│   ├── extension.py
│   └── test.sh
├── src/
│   ├── audio2exp_models/
│   │   ├── audio2exp.py
│   │   └── networks.py
│   ├── audio2pose_models/
│   │   ├── audio2pose.py
│   │   ├── audio_encoder.py
│   │   ├── cvae.py
│   │   ├── discriminator.py
│   │   ├── networks.py
│   │   └── res_unet.py
│   ├── config/
│   │   ├── auido2exp.yaml
│   │   ├── auido2pose.yaml
│   │   ├── facerender.yaml
│   │   ├── facerender_still.yaml
│   │   └── similarity_Lm3D_all.mat
│   ├── face3d/
│   │   ├── data/
│   │   │   ├── __init__.py
│   │   │   ├── base_dataset.py
│   │   │   ├── flist_dataset.py
│   │   │   ├── image_folder.py
│   │   │   └── template_dataset.py
│   │   ├── extract_kp_videos.py
│   │   ├── extract_kp_videos_safe.py
│   │   ├── models/
│   │   │   ├── __init__.py
│   │   │   ├── arcface_torch/
│   │   │   │   ├── README.md
│   │   │   │   ├── backbones/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── iresnet.py
│   │   │   │   │   ├── iresnet2060.py
│   │   │   │   │   └── mobilefacenet.py
│   │   │   │   ├── configs/
│   │   │   │   │   ├── 3millions.py
│   │   │   │   │   ├── 3millions_pfc.py
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── base.py
│   │   │   │   │   ├── glint360k_mbf.py
│   │   │   │   │   ├── glint360k_r100.py
│   │   │   │   │   ├── glint360k_r18.py
│   │   │   │   │   ├── glint360k_r34.py
│   │   │   │   │   ├── glint360k_r50.py
│   │   │   │   │   ├── ms1mv3_mbf.py
│   │   │   │   │   ├── ms1mv3_r18.py
│   │   │   │   │   ├── ms1mv3_r2060.py
│   │   │   │   │   ├── ms1mv3_r34.py
│   │   │   │   │   ├── ms1mv3_r50.py
│   │   │   │   │   └── speed.py
│   │   │   │   ├── dataset.py
│   │   │   │   ├── docs/
│   │   │   │   │   ├── eval.md
│   │   │   │   │   ├── install.md
│   │   │   │   │   ├── modelzoo.md
│   │   │   │   │   └── speed_benchmark.md
│   │   │   │   ├── eval/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   └── verification.py
│   │   │   │   ├── eval_ijbc.py
│   │   │   │   ├── inference.py
│   │   │   │   ├── losses.py
│   │   │   │   ├── onnx_helper.py
│   │   │   │   ├── onnx_ijbc.py
│   │   │   │   ├── partial_fc.py
│   │   │   │   ├── requirement.txt
│   │   │   │   ├── run.sh
│   │   │   │   ├── torch2onnx.py
│   │   │   │   ├── train.py
│   │   │   │   └── utils/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── plot.py
│   │   │   │       ├── utils_amp.py
│   │   │   │       ├── utils_callbacks.py
│   │   │   │       ├── utils_config.py
│   │   │   │       ├── utils_logging.py
│   │   │   │       └── utils_os.py
│   │   │   ├── base_model.py
│   │   │   ├── bfm.py
│   │   │   ├── facerecon_model.py
│   │   │   ├── losses.py
│   │   │   ├── networks.py
│   │   │   └── template_model.py
│   │   ├── options/
│   │   │   ├── __init__.py
│   │   │   ├── base_options.py
│   │   │   ├── inference_options.py
│   │   │   ├── test_options.py
│   │   │   └── train_options.py
│   │   ├── util/
│   │   │   ├── BBRegressorParam_r.mat
│   │   │   ├── __init__.py
│   │   │   ├── detect_lm68.py
│   │   │   ├── generate_list.py
│   │   │   ├── html.py
│   │   │   ├── load_mats.py
│   │   │   ├── my_awing_arch.py
│   │   │   ├── nvdiffrast.py
│   │   │   ├── preprocess.py
│   │   │   ├── skin_mask.py
│   │   │   ├── test_mean_face.txt
│   │   │   ├── util.py
│   │   │   └── visualizer.py
│   │   └── visualize.py
│   ├── facerender/
│   │   ├── animate.py
│   │   ├── modules/
│   │   │   ├── dense_motion.py
│   │   │   ├── discriminator.py
│   │   │   ├── generator.py
│   │   │   ├── keypoint_detector.py
│   │   │   ├── make_animation.py
│   │   │   ├── mapping.py
│   │   │   └── util.py
│   │   └── sync_batchnorm/
│   │       ├── __init__.py
│   │       ├── batchnorm.py
│   │       ├── comm.py
│   │       ├── replicate.py
│   │       └── unittest.py
│   ├── generate_batch.py
│   ├── generate_facerender_batch.py
│   ├── gradio_demo.py
│   ├── test_audio2coeff.py
│   └── utils/
│       ├── audio.py
│       ├── croper.py
│       ├── face_enhancer.py
│       ├── hparams.py
│       ├── init_path.py
│       ├── model2safetensor.py
│       ├── paste_pic.py
│       ├── preprocess.py
│       ├── safetensor_helper.py
│       ├── text2speech.py
│       └── videoio.py
├── webui.bat
└── webui.sh

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

examples/results/*
gfpgan/*
checkpoints/*
assets/*
results/*
Dockerfile
start_docker.sh
start.sh

checkpoints

# Mac
.DS_Store


================================================
FILE: LICENSE
================================================
Tencent is pleased to support the open source community by making SadTalker available.

Copyright (C), a Tencent company. All rights reserved.

SadTalker is licensed under the Apache 2.0 License, except for the third-party components listed below.

Terms of the Apache License Version 2.0:
---------------------------------------------
                                Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align="center">

<img src='https://user-images.githubusercontent.com/4397546/229094115-862c747e-7397-4b54-ba4a-bd368bfe2e0f.png' width='500px'/>


<!--<h2> 😭 SadTalker: <span style="font-size:12px">Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation </span> </h2> -->

  <a href='https://arxiv.org/abs/2211.12194'><img src='https://img.shields.io/badge/ArXiv-PDF-red'></a> &nbsp; <a href='https://sadtalker.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp; [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) &nbsp; [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/vinthony/SadTalker) &nbsp; [![sd webui-colab](https://img.shields.io/badge/Automatic1111-Colab-green)](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) &nbsp; <br> [![Replicate](https://replicate.com/cjwbw/sadtalker/badge)](https://replicate.com/cjwbw/sadtalker) [![Discord](https://dcbadge.vercel.app/api/server/rrayYqZ4tf?style=flat)](https://discord.gg/rrayYqZ4tf)

<div>
    <a target='_blank'>Wenxuan Zhang <sup>*,1,2</sup> </a>&emsp;
    <a href='https://vinthony.github.io/' target='_blank'>Xiaodong Cun <sup>*,2</a>&emsp;
    <a href='https://xuanwangvc.github.io/' target='_blank'>Xuan Wang <sup>3</sup></a>&emsp;
    <a href='https://yzhang2016.github.io/' target='_blank'>Yong Zhang <sup>2</sup></a>&emsp;
    <a href='https://xishen0220.github.io/' target='_blank'>Xi Shen <sup>2</sup></a>&emsp; </br>
    <a href='https://yuguo-xjtu.github.io/' target='_blank'>Yu Guo<sup>1</sup> </a>&emsp;
    <a href='https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ' target='_blank'>Ying Shan <sup>2</sup> </a>&emsp;
    <a target='_blank'>Fei Wang <sup>1</sup> </a>&emsp;
</div>
<br>
<div>
    <sup>1</sup> Xi'an Jiaotong University &emsp; <sup>2</sup> Tencent AI Lab &emsp; <sup>3</sup> Ant Group &emsp; 
</div>
<br>
<i><strong><a href='https://arxiv.org/abs/2211.12194' target='_blank'>CVPR 2023</a></strong></i>
<br>
<br>


![sadtalker](https://user-images.githubusercontent.com/4397546/222490039-b1f6156b-bf00-405b-9fda-0c9a9156f991.gif)

<b>TL;DR: &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; single portrait image 🙎‍♂️  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; audio 🎤  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; =  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; talking head video 🎞.</b>

<br>

</div>



## Highlights

- The license has been updated to Apache 2.0, and we've removed the non-commercial restriction
- **SadTalker has now officially been integrated into Discord, where you can use it for free by sending files. You can also generate high-quailty videos from text prompts. Join: [![Discord](https://dcbadge.vercel.app/api/server/rrayYqZ4tf?style=flat)](https://discord.gg/rrayYqZ4tf)**

- We've published a [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) extension. Check out more details [here](docs/webui_extension.md). [Demo Video](https://user-images.githubusercontent.com/4397546/231495639-5d4bb925-ea64-4a36-a519-6389917dac29.mp4)

- Full image mode is now available! [More details...](https://github.com/OpenTalker/SadTalker#full-bodyimage-generation)

| still+enhancer in v0.0.1                 | still + enhancer   in v0.0.2       |   [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) |
|:--------------------: |:--------------------: | :----: |
| <video  src="https://user-images.githubusercontent.com/48216707/229484996-5d7be64f-2553-4c9e-a452-c5cf0b8ebafe.mp4" type="video/mp4"> </video> | <video  src="https://user-images.githubusercontent.com/4397546/230717873-355b7bf3-d3de-49f9-a439-9220e623fce7.mp4" type="video/mp4"> </video>  | <img src='./examples/source_image/full_body_2.png' width='380'> 

- Several new modes (Still, reference, and resize modes) are now available!

- We're happy to see more community demos on [bilibili](https://search.bilibili.com/all?keyword=sadtalker), [YouTube](https://www.youtube.com/results?search_query=sadtalker) and [X (#sadtalker)](https://twitter.com/search?q=%23sadtalker&src).

## Changelog 

The previous changelog can be found [here](docs/changlelog.md).

- __[2023.06.12]__: Added more new features in WebUI extension, see the discussion [here](https://github.com/OpenTalker/SadTalker/discussions/386).

- __[2023.06.05]__: Released a new 512x512px (beta) face model. Fixed some bugs and improve the performance.

- __[2023.04.15]__: Added a WebUI Colab notebook by [@camenduru](https://github.com/camenduru/): [![sd webui-colab](https://img.shields.io/badge/Automatic1111-Colab-green)](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb)

- __[2023.04.12]__: Added a more detailed WebUI installation document and fixed a problem when reinstalling.

- __[2023.04.12]__: Fixed the WebUI safe issues becasue of 3rd-party packages, and optimized the output path in `sd-webui-extension`.

- __[2023.04.08]__: In v0.0.2, we added a logo watermark to the generated video to prevent abuse. _This watermark has since been removed in a later release._

- __[2023.04.08]__: In v0.0.2, we added features for full image animation and a link to download checkpoints from Baidu. We also optimized the enhancer logic.

## To-Do

We're tracking new updates in [issue #280](https://github.com/OpenTalker/SadTalker/issues/280).

## Troubleshooting

If you have any problems, please read our [FAQs](docs/FAQ.md) before opening an issue.



## 1. Installation.

Community tutorials: [中文Windows教程 (Chinese Windows tutorial)](https://www.bilibili.com/video/BV1Dc411W7V6/) | [日本語コース (Japanese tutorial)](https://br-d.fanbox.cc/posts/5685086).

### Linux/Unix

1. Install [Anaconda](https://www.anaconda.com/), Python and `git`.

2. Creating the env and install the requirements.
  ```bash
  git clone https://github.com/OpenTalker/SadTalker.git

  cd SadTalker 

  conda create -n sadtalker python=3.8

  conda activate sadtalker

  pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113

  conda install ffmpeg

  pip install -r requirements.txt

  ### Coqui TTS is optional for gradio demo. 
  ### pip install TTS

  ```  
### Windows

A video tutorial in chinese is available [here](https://www.bilibili.com/video/BV1Dc411W7V6/). You can also follow the following instructions:

1. Install [Python 3.8](https://www.python.org/downloads/windows/) and check "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win) manually or using [Scoop](https://scoop.sh/): `scoop install git`.
3. Install `ffmpeg`, following [this tutorial](https://www.wikihow.com/Install-FFmpeg-on-Windows) or using [scoop](https://scoop.sh/): `scoop install ffmpeg`.
4. Download the SadTalker repository by running `git clone https://github.com/Winfredy/SadTalker.git`.
5. Download the checkpoints and gfpgan models in the [downloads section](#2-download-models).
6. Run `start.bat` from Windows Explorer as normal, non-administrator, user, and a Gradio-powered WebUI demo will be started.

### macOS

A tutorial on installing SadTalker on macOS can be found [here](docs/install.md).

### Docker, WSL, etc

Please check out additional tutorials [here](docs/install.md).

## 2. Download Models

You can run the following script on Linux/macOS to automatically download all the models:

```bash
bash scripts/download_models.sh
```

We also provide an offline patch (`gfpgan/`), so no model will be downloaded when generating.

### Pre-Trained Models

* [Google Drive](https://drive.google.com/file/d/1gwWh45pF7aelNP_P78uDJL8Sycep-K7j/view?usp=sharing)
* [GitHub Releases](https://github.com/OpenTalker/SadTalker/releases)
* [Baidu (百度云盘)](https://pan.baidu.com/s/1kb1BCPaLOWX1JJb9Czbn6w?pwd=sadt) (Password: `sadt`)

<!-- TODO add Hugging Face links -->

### GFPGAN Offline Patch

* [Google Drive](https://drive.google.com/file/d/19AIBsmfcHW6BRJmeqSFlG5fL445Xmsyi?usp=sharing)
* [GitHub Releases](https://github.com/OpenTalker/SadTalker/releases)
* [Baidu (百度云盘)](https://pan.baidu.com/s/1P4fRgk9gaSutZnn8YW034Q?pwd=sadt) (Password: `sadt`)

<!-- TODO add Hugging Face links -->


<details><summary>Model Details</summary>


Model explains:

##### New version 
| Model | Description
| :--- | :----------
|checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/SadTalker_V0.0.2_256.safetensors | packaged sadtalker checkpoints of old version, 256 face render).
|checkpoints/SadTalker_V0.0.2_512.safetensors | packaged sadtalker checkpoints of old version, 512 face render).
|gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`.
  
  
##### Old version
| Model | Description
| :--- | :----------
|checkpoints/auido2exp_00300-model.pth | Pre-trained ExpNet in Sadtalker.
|checkpoints/auido2pose_00140-model.pth | Pre-trained PoseVAE in Sadtalker.
|checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/facevid2vid_00189-model.pth.tar | Pre-trained face-vid2vid model from [the reappearance of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis).
|checkpoints/epoch_20.pth | Pre-trained 3DMM extractor in [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction).
|checkpoints/wav2lip.pth | Highly accurate lip-sync model in [Wav2lip](https://github.com/Rudrabha/Wav2Lip).
|checkpoints/shape_predictor_68_face_landmarks.dat | Face landmark model used in [dilb](http://dlib.net/). 
|checkpoints/BFM | 3DMM library file.  
|checkpoints/hub | Face detection models used in [face alignment](https://github.com/1adrianb/face-alignment).
|gfpgan/weights | Face detection and enhanced models used in `facexlib` and `gfpgan`.

The final folder will be shown as:

<img width="331" alt="image" src="https://user-images.githubusercontent.com/4397546/232511411-4ca75cbf-a434-48c5-9ae0-9009e8316484.png">


</details>

## 3. Quick Start

Please read our document on [best practices and configuration tips](docs/best_practice.md)

### WebUI Demos

**Online Demo**: [HuggingFace](https://huggingface.co/spaces/vinthony/SadTalker) | [SDWebUI-Colab](https://colab.research.google.com/github/camenduru/stable-diffusion-webui-colab/blob/main/video/stable/stable_diffusion_1_5_video_webui_colab.ipynb) | [Colab](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)

**Local WebUI extension**: Please refer to [WebUI docs](docs/webui_extension.md).

**Local gradio demo (recommanded)**: A Gradio instance similar to our [Hugging Face demo](https://huggingface.co/spaces/vinthony/SadTalker) can be run locally:

```bash
## you need manually install TTS(https://github.com/coqui-ai/TTS) via `pip install tts` in advanced.
python app_sadtalker.py
```

You can also start it more easily:

- windows: just double click `webui.bat`, the requirements will be installed automatically.
- Linux/Mac OS: run `bash webui.sh` to start the webui.


### CLI usage

##### Animating a portrait image from default config:
```bash
python inference.py --driven_audio <audio.wav> \
                    --source_image <video.mp4 or picture.png> \
                    --enhancer gfpgan 
```
The results will be saved in `results/$SOME_TIMESTAMP/*.mp4`.

##### Full body/image Generation:

Using `--still` to generate a natural full body video. You can add `enhancer` to improve the quality of the generated video. 

```bash
python inference.py --driven_audio <audio.wav> \
                    --source_image <video.mp4 or picture.png> \
                    --result_dir <a file to store results> \
                    --still \
                    --preprocess full \
                    --enhancer gfpgan 
```

More examples and configuration and tips can be founded in the [ >>> best practice documents <<<](docs/best_practice.md).

## Citation

If you find our work useful in your research, please consider citing:

```bibtex
@article{zhang2022sadtalker,
  title={SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation},
  author={Zhang, Wenxuan and Cun, Xiaodong and Wang, Xuan and Zhang, Yong and Shen, Xi and Guo, Yu and Shan, Ying and Wang, Fei},
  journal={arXiv preprint arXiv:2211.12194},
  year={2022}
}
```

## Acknowledgements

Facerender code borrows heavily from [zhanglonghao's reproduction of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis) and [PIRender](https://github.com/RenYurui/PIRender). We thank the authors for sharing their wonderful code. In training process, we also used the model from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction) and [Wav2lip](https://github.com/Rudrabha/Wav2Lip). We thank for their wonderful work.

We also use the following 3rd-party libraries:

- **Face Utils**: https://github.com/xinntao/facexlib
- **Face Enhancement**: https://github.com/TencentARC/GFPGAN
- **Image/Video Enhancement**:https://github.com/xinntao/Real-ESRGAN

## Extensions:

- [SadTalker-Video-Lip-Sync](https://github.com/Zz-ww/SadTalker-Video-Lip-Sync) from [@Zz-ww](https://github.com/Zz-ww): SadTalker for Video Lip Editing

## Related Works
- [StyleHEAT: One-Shot High-Resolution Editable Talking Face Generation via Pre-trained StyleGAN (ECCV 2022)](https://github.com/FeiiYin/StyleHEAT)
- [CodeTalker: Speech-Driven 3D Facial Animation with Discrete Motion Prior (CVPR 2023)](https://github.com/Doubiiu/CodeTalker)
- [VideoReTalking: Audio-based Lip Synchronization for Talking Head Video Editing In the Wild (SIGGRAPH Asia 2022)](https://github.com/vinthony/video-retalking)
- [DPE: Disentanglement of Pose and Expression for General Video Portrait Editing (CVPR 2023)](https://github.com/Carlyx/DPE)
- [3D GAN Inversion with Facial Symmetry Prior (CVPR 2023)](https://github.com/FeiiYin/SPI/)
- [T2M-GPT: Generating Human Motion from Textual Descriptions with Discrete Representations (CVPR 2023)](https://github.com/Mael-zys/T2M-GPT)

## Disclaimer

This is not an official product of Tencent. 

```
1. Please carefully read and comply with the open-source license applicable to this code before using it. 
2. Please carefully read and comply with the intellectual property declaration applicable to this code before using it.
3. This open-source code runs completely offline and does not collect any personal information or other data. If you use this code to provide services to end-users and collect related data, please take necessary compliance measures according to applicable laws and regulations (such as publishing privacy policies, adopting necessary data security strategies, etc.). If the collected data involves personal information, user consent must be obtained (if applicable). Any legal liabilities arising from this are unrelated to Tencent.
4. Without Tencent's written permission, you are not authorized to use the names or logos legally owned by Tencent, such as "Tencent." Otherwise, you may be liable for legal responsibilities.
5. This open-source code does not have the ability to directly provide services to end-users. If you need to use this code for further model training or demos, as part of your product to provide services to end-users, or for similar use, please comply with applicable laws and regulations for your product or service. Any legal liabilities arising from this are unrelated to Tencent.
6. It is prohibited to use this open-source code for activities that harm the legitimate rights and interests of others (including but not limited to fraud, deception, infringement of others' portrait rights, reputation rights, etc.), or other behaviors that violate applicable laws and regulations or go against social ethics and good customs (including providing incorrect or false information, spreading pornographic, terrorist, and violent information, etc.). Otherwise, you may be liable for legal responsibilities.
```

LOGO: color and font suggestion: [ChatGPT](https://chat.openai.com), logo font: [Montserrat Alternates
](https://fonts.google.com/specimen/Montserrat+Alternates?preview.text=SadTalker&preview.text_type=custom&query=mont).

All the copyrights of the demo images and audio are from community users or the generation from stable diffusion. Feel free to contact us if you would like use to remove them.


<!-- Spelling fixed on Tuesday, September 12, 2023 by @fakerybakery (https://github.com/fakerybakery). These changes are licensed under the Apache 2.0 license. -->


================================================
FILE: app_sadtalker.py
================================================
import os, sys
import gradio as gr
from src.gradio_demo import SadTalker  


try:
    import webui  # in webui
    in_webui = True
except:
    in_webui = False


def toggle_audio_file(choice):
    if choice == False:
        return gr.update(visible=True), gr.update(visible=False)
    else:
        return gr.update(visible=False), gr.update(visible=True)
    
def ref_video_fn(path_of_ref_video):
    if path_of_ref_video is not None:
        return gr.update(value=True)
    else:
        return gr.update(value=False)

def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None):

    sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)

    with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
        gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
                    <a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
                    <a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a>  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
                     <a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")
        
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
                with gr.Tabs(elem_id="sadtalker_source_image"):
                    with gr.TabItem('Upload image'):
                        with gr.Row():
                            source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)

                with gr.Tabs(elem_id="sadtalker_driven_audio"):
                    with gr.TabItem('Upload OR TTS'):
                        with gr.Column(variant='panel'):
                            driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")

                        if sys.platform != 'win32' and not in_webui: 
                            from src.utils.text2speech import TTSTalker
                            tts_talker = TTSTalker()
                            with gr.Column(variant='panel'):
                                input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
                                tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
                                tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
                            
            with gr.Column(variant='panel'): 
                with gr.Tabs(elem_id="sadtalker_checkbox"):
                    with gr.TabItem('Settings'):
                        gr.Markdown("need help? please visit our [best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md) for more detials")
                        with gr.Column(variant='panel'):
                            # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
                            # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
                            pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) # 
                            size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") # 
                            preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
                            is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
                            batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2)
                            enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
                            submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
                            
                with gr.Tabs(elem_id="sadtalker_genearted"):
                        gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)

        if warpfn:
            submit.click(
                        fn=warpfn(sad_talker.test), 
                        inputs=[source_image,
                                driven_audio,
                                preprocess_type,
                                is_still_mode,
                                enhancer,
                                batch_size,                            
                                size_of_image,
                                pose_style
                                ], 
                        outputs=[gen_video]
                        )
        else:
            submit.click(
                        fn=sad_talker.test, 
                        inputs=[source_image,
                                driven_audio,
                                preprocess_type,
                                is_still_mode,
                                enhancer,
                                batch_size,                            
                                size_of_image,
                                pose_style
                                ], 
                        outputs=[gen_video]
                        )

    return sadtalker_interface
 

if __name__ == "__main__":

    demo = sadtalker_demo()
    demo.queue()
    demo.launch()




================================================
FILE: cog.yaml
================================================
build:
  gpu: true
  cuda: "11.3"
  python_version: "3.8"
  system_packages:
    - "ffmpeg"
    - "libgl1-mesa-glx"
    - "libglib2.0-0"
  python_packages:
    - "torch==1.12.1"
    - "torchvision==0.13.1"
    - "torchaudio==0.12.1"
    - "joblib==1.1.0"
    - "scikit-image==0.19.3"
    - "basicsr==1.4.2"
    - "facexlib==0.3.0"
    - "resampy==0.3.1"
    - "pydub==0.25.1"
    - "scipy==1.10.1"
    - "kornia==0.6.8"
    - "face_alignment==1.3.5"
    - "imageio==2.19.3"
    - "imageio-ffmpeg==0.4.7"
    - "librosa==0.9.2" #
    - "tqdm==4.65.0"
    - "yacs==0.1.8"
    - "gfpgan==1.3.8"
    - "dlib-bin==19.24.1"
    - "av==10.0.0"
    - "trimesh==3.9.20"
  run:
    - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth"
    - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip"

predict: "predict.py:Predictor"


================================================
FILE: docs/FAQ.md
================================================

## Frequency Asked Question

**Q: `ffmpeg` is not recognized as an internal or external command**

In Linux, you can install the ffmpeg via `conda install ffmpeg`. Or on Mac OS X, try to install ffmpeg via `brew install ffmpeg`. On windows, make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), then, following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) installation to install `ffmpeg`.

**Q: Running Requirments.**

Please refer to the discussion here: https://github.com/Winfredy/SadTalker/issues/124#issuecomment-1508113989


**Q: ModuleNotFoundError: No module named 'ai'**

please check the checkpoint's size of the `epoch_20.pth`. (https://github.com/Winfredy/SadTalker/issues/167, https://github.com/Winfredy/SadTalker/issues/113)

**Q: Illegal Hardware Error: Mac M1**

please reinstall the `dlib` by `pip install dlib` individually. (https://github.com/Winfredy/SadTalker/issues/129, https://github.com/Winfredy/SadTalker/issues/109)


**Q: FileNotFoundError: [Errno 2] No such file or directory: checkpoints\BFM_Fitting\similarity_Lm3D_all.mat**

Make sure you have downloaded the checkpoints and gfpgan as [here](https://github.com/Winfredy/SadTalker#-2-download-trained-models) and placed them in the right place. 

**Q: RuntimeError: unexpected EOF, expected 237192 more bytes. The file might be corrupted.**

The files are not automatically downloaded. Please update the code and download the gfpgan folders as [here](https://github.com/Winfredy/SadTalker#-2-download-trained-models).

**Q: CUDA out of memory error**

please refer to https://stackoverflow.com/questions/73747731/runtimeerror-cuda-out-of-memory-how-setting-max-split-size-mb

``` 
# windows
set PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 
python inference.py ...

# linux
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 
python inference.py ...
```

**Q: Error while decoding stream #0:0: Invalid data found when processing input [mp3float @ 0000015037628c00] Header missing**

Our method only support wav or mp3 files as input, please make sure the feeded audios are in these formats.


================================================
FILE: docs/best_practice.md
================================================
# Best Practices and Tips for configuration

> Our model only works on REAL people or the portrait image similar to REAL person. The anime talking head genreation method will be released in future.

Advanced confiuration options for `inference.py`:

| Name        | Configuration | default |   Explaination  | 
|:------------- |:------------- |:----- | :------------- |
| Enhance Mode | `--enhancer` | None | Using `gfpgan` or `RestoreFormer` to enhance the generated face via face restoration network 
| Background Enhancer | `--background_enhancer` | None | Using `realesrgan` to enhance the full video. 
| Still Mode   | ` --still` | False |  Using the same pose parameters as the original image, fewer head motion.
| Expressive Mode | `--expression_scale` | 1.0 | a larger value will make the expression motion stronger.
| save path | `--result_dir` |`./results` | The file will be save in the newer location.
| preprocess | `--preprocess` | `crop` | Run and produce the results in the croped input image. Other choices: `resize`, where the images will be resized to the specific resolution. `full` Run the full image animation, use with `--still` to get better results.
| ref Mode (eye) | `--ref_eyeblink` | None | A video path, where we borrow the eyeblink from this reference video to provide more natural eyebrow movement.
| ref Mode (pose) | `--ref_pose` | None | A video path, where we borrow the pose from the head reference video. 
| 3D Mode | `--face3dvis` | False | Need additional installation. More details to generate the 3d face can be founded [here](docs/face3d.md). 
| free-view Mode | `--input_yaw`,<br> `--input_pitch`,<br> `--input_roll` | None | Genearting novel view or free-view 4D talking head from a single image. More details can be founded [here](https://github.com/Winfredy/SadTalker#generating-4d-free-view-talking-examples-from-audio-and-a-single-image).


### About `--preprocess`

Our system automatically handles the input images via `crop`, `resize` and `full`.

In `crop` mode, we only generate the croped image via the facial keypoints and generated the facial anime avator. The animation of both expression and head pose are realistic.

> Still mode will stop the eyeblink and head pose movement.

|  [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) | crop | crop w/still |
|:--------------------: |:--------------------: | :----: |
| <img src='../examples/source_image/full_body_2.png' width='380'> | ![full_body_2](example_crop.gif) | ![full_body_2](example_crop_still.gif) |


In `resize` mode, we resize the whole images to generate the fully talking head video. Thus, an image similar to the ID photo can be produced. ⚠️ It will produce bad results for full person images.


 

| <img src='../examples/source_image/full_body_2.png' width='380'> |  <img src='../examples/source_image/full4.jpeg' width='380'> |
|:--------------------: |:--------------------: |
| ❌ not suitable for resize mode | ✅ good for resize mode |
| <img src='resize_no.gif'> |  <img src='resize_good.gif' width='380'> |

In `full` mode, our model will automatically process the croped region and paste back to the original image. Remember to use `--still` to keep the original head pose.

| input | `--still` | `--still` & `enhancer` |
|:--------------------: |:--------------------: | :--:|
| <img src='../examples/source_image/full_body_2.png' width='380'> |  <img src='./example_full.gif' width='380'> |  <img src='./example_full_enhanced.gif' width='380'> 


### About `--enhancer`

For higher resolution, we intergate [gfpgan](https://github.com/TencentARC/GFPGAN) and [real-esrgan](https://github.com/xinntao/Real-ESRGAN) for different purpose. Just adding `--enhancer <gfpgan or RestoreFormer>` or `--background_enhancer <realesrgan>` for the enhancement of the face and the full image.

```bash
# make sure above packages are available:
pip install gfpgan
pip install realesrgan
```

### About `--face3dvis`

This flag indicate that we can generated the 3d-rendered face and it's 3d facial landmarks. More details can be founded [here](face3d.md).

| Input        | Animated 3d face | 
|:-------------: | :-------------: |
|  <img src='../examples/source_image/art_0.png' width='200px'> | <video src="https://user-images.githubusercontent.com/4397546/226856847-5a6a0a4d-a5ec-49e2-9b05-3206db65e8e3.mp4"></video>  | 

> Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.



#### Reference eye-link mode.

| Input, w/ reference video   ,  reference video    | 
|:-------------: | 
|  ![free_view](using_ref_video.gif)| 
| If the reference video is shorter than the input audio, we will loop the reference video . 



#### Generating 4D free-view talking examples from audio and a single image

We use `input_yaw`, `input_pitch`, `input_roll` to control head pose. For example, `--input_yaw -20 30 10` means the input head yaw degree changes from -20 to 30 and then changes from 30 to 10.
```bash
python inference.py --driven_audio <audio.wav> \
                    --source_image <video.mp4 or picture.png> \
                    --result_dir <a file to store results> \
                    --input_yaw -20 30 10
```

| Results, Free-view results,  Novel view results  | 
|:-------------: | 
|  ![free_view](free_view_result.gif)| 


================================================
FILE: docs/changlelog.md
================================================
## changelogs


- __[2023.04.06]__: stable-diffiusion webui extension is release.

- __[2023.04.03]__: Enable TTS in huggingface and gradio local demo.

- __[2023.03.30]__: Launch beta version of the full body mode.

- __[2023.03.30]__: Launch new feature: through using reference videos, our algorithm can generate videos with more natural eye blinking and some eyebrow movement.

- __[2023.03.29]__: `resize mode` is online by `python infererence.py --preprocess resize`! Where we can produce a larger crop of the image as discussed in https://github.com/Winfredy/SadTalker/issues/35.

- __[2023.03.29]__: local gradio demo is online! `python app.py` to start the demo. New `requirments.txt` is used to avoid the bugs in `librosa`.

- __[2023.03.28]__: Online demo is launched in [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/vinthony/SadTalker), thanks AK!
 
- __[2023.03.22]__: Launch new feature: generating the 3d face animation from a single image. New applications about it will be updated.

- __[2023.03.22]__: Launch new feature: `still mode`, where only a small head pose will be produced via `python inference.py --still`. 

- __[2023.03.18]__: Support `expression intensity`, now you can change the intensity of the generated motion: `python inference.py --expression_scale 1.3 (some value > 1)`.

- __[2023.03.18]__: Reconfig the data folders, now you can download the checkpoint automatically using `bash scripts/download_models.sh`.
- __[2023.03.18]__: We have offically integrate the [GFPGAN](https://github.com/TencentARC/GFPGAN) for face enhancement, using `python inference.py --enhancer gfpgan` for  better visualization performance.
- __[2023.03.14]__: Specify the version of package `joblib` to remove the errors in using `librosa`, [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) is online!
- __[2023.03.06]__: Solve some bugs in code and errors in installation 
- __[2023.03.03]__: Release the test code for audio-driven single image animation!
- __[2023.02.28]__: SadTalker has been accepted by CVPR 2023!


================================================
FILE: docs/face3d.md
================================================
## 3D Face Visualization

We use `pytorch3d` to visualize the 3D faces from a single image.

The requirements for 3D visualization are difficult to install, so here's a tutorial:

```bash
git clone https://github.com/OpenTalker/SadTalker.git
cd SadTalker 
conda create -n sadtalker3d python=3.8
source activate sadtalker3d

conda install ffmpeg
conda install -c fvcore -c iopath -c conda-forge fvcore iopath
conda install libgcc gmp

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113

# insintall pytorch3d
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html

pip install -r requirements3d.txt

### install gpfgan for enhancer
pip install git+https://github.com/TencentARC/GFPGAN


### when occurs gcc version problem `from pytorch import _C` from pytorch3d, add the anaconda path to LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/$YOUR_ANACONDA_PATH/lib/

``` 

Then, generate the result via:

```bash


python inference.py --driven_audio <audio.wav> \
                    --source_image <video.mp4 or picture.png> \
                    --result_dir <a file to store results> \
                    --face3dvis

```

The results will appear, named `face3d.mp4`.

More applications about 3D face rendering will be released soon.


================================================
FILE: docs/install.md
================================================
### macOS

This method has been tested on a M1 Mac (13.3)

```bash
git clone https://github.com/OpenTalker/SadTalker.git
cd SadTalker 
conda create -n sadtalker python=3.8
conda activate sadtalker
# install pytorch 2.0
pip install torch torchvision torchaudio
conda install ffmpeg
pip install -r requirements.txt
pip install dlib # macOS needs to install the original dlib.
```

### Windows Native

- Make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) tutorial to install `ffmpeg` or using scoop.


### Windows WSL


- Make sure the environment: `export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH`


### Docker Installation

A community Docker image by [@thegenerativegeneration](https://github.com/thegenerativegeneration) is available on the [Docker hub](https://hub.docker.com/repository/docker/wawa9000/sadtalker), which can be used directly:
```bash
docker run --gpus "all" --rm -v $(pwd):/host_dir wawa9000/sadtalker \
    --driven_audio /host_dir/deyu.wav \
    --source_image /host_dir/image.jpg \
    --expression_scale 1.0 \
    --still \
    --result_dir /host_dir
```



================================================
FILE: docs/webui_extension.md
================================================
## Run SadTalker as a Stable Diffusion WebUI Extension.

1. Install the lastest version of [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and install SadTalker via `extension`.
<img width="726" alt="image" src="https://user-images.githubusercontent.com/4397546/230698519-267d1d1f-6e99-4dd4-81e1-7b889259efbd.png">

2. Download the checkpoints manually, for Linux and Mac:

    ```bash

    cd SOMEWHERE_YOU_LIKE

    bash <(wget -qO- https://raw.githubusercontent.com/Winfredy/OpenTalker/main/scripts/download_models.sh)
    ```

    For Windows, you can download all the checkpoints [here](https://github.com/OpenTalker/SadTalker/tree/main#2-download-models).

3.1. Option 1: put the checkpoint in `stable-diffusion-webui/models/SadTalker` or `stable-diffusion-webui/extensions/SadTalker/checkpoints/`, the checkpoints will be detected automatically.

3.2. Option 2: Set the path of `SADTALKTER_CHECKPOINTS` in `webui_user.sh`(linux) or `webui_user.bat`(windows) by:

    > only works if you are directly starting webui from `webui_user.sh` or `webui_user.bat`.

    ```bash
    # Windows (webui_user.bat)
    set SADTALKER_CHECKPOINTS=D:\SadTalker\checkpoints

    # Linux/macOS (webui_user.sh)
    export SADTALKER_CHECKPOINTS=/path/to/SadTalker/checkpoints
    ```

4. Start the WebUI via `webui.sh or webui_user.sh(linux)` or `webui_user.bat(windows)` or any other method. SadTalker can also be used in stable-diffusion-webui directly.
    
    <img width="726" alt="image" src="https://user-images.githubusercontent.com/4397546/230698614-58015182-2916-4240-b324-e69022ef75b3.png">
    
## Questions

1. if you are running on CPU, you need to specific `--disable-safe-unpickle` in `webui_user.sh` or `webui_user.bat`.

    ```bash
    # windows (webui_user.bat)
    set COMMANDLINE_ARGS="--disable-safe-unpickle"

    # linux (webui_user.sh)
    export COMMANDLINE_ARGS="--disable-safe-unpickle"
    ```



(If you're unable to use the `full` mode, please read this [discussion](https://github.com/Winfredy/SadTalker/issues/78).)


================================================
FILE: inference.py
================================================
from glob import glob
import shutil
import torch
from time import  strftime
import os, sys, time
from argparse import ArgumentParser

from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff  
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path

def main(args):
    #torch.backends.cudnn.enabled = False

    pic_path = args.source_image
    audio_path = args.driven_audio
    save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
    os.makedirs(save_dir, exist_ok=True)
    pose_style = args.pose_style
    device = args.device
    batch_size = args.batch_size
    input_yaw_list = args.input_yaw
    input_pitch_list = args.input_pitch
    input_roll_list = args.input_roll
    ref_eyeblink = args.ref_eyeblink
    ref_pose = args.ref_pose

    current_root_path = os.path.split(sys.argv[0])[0]

    sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)

    #init model
    preprocess_model = CropAndExtract(sadtalker_paths, device)

    audio_to_coeff = Audio2Coeff(sadtalker_paths,  device)
    
    animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)

    #crop image and extract 3dmm from image
    first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
    os.makedirs(first_frame_dir, exist_ok=True)
    print('3DMM Extraction for source image')
    first_coeff_path, crop_pic_path, crop_info =  preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
                                                                             source_image_flag=True, pic_size=args.size)
    if first_coeff_path is None:
        print("Can't get the coeffs of the input")
        return

    if ref_eyeblink is not None:
        ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
        ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
        os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
        print('3DMM Extraction for the reference video providing eye blinking')
        ref_eyeblink_coeff_path, _, _ =  preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
    else:
        ref_eyeblink_coeff_path=None

    if ref_pose is not None:
        if ref_pose == ref_eyeblink: 
            ref_pose_coeff_path = ref_eyeblink_coeff_path
        else:
            ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
            ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
            os.makedirs(ref_pose_frame_dir, exist_ok=True)
            print('3DMM Extraction for the reference video providing pose')
            ref_pose_coeff_path, _, _ =  preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
    else:
        ref_pose_coeff_path=None

    #audio2ceoff
    batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
    coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)

    # 3dface render
    if args.face3dvis:
        from src.face3d.visualize import gen_composed_video
        gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
    
    #coeff2video
    data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, 
                                batch_size, input_yaw_list, input_pitch_list, input_roll_list,
                                expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
    
    result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
                                enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
    
    shutil.move(result, save_dir+'.mp4')
    print('The generated video is named:', save_dir+'.mp4')

    if not args.verbose:
        shutil.rmtree(save_dir)

    
if __name__ == '__main__':

    parser = ArgumentParser()  
    parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
    parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
    parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
    parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
    parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
    parser.add_argument("--result_dir", default='./results', help="path to output")
    parser.add_argument("--pose_style", type=int, default=0,  help="input pose style from [0, 46)")
    parser.add_argument("--batch_size", type=int, default=2,  help="the batch size of facerender")
    parser.add_argument("--size", type=int, default=256,  help="the image size of the facerender")
    parser.add_argument("--expression_scale", type=float, default=1.,  help="the batch size of facerender")
    parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
    parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
    parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
    parser.add_argument('--enhancer',  type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
    parser.add_argument('--background_enhancer',  type=str, default=None, help="background enhancer, [realesrgan]")
    parser.add_argument("--cpu", dest="cpu", action="store_true") 
    parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks") 
    parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion") 
    parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" ) 
    parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" ) 
    parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" ) 


    # net structure and parameters
    parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
    parser.add_argument('--init_path', type=str, default=None, help='Useless')
    parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
    parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
    parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')

    # default renderer parameters
    parser.add_argument('--focal', type=float, default=1015.)
    parser.add_argument('--center', type=float, default=112.)
    parser.add_argument('--camera_d', type=float, default=10.)
    parser.add_argument('--z_near', type=float, default=5.)
    parser.add_argument('--z_far', type=float, default=15.)

    args = parser.parse_args()

    if torch.cuda.is_available() and not args.cpu:
        args.device = "cuda"
    else:
        args.device = "cpu"

    main(args)



================================================
FILE: launcher.py
================================================
# this scripts installs necessary requirements and launches main program in webui.py
# borrow from : https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/launch.py
import subprocess
import os
import sys
import importlib.util
import shlex
import platform
import json

python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
skip_install = False
dir_repos = "repositories"
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
    os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'


def check_python_version():
    is_windows = platform.system() == "Windows"
    major = sys.version_info.major
    minor = sys.version_info.minor
    micro = sys.version_info.micro

    if is_windows:
        supported_minors = [10]
    else:
        supported_minors = [7, 8, 9, 10, 11]

    if not (major == 3 and minor in supported_minors):

        raise (f"""
INCOMPATIBLE PYTHON VERSION
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
If you encounter an error with "RuntimeError: Couldn't install torch." message,
or any other error regarding unsuccessful package (library) installation,
please downgrade (or upgrade) to the latest version of 3.10 Python
and delete current Python and "venv" folder in WebUI's directory.
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
Use --skip-python-version-check to suppress this warning.
""")


def commit_hash():
    global stored_commit_hash

    if stored_commit_hash is not None:
        return stored_commit_hash

    try:
        stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
    except Exception:
        stored_commit_hash = "<none>"

    return stored_commit_hash


def run(command, desc=None, errdesc=None, custom_env=None, live=False):
    if desc is not None:
        print(desc)

    if live:
        result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
        if result.returncode != 0:
            raise RuntimeError(f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}""")

        return ""

    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)

    if result.returncode != 0:

        message = f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
"""
        raise RuntimeError(message)

    return result.stdout.decode(encoding="utf8", errors="ignore")


def check_run(command):
    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    return result.returncode == 0


def is_installed(package):
    try:
        spec = importlib.util.find_spec(package)
    except ModuleNotFoundError:
        return False

    return spec is not None


def repo_dir(name):
    return os.path.join(script_path, dir_repos, name)


def run_python(code, desc=None, errdesc=None):
    return run(f'"{python}" -c "{code}"', desc, errdesc)


def run_pip(args, desc=None):
    if skip_install:
        return

    index_url_line = f' --index-url {index_url}' if index_url != '' else ''
    return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")


def check_run_python(code):
    return check_run(f'"{python}" -c "{code}"')


def git_clone(url, dir, name, commithash=None):
    # TODO clone into temporary dir and move if successful

    if os.path.exists(dir):
        if commithash is None:
            return

        current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
        if current_hash == commithash:
            return

        run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
        run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
        return

    run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")

    if commithash is not None:
        run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")


def git_pull_recursive(dir):
    for subdir, _, _ in os.walk(dir):
        if os.path.exists(os.path.join(subdir, '.git')):
            try:
                output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
                print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
            except subprocess.CalledProcessError as e:
                print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")


def run_extension_installer(extension_dir):
    path_installer = os.path.join(extension_dir, "install.py")
    if not os.path.isfile(path_installer):
        return

    try:
        env = os.environ.copy()
        env['PYTHONPATH'] = os.path.abspath(".")

        print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
    except Exception as e:
        print(e, file=sys.stderr)


def prepare_environment():
    global skip_install

    torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113")

    ## check windows 
    if sys.platform != 'win32':
        requirements_file = os.environ.get('REQS_FILE', "req.txt")
    else:
        requirements_file = os.environ.get('REQS_FILE', "requirements.txt")

    commit = commit_hash()

    print(f"Python {sys.version}")
    print(f"Commit hash: {commit}")

    if not is_installed("torch") or not is_installed("torchvision"):
        run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)

    run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)")

    if sys.platform != 'win32' and not is_installed('tts'):
        run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.")


def start():
    print(f"Launching SadTalker Web UI")
    from app_sadtalker import sadtalker_demo
    demo = sadtalker_demo()
    demo.queue()
    demo.launch()

if __name__ == "__main__":
    prepare_environment()
    start()

================================================
FILE: predict.py
================================================
"""run bash scripts/download_models.sh first to prepare the weights file"""
import os
import shutil
from argparse import Namespace
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path
from cog import BasePredictor, Input, Path

checkpoints = "checkpoints"


class Predictor(BasePredictor):
    def setup(self):
        """Load the model into memory to make running multiple predictions efficient"""
        device = "cuda"

        
        sadtalker_paths = init_path(checkpoints,os.path.join("src","config"))

        # init model
        self.preprocess_model = CropAndExtract(sadtalker_paths, device
        )

        self.audio_to_coeff = Audio2Coeff(
            sadtalker_paths,
            device,
        )

        self.animate_from_coeff = {
            "full": AnimateFromCoeff(
                sadtalker_paths,
                device,
            ),
            "others": AnimateFromCoeff(
                sadtalker_paths,
                device,
            ),
        }

    def predict(
        self,
        source_image: Path = Input(
            description="Upload the source image, it can be video.mp4 or picture.png",
        ),
        driven_audio: Path = Input(
            description="Upload the driven audio, accepts .wav and .mp4 file",
        ),
        enhancer: str = Input(
            description="Choose a face enhancer",
            choices=["gfpgan", "RestoreFormer"],
            default="gfpgan",
        ),
        preprocess: str = Input(
            description="how to preprocess the images",
            choices=["crop", "resize", "full"],
            default="full",
        ),
        ref_eyeblink: Path = Input(
            description="path to reference video providing eye blinking",
            default=None,
        ),
        ref_pose: Path = Input(
            description="path to reference video providing pose",
            default=None,
        ),
        still: bool = Input(
            description="can crop back to the original videos for the full body aniamtion when preprocess is full",
            default=True,
        ),
    ) -> Path:
        """Run a single prediction on the model"""

        animate_from_coeff = (
            self.animate_from_coeff["full"]
            if preprocess == "full"
            else self.animate_from_coeff["others"]
        )

        args = load_default()
        args.pic_path = str(source_image)
        args.audio_path = str(driven_audio)
        device = "cuda"
        args.still = still
        args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
        args.ref_pose = None if ref_pose is None else str(ref_pose)

        # crop image and extract 3dmm from image
        results_dir = "results"
        if os.path.exists(results_dir):
            shutil.rmtree(results_dir)
        os.makedirs(results_dir)
        first_frame_dir = os.path.join(results_dir, "first_frame_dir")
        os.makedirs(first_frame_dir)

        print("3DMM Extraction for source image")
        first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
            args.pic_path, first_frame_dir, preprocess, source_image_flag=True
        )
        if first_coeff_path is None:
            print("Can't get the coeffs of the input")
            return

        if ref_eyeblink is not None:
            ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
                0
            ]
            ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
            os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
            print("3DMM Extraction for the reference video providing eye blinking")
            ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
                ref_eyeblink, ref_eyeblink_frame_dir
            )
        else:
            ref_eyeblink_coeff_path = None

        if ref_pose is not None:
            if ref_pose == ref_eyeblink:
                ref_pose_coeff_path = ref_eyeblink_coeff_path
            else:
                ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
                ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
                os.makedirs(ref_pose_frame_dir, exist_ok=True)
                print("3DMM Extraction for the reference video providing pose")
                ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
                    ref_pose, ref_pose_frame_dir
                )
        else:
            ref_pose_coeff_path = None

        # audio2ceoff
        batch = get_data(
            first_coeff_path,
            args.audio_path,
            device,
            ref_eyeblink_coeff_path,
            still=still,
        )
        coeff_path = self.audio_to_coeff.generate(
            batch, results_dir, args.pose_style, ref_pose_coeff_path
        )
        # coeff2video
        print("coeff2video")
        data = get_facerender_data(
            coeff_path,
            crop_pic_path,
            first_coeff_path,
            args.audio_path,
            args.batch_size,
            args.input_yaw,
            args.input_pitch,
            args.input_roll,
            expression_scale=args.expression_scale,
            still_mode=still,
            preprocess=preprocess,
        )
        animate_from_coeff.generate(
            data, results_dir, args.pic_path, crop_info,
            enhancer=enhancer, background_enhancer=args.background_enhancer,
            preprocess=preprocess)

        output = "/tmp/out.mp4"
        mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
        shutil.copy(mp4_path, output)

        return Path(output)


def load_default():
    return Namespace(
        pose_style=0,
        batch_size=2,
        expression_scale=1.0,
        input_yaw=None,
        input_pitch=None,
        input_roll=None,
        background_enhancer=None,
        face3dvis=False,
        net_recon="resnet50",
        init_path=None,
        use_last_fc=False,
        bfm_folder="./src/config/",
        bfm_model="BFM_model_front.mat",
        focal=1015.0,
        center=112.0,
        camera_d=10.0,
        z_near=5.0,
        z_far=15.0,
    )


================================================
FILE: quick_demo.ipynb
================================================
{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "M74Gs_TjYl_B"
      },
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github"
      },
      "source": [
        "### SadTalker:Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation \n",
        "\n",
        "[arxiv](https://arxiv.org/abs/2211.12194) | [project](https://sadtalker.github.io) | [Github](https://github.com/Winfredy/SadTalker)\n",
        "\n",
        "Wenxuan Zhang, Xiaodong Cun, Xuan Wang, Yong Zhang, Xi Shen, Yu Guo, Ying Shan, Fei Wang.\n",
        "\n",
        "Xi'an Jiaotong University, Tencent AI Lab, Ant Group\n",
        "\n",
        "CVPR 2023\n",
        "\n",
        "TL;DR: A realistic and stylized talking head video generation method from a single image and audio\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "kA89DV-sKS4i"
      },
      "source": [
        "Installation (around 5 mins)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qJ4CplXsYl_E"
      },
      "outputs": [],
      "source": [
        "### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\n",
        "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mdq6j4E5KQAR"
      },
      "outputs": [],
      "source": [
        "!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.8 2\n",
        "!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.9 1\n",
        "!sudo apt install python3.8\n",
        "\n",
        "!sudo apt-get install python3.8-distutils\n",
        "\n",
        "!python --version\n",
        "\n",
        "!apt-get update\n",
        "\n",
        "!apt install software-properties-common\n",
        "\n",
        "!sudo dpkg --remove --force-remove-reinstreq python3-pip python3-setuptools python3-wheel\n",
        "\n",
        "!apt-get install python3-pip\n",
        "\n",
        "print('Git clone project and install requirements...')\n",
        "!git clone https://github.com/Winfredy/SadTalker &> /dev/null\n",
        "%cd SadTalker\n",
        "!export PYTHONPATH=/content/SadTalker:$PYTHONPATH\n",
        "!python3.8 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\n",
        "!apt update\n",
        "!apt install ffmpeg &> /dev/null\n",
        "!python3.8 -m pip install -r requirements.txt"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "DddcKB_nKsnk"
      },
      "source": [
        "Download models (1 mins)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eDw3_UN8K2xa"
      },
      "outputs": [],
      "source": [
        "print('Download pre-trained models...')\n",
        "!rm -rf checkpoints\n",
        "!bash scripts/download_models.sh"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kK7DYeo7Yl_H"
      },
      "outputs": [],
      "source": [
        "# borrow from makeittalk\n",
        "import ipywidgets as widgets\n",
        "import glob\n",
        "import matplotlib.pyplot as plt\n",
        "print(\"Choose the image name to animate: (saved in folder 'examples/')\")\n",
        "img_list = glob.glob1('examples/source_image', '*.png')\n",
        "img_list.sort()\n",
        "img_list = [item.split('.')[0] for item in img_list]\n",
        "default_head_name = widgets.Dropdown(options=img_list, value='full3')\n",
        "def on_change(change):\n",
        "    if change['type'] == 'change' and change['name'] == 'value':\n",
        "        plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
        "        plt.axis('off')\n",
        "        plt.show()\n",
        "default_head_name.observe(on_change)\n",
        "display(default_head_name)\n",
        "plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
        "plt.axis('off')\n",
        "plt.show()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "-khNZcnGK4UK"
      },
      "source": [
        "Animation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ToBlDusjK5sS"
      },
      "outputs": [],
      "source": [
        "# selected audio from exmaple/driven_audio\n",
        "img = 'examples/source_image/{}.png'.format(default_head_name.value)\n",
        "print(img)\n",
        "!python3.8 inference.py --driven_audio ./examples/driven_audio/RD_Radio31_000.wav \\\n",
        "           --source_image {img} \\\n",
        "           --result_dir ./results --still --preprocess full --enhancer gfpgan"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fAjwGmKKYl_I"
      },
      "outputs": [],
      "source": [
        "# visualize code from makeittalk\n",
        "from IPython.display import HTML\n",
        "from base64 import b64encode\n",
        "import os, sys\n",
        "\n",
        "# get the last from results\n",
        "\n",
        "results = sorted(os.listdir('./results/'))\n",
        "\n",
        "mp4_name = glob.glob('./results/*.mp4')[0]\n",
        "\n",
        "mp4 = open('{}'.format(mp4_name),'rb').read()\n",
        "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
        "\n",
        "print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
        "display(HTML(\"\"\"\n",
        "  <video width=256 controls>\n",
        "        <source src=\"%s\" type=\"video/mp4\">\n",
        "  </video>\n",
        "  \"\"\" % data_url))\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "base",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.9.7"
    },
    "vscode": {
      "interpreter": {
        "hash": "db5031b3636a3f037ea48eb287fd3d023feb9033aefc2a9652a92e470fb0851b"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}


================================================
FILE: req.txt
================================================
llvmlite==0.38.1
numpy==1.21.6
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.10.0.post2
numba==0.55.1
resampy==0.3.1
pydub==0.25.1 
scipy==1.10.1
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml  
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.3.0
gradio
gfpgan
av
safetensors


================================================
FILE: requirements.txt
================================================
numpy==1.23.4
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.9.2 # 
numba
resampy==0.3.1
pydub==0.25.1 
scipy==1.10.1
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml  
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.3.0
gradio
gfpgan
av
safetensors


================================================
FILE: requirements3d.txt
================================================
numpy==1.23.4
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.9.2 # 
numba
resampy==0.3.1
pydub==0.25.1 
scipy==1.5.3
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml  
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.3.0
trimesh==3.9.20
gradio
gfpgan
safetensors

================================================
FILE: scripts/download_models.sh
================================================
mkdir ./checkpoints  

# lagency download link
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
# unzip -n ./checkpoints/hub.zip -d ./checkpoints/


#### download the new links.
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O  ./checkpoints/mapping_00109-model.pth.tar
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O  ./checkpoints/mapping_00229-model.pth.tar
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O  ./checkpoints/SadTalker_V0.0.2_256.safetensors
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O  ./checkpoints/SadTalker_V0.0.2_512.safetensors


# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
# unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/

### enhancer 
mkdir -p ./gfpgan/weights
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth 
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth 
wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth 
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth 



================================================
FILE: scripts/extension.py
================================================
import os, sys
from pathlib import Path
import tempfile
import gradio as gr
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
from modules.shared import opts, OptionInfo
from modules import shared, paths, script_callbacks
import launch
import glob
from huggingface_hub import snapshot_download



def check_all_files_safetensor(current_dir):
    kv = {
        "SadTalker_V0.0.2_256.safetensors": "sadtalker-256",
        "SadTalker_V0.0.2_512.safetensors": "sadtalker-512",
        "mapping_00109-model.pth.tar" : "mapping-109" ,
        "mapping_00229-model.pth.tar" : "mapping-229" ,
    }

    if not os.path.isdir(current_dir):
        return False
    
    dirs = os.listdir(current_dir)

    for f in dirs:
        if f in kv.keys():
            del kv[f]

    return len(kv.keys()) == 0

def check_all_files(current_dir):
    kv = {
        "auido2exp_00300-model.pth": "audio2exp",
        "auido2pose_00140-model.pth": "audio2pose",
        "epoch_20.pth": "face_recon",
        "facevid2vid_00189-model.pth.tar": "face-render",
        "mapping_00109-model.pth.tar" : "mapping-109" ,
        "mapping_00229-model.pth.tar" : "mapping-229" ,
        "wav2lip.pth": "wav2lip",
        "shape_predictor_68_face_landmarks.dat": "dlib",
    }

    if not os.path.isdir(current_dir):
        return False
    
    dirs = os.listdir(current_dir)

    for f in dirs:
        if f in kv.keys():
            del kv[f]

    return len(kv.keys()) == 0

    

def download_model(local_dir='./checkpoints'):
    REPO_ID = 'vinthony/SadTalker'
    snapshot_download(repo_id=REPO_ID, local_dir=local_dir, local_dir_use_symlinks=False)

def get_source_image(image):   
        return image

def get_img_from_txt2img(x):
    talker_path = Path(paths.script_path) / "outputs"
    imgs_from_txt_dir = str(talker_path / "txt2img-images/")
    imgs = glob.glob(imgs_from_txt_dir+'/*/*.png')
    imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x)))
    img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1])
    return img_from_txt_path, img_from_txt_path

def get_img_from_img2img(x):
    talker_path = Path(paths.script_path) / "outputs"
    imgs_from_img_dir = str(talker_path / "img2img-images/")
    imgs = glob.glob(imgs_from_img_dir+'/*/*.png')
    imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x)))
    img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1])
    return img_from_img_path, img_from_img_path
 
def get_default_checkpoint_path():
    # check the path of models/checkpoints and extensions/
    checkpoint_path = Path(paths.script_path) / "models"/ "SadTalker" 
    extension_checkpoint_path = Path(paths.script_path) / "extensions"/ "SadTalker" / "checkpoints"

    if check_all_files_safetensor(checkpoint_path):
        # print('founding sadtalker checkpoint in ' + str(checkpoint_path))
        return checkpoint_path

    if check_all_files_safetensor(extension_checkpoint_path):
        # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
        return extension_checkpoint_path
    
    if check_all_files(checkpoint_path):
        # print('founding sadtalker checkpoint in ' + str(checkpoint_path))
        return checkpoint_path

    if check_all_files(extension_checkpoint_path):
        # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
        return extension_checkpoint_path

    return None



def install():

    kv = {
        "face_alignment": "face-alignment==1.3.5",
        "imageio": "imageio==2.19.3",
        "imageio_ffmpeg": "imageio-ffmpeg==0.4.7",
        "librosa":"librosa==0.8.0",
        "pydub":"pydub==0.25.1",
        "scipy":"scipy==1.8.1",
        "tqdm": "tqdm",
        "yacs":"yacs==0.1.8",
        "yaml": "pyyaml", 
        "av":"av",
        "gfpgan": "gfpgan",
    }

    # # dlib is not necessary currently
    # if 'darwin' in sys.platform:
    #     kv['dlib'] = "dlib"
    # else:
    #     kv['dlib'] = 'dlib-bin'

    # #### we need to have a newer version of imageio for our method.
    # launch.run_pip("install imageio==2.19.3", "requirements for SadTalker")

    for k,v in kv.items():
        if not launch.is_installed(k):
            print(k, launch.is_installed(k))
            launch.run_pip("install "+ v, "requirements for SadTalker")

    if os.getenv('SADTALKER_CHECKPOINTS'):
        print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))

    elif get_default_checkpoint_path() is not None:
        os.environ['SADTALKER_CHECKPOINTS'] = str(get_default_checkpoint_path())
    else:

        print(
            """"
            SadTalker will not support download all the files from hugging face, which will take a long time.
             
            please manually set the SADTALKER_CHECKPOINTS in `webui_user.bat`(windows) or `webui_user.sh`(linux)
            """
            )
        
        # python = sys.executable

        # launch.run(f'"{python}" -m pip uninstall -y huggingface_hub', live=True)
        # launch.run(f'"{python}" -m pip install --upgrade git+https://github.com/huggingface/huggingface_hub@main', live=True)
        # ### run the scripts to downlod models to correct localtion.
        # # print('download models for SadTalker')
        # # launch.run("cd " + paths.script_path+"/extensions/SadTalker && bash ./scripts/download_models.sh", live=True)
        # # print('SadTalker is successfully installed!')
        # download_model(paths.script_path+'/extensions/SadTalker/checkpoints')
    
 
def on_ui_tabs():
    install()

    sys.path.extend([paths.script_path+'/extensions/SadTalker']) 
    
    repo_dir = paths.script_path+'/extensions/SadTalker/'

    result_dir = opts.sadtalker_result_dir
    os.makedirs(result_dir, exist_ok=True)

    from app_sadtalker import sadtalker_demo  

    if  os.getenv('SADTALKER_CHECKPOINTS'):
        checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')
    else:
        checkpoint_path = repo_dir+'checkpoints/'

    audio_to_video = sadtalker_demo(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', warpfn = wrap_queued_call)
   
    return [(audio_to_video, "SadTalker", "extension")]

def on_ui_settings():
    talker_path = Path(paths.script_path) / "outputs"
    section = ('extension', "SadTalker") 
    opts.add_option("sadtalker_result_dir", OptionInfo(str(talker_path / "SadTalker/"), "Path to save results of sadtalker", section=section)) 

script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)


================================================
FILE: scripts/test.sh
================================================
# ### some test command before commit.
# python inference.py --preprocess crop --size 256
# python inference.py --preprocess crop --size 512

# python inference.py --preprocess extcrop --size 256
# python inference.py --preprocess extcrop --size 512

# python inference.py --preprocess resize --size 256
# python inference.py --preprocess resize --size 512

# python inference.py --preprocess full --size 256
# python inference.py --preprocess full --size 512

# python inference.py --preprocess extfull --size 256
# python inference.py --preprocess extfull --size 512

python inference.py --preprocess full --size 256 --enhancer gfpgan
python inference.py --preprocess full --size 512 --enhancer gfpgan

python inference.py --preprocess full --size 256 --enhancer gfpgan --still
python inference.py --preprocess full --size 512 --enhancer gfpgan --still


================================================
FILE: src/audio2exp_models/audio2exp.py
================================================
from tqdm import tqdm
import torch
from torch import nn


class Audio2Exp(nn.Module):
    def __init__(self, netG, cfg, device, prepare_training_loss=False):
        super(Audio2Exp, self).__init__()
        self.cfg = cfg
        self.device = device
        self.netG = netG.to(device)

    def test(self, batch):

        mel_input = batch['indiv_mels']                         # bs T 1 80 16
        bs = mel_input.shape[0]
        T = mel_input.shape[1]

        exp_coeff_pred = []

        for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
            
            current_mel_input = mel_input[:,i:i+10]

            #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1))           #bs T 64
            ref = batch['ref'][:, :, :64][:, i:i+10]
            ratio = batch['ratio_gt'][:, i:i+10]                               #bs T

            audiox = current_mel_input.view(-1, 1, 80, 16)                  # bs*T 1 80 16

            curr_exp_coeff_pred  = self.netG(audiox, ref, ratio)         # bs T 64 

            exp_coeff_pred += [curr_exp_coeff_pred]

        # BS x T x 64
        results_dict = {
            'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
            }
        return results_dict




================================================
FILE: src/audio2exp_models/networks.py
================================================
import torch
import torch.nn.functional as F
from torch import nn

class Conv2d(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
                            nn.Conv2d(cin, cout, kernel_size, stride, padding),
                            nn.BatchNorm2d(cout)
                            )
        self.act = nn.ReLU()
        self.residual = residual
        self.use_act = use_act

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        
        if self.use_act:
            return self.act(out)
        else:
            return out

class SimpleWrapperV2(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.audio_encoder = nn.Sequential(
            Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
            Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
            )

        #### load the pre-trained audio_encoder 
        #self.audio_encoder = self.audio_encoder.to(device)  
        '''
        wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
        state_dict = self.audio_encoder.state_dict()

        for k,v in wav2lip_state_dict.items():
            if 'audio_encoder' in k:
                print('init:', k)
                state_dict[k.replace('module.audio_encoder.', '')] = v
        self.audio_encoder.load_state_dict(state_dict)
        '''

        self.mapping1 = nn.Linear(512+64+1, 64)
        #self.mapping2 = nn.Linear(30, 64)
        #nn.init.constant_(self.mapping1.weight, 0.)
        nn.init.constant_(self.mapping1.bias, 0.)

    def forward(self, x, ref, ratio):
        x = self.audio_encoder(x).view(x.size(0), -1)
        ref_reshape = ref.reshape(x.size(0), -1)
        ratio = ratio.reshape(x.size(0), -1)
        
        y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) 
        out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
        return out


================================================
FILE: src/audio2pose_models/audio2pose.py
================================================
import torch
from torch import nn
from src.audio2pose_models.cvae import CVAE
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
from src.audio2pose_models.audio_encoder import AudioEncoder

class Audio2Pose(nn.Module):
    def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
        super().__init__()
        self.cfg = cfg
        self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
        self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
        self.device = device

        self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
        self.audio_encoder.eval()
        for param in self.audio_encoder.parameters():
            param.requires_grad = False

        self.netG = CVAE(cfg)
        self.netD_motion = PoseSequenceDiscriminator(cfg)
        
        
    def forward(self, x):

        batch = {}
        coeff_gt = x['gt'].cuda().squeeze(0)           #bs frame_len+1 73
        batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
        batch['ref'] = coeff_gt[:, 0, 64:70]  #bs  6
        batch['class'] = x['class'].squeeze(0).cuda() # bs
        indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16

        # forward
        audio_emb_list = []
        audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
        batch['audio_emb'] = audio_emb
        batch = self.netG(batch)

        pose_motion_pred = batch['pose_motion_pred']           # bs frame_len 6
        pose_gt = coeff_gt[:, 1:, 64:70].clone()               # bs frame_len 6
        pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred  # bs frame_len 6

        batch['pose_pred'] = pose_pred
        batch['pose_gt'] = pose_gt

        return batch

    def test(self, x):

        batch = {}
        ref = x['ref']                            #bs 1 70
        batch['ref'] = x['ref'][:,0,-6:]  
        batch['class'] = x['class']  
        bs = ref.shape[0]
        
        indiv_mels= x['indiv_mels']               # bs T 1 80 16
        indiv_mels_use = indiv_mels[:, 1:]        # we regard the ref as the first frame
        num_frames = x['num_frames']
        num_frames = int(num_frames) - 1

        #  
        div = num_frames//self.seq_len
        re = num_frames%self.seq_len
        audio_emb_list = []
        pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, 
                                                device=batch['ref'].device)]

        for i in range(div):
            z = torch.randn(bs, self.latent_dim).to(ref.device)
            batch['z'] = z
            audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
            batch['audio_emb'] = audio_emb
            batch = self.netG.test(batch)
            pose_motion_pred_list.append(batch['pose_motion_pred'])  #list of bs seq_len 6
        
        if re != 0:
            z = torch.randn(bs, self.latent_dim).to(ref.device)
            batch['z'] = z
            audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len  512
            if audio_emb.shape[1] != self.seq_len:
                pad_dim = self.seq_len-audio_emb.shape[1]
                pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) 
                audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) 
            batch['audio_emb'] = audio_emb
            batch = self.netG.test(batch)
            pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])   
        
        pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
        batch['pose_motion_pred'] = pose_motion_pred

        pose_pred = ref[:, :1, -6:] + pose_motion_pred  # bs T 6

        batch['pose_pred'] = pose_pred
        return batch


================================================
FILE: src/audio2pose_models/audio_encoder.py
================================================
import torch
from torch import nn
from torch.nn import functional as F

class Conv2d(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
                            nn.Conv2d(cin, cout, kernel_size, stride, padding),
                            nn.BatchNorm2d(cout)
                            )
        self.act = nn.ReLU()
        self.residual = residual

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        return self.act(out)

class AudioEncoder(nn.Module):
    def __init__(self, wav2lip_checkpoint, device):
        super(AudioEncoder, self).__init__()

        self.audio_encoder = nn.Sequential(
            Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),

            Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
            Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)

        #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
        # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
        # state_dict = self.audio_encoder.state_dict()

        # for k,v in wav2lip_state_dict.items():
        #     if 'audio_encoder' in k:
        #         state_dict[k.replace('module.audio_encoder.', '')] = v
        # self.audio_encoder.load_state_dict(state_dict)


    def forward(self, audio_sequences):
        # audio_sequences = (B, T, 1, 80, 16)
        B = audio_sequences.size(0)

        audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)

        audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
        dim = audio_embedding.shape[1]
        audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))

        return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 


================================================
FILE: src/audio2pose_models/cvae.py
================================================
import torch
import torch.nn.functional as F
from torch import nn
from src.audio2pose_models.res_unet import ResUnet

def class2onehot(idx, class_num):

    assert torch.max(idx).item() < class_num
    onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
    onehot.scatter_(1, idx, 1)
    return onehot

class CVAE(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
        decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
        latent_size = cfg.MODEL.CVAE.LATENT_SIZE
        num_classes = cfg.DATASET.NUM_CLASSES
        audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
        audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
        seq_len = cfg.MODEL.CVAE.SEQ_LEN

        self.latent_size = latent_size

        self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
                                audio_emb_in_size, audio_emb_out_size, seq_len)
        self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
                                audio_emb_in_size, audio_emb_out_size, seq_len)
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, batch):
        batch = self.encoder(batch)
        mu = batch['mu']
        logvar = batch['logvar']
        z = self.reparameterize(mu, logvar)
        batch['z'] = z
        return self.decoder(batch)

    def test(self, batch):
        '''
        class_id = batch['class']
        z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
        batch['z'] = z
        '''
        return self.decoder(batch)

class ENCODER(nn.Module):
    def __init__(self, layer_sizes, latent_size, num_classes, 
                audio_emb_in_size, audio_emb_out_size, seq_len):
        super().__init__()

        self.resunet = ResUnet()
        self.num_classes = num_classes
        self.seq_len = seq_len

        self.MLP = nn.Sequential()
        layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            self.MLP.add_module(
                name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
            self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())

        self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
        self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
        self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)

        self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))

    def forward(self, batch):
        class_id = batch['class']
        pose_motion_gt = batch['pose_motion_gt']                             #bs seq_len 6
        ref = batch['ref']                             #bs 6
        bs = pose_motion_gt.shape[0]
        audio_in = batch['audio_emb']                          # bs seq_len audio_emb_in_size

        #pose encode
        pose_emb = self.resunet(pose_motion_gt.unsqueeze(1))          #bs 1 seq_len 6 
        pose_emb = pose_emb.reshape(bs, -1)                    #bs seq_len*6

        #audio mapping
        print(audio_in.shape)
        audio_out = self.linear_audio(audio_in)                # bs seq_len audio_emb_out_size
        audio_out = audio_out.reshape(bs, -1)

        class_bias = self.classbias[class_id]                  #bs latent_size
        x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
        x_out = self.MLP(x_in)

        mu = self.linear_means(x_out)
        logvar = self.linear_means(x_out)                      #bs latent_size 

        batch.update({'mu':mu, 'logvar':logvar})
        return batch

class DECODER(nn.Module):
    def __init__(self, layer_sizes, latent_size, num_classes, 
                audio_emb_in_size, audio_emb_out_size, seq_len):
        super().__init__()

        self.resunet = ResUnet()
        self.num_classes = num_classes
        self.seq_len = seq_len

        self.MLP = nn.Sequential()
        input_size = latent_size + seq_len*audio_emb_out_size + 6
        for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
            self.MLP.add_module(
                name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
            if i+1 < len(layer_sizes):
                self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
            else:
                self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
        
        self.pose_linear = nn.Linear(6, 6)
        self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)

        self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))

    def forward(self, batch):

        z = batch['z']                                          #bs latent_size
        bs = z.shape[0]
        class_id = batch['class']
        ref = batch['ref']                             #bs 6
        audio_in = batch['audio_emb']                           # bs seq_len audio_emb_in_size
        #print('audio_in: ', audio_in[:, :, :10])

        audio_out = self.linear_audio(audio_in)                 # bs seq_len audio_emb_out_size
        #print('audio_out: ', audio_out[:, :, :10])
        audio_out = audio_out.reshape([bs, -1])                 # bs seq_len*audio_emb_out_size
        class_bias = self.classbias[class_id]                   #bs latent_size

        z = z + class_bias
        x_in = torch.cat([ref, z, audio_out], dim=-1)
        x_out = self.MLP(x_in)                                  # bs layer_sizes[-1]
        x_out = x_out.reshape((bs, self.seq_len, -1))

        #print('x_out: ', x_out)

        pose_emb = self.resunet(x_out.unsqueeze(1))             #bs 1 seq_len 6

        pose_motion_pred = self.pose_linear(pose_emb.squeeze(1))       #bs seq_len 6

        batch.update({'pose_motion_pred':pose_motion_pred})
        return batch


================================================
FILE: src/audio2pose_models/discriminator.py
================================================
import torch
import torch.nn.functional as F
from torch import nn

class ConvNormRelu(nn.Module):
    def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
                 kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
        super().__init__()
        if kernel_size is None:
            if downsample:
                kernel_size, stride, padding = 4, 2, 1
            else:
                kernel_size, stride, padding = 3, 1, 1

        if conv_type == '2d':
            self.conv = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            )
            if norm == 'BN':
                self.norm = nn.BatchNorm2d(out_channels)
            elif norm == 'IN':
                self.norm = nn.InstanceNorm2d(out_channels)
            else:
                raise NotImplementedError
        elif conv_type == '1d':
            self.conv = nn.Conv1d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            )
            if norm == 'BN':
                self.norm = nn.BatchNorm1d(out_channels)
            elif norm == 'IN':
                self.norm = nn.InstanceNorm1d(out_channels)
            else:
                raise NotImplementedError
        nn.init.kaiming_normal_(self.conv.weight)

        self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if isinstance(self.norm, nn.InstanceNorm1d):
            x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1))  # normalize on [C]
        else:
            x = self.norm(x)
        x = self.act(x)
        return x


class PoseSequenceDiscriminator(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU

        self.seq = nn.Sequential(
            ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky),  # B, 256, 64
            ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky),  # B, 512, 32
            ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky),  # B, 1024, 16
            nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True)  # B, 1, 16
        )

    def forward(self, x):
        x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
        x = self.seq(x)
        x = x.squeeze(1)
        return x

================================================
FILE: src/audio2pose_models/networks.py
================================================
import torch.nn as nn
import torch


class ResidualConv(nn.Module):
    def __init__(self, input_dim, output_dim, stride, padding):
        super(ResidualConv, self).__init__()

        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(
                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
            ),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(output_dim),
        )

    def forward(self, x):

        return self.conv_block(x) + self.conv_skip(x)


class Upsample(nn.Module):
    def __init__(self, input_dim, output_dim, kernel, stride):
        super(Upsample, self).__init__()

        self.upsample = nn.ConvTranspose2d(
            input_dim, output_dim, kernel_size=kernel, stride=stride
        )

    def forward(self, x):
        return self.upsample(x)


class Squeeze_Excite_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(Squeeze_Excite_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class ASPP(nn.Module):
    def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
        super(ASPP, self).__init__()

        self.aspp_block1 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block2 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block3 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )

        self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
        self._init_weights()

    def forward(self, x):
        x1 = self.aspp_block1(x)
        x2 = self.aspp_block2(x)
        x3 = self.aspp_block3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return self.output(out)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class Upsample_(nn.Module):
    def __init__(self, scale=2):
        super(Upsample_, self).__init__()

        self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)

    def forward(self, x):
        return self.upsample(x)


class AttentionBlock(nn.Module):
    def __init__(self, input_encoder, input_decoder, output_dim):
        super(AttentionBlock, self).__init__()

        self.conv_encoder = nn.Sequential(
            nn.BatchNorm2d(input_encoder),
            nn.ReLU(),
            nn.Conv2d(input_encoder, output_dim, 3, padding=1),
            nn.MaxPool2d(2, 2),
        )

        self.conv_decoder = nn.Sequential(
            nn.BatchNorm2d(input_decoder),
            nn.ReLU(),
            nn.Conv2d(input_decoder, output_dim, 3, padding=1),
        )

        self.conv_attn = nn.Sequential(
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, 1, 1),
        )

    def forward(self, x1, x2):
        out = self.conv_encoder(x1) + self.conv_decoder(x2)
        out = self.conv_attn(out)
        return out * x2

================================================
FILE: src/audio2pose_models/res_unet.py
================================================
import torch
import torch.nn as nn
from src.audio2pose_models.networks import ResidualConv, Upsample


class ResUnet(nn.Module):
    def __init__(self, channel=1, filters=[32, 64, 128, 256]):
        super(ResUnet, self).__init__()

        self.input_layer = nn.Sequential(
            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(),
            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
        )
        self.input_skip = nn.Sequential(
            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
        )

        self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
        self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)

        self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)

        self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
        self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)

        self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
        self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)

        self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
        self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)

        self.output_layer = nn.Sequential(
            nn.Conv2d(filters[0], 1, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # Encode
        x1 = self.input_layer(x) + self.input_skip(x)
        x2 = self.residual_conv_1(x1)
        x3 = self.residual_conv_2(x2)
        # Bridge
        x4 = self.bridge(x3)

        # Decode
        x4 = self.upsample_1(x4)
        x5 = torch.cat([x4, x3], dim=1)

        x6 = self.up_residual_conv1(x5)

        x6 = self.upsample_2(x6)
        x7 = torch.cat([x6, x2], dim=1)

        x8 = self.up_residual_conv2(x7)

        x8 = self.upsample_3(x8)
        x9 = torch.cat([x8, x1], dim=1)

        x10 = self.up_residual_conv3(x9)

        output = self.output_layer(x10)

        return output

================================================
FILE: src/config/auido2exp.yaml
================================================
DATASET:
  TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
  EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
  TRAIN_BATCH_SIZE: 32
  EVAL_BATCH_SIZE: 32
  EXP: True
  EXP_DIM: 64
  FRAME_LEN: 32
  COEFF_LEN: 73
  NUM_CLASSES: 46
  AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
  COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
  LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
  DEBUG: True
  NUM_REPEATS: 2
  T: 40
  

MODEL:
  FRAMEWORK: V2
  AUDIOENCODER:
    LEAKY_RELU: True
    NORM: 'IN'
  DISCRIMINATOR:
    LEAKY_RELU: False
    INPUT_CHANNELS: 6
  CVAE:
    AUDIO_EMB_IN_SIZE: 512
    AUDIO_EMB_OUT_SIZE: 128
    SEQ_LEN: 32
    LATENT_SIZE: 256
    ENCODER_LAYER_SIZES: [192, 1024]
    DECODER_LAYER_SIZES: [1024, 192]
    

TRAIN:
  MAX_EPOCH: 300
  GENERATOR:
    LR: 2.0e-5
  DISCRIMINATOR:
    LR: 1.0e-5
  LOSS:
    W_FEAT: 0
    W_COEFF_EXP: 2
    W_LM: 1.0e-2
    W_LM_MOUTH: 0
    W_REG: 0
    W_SYNC: 0
    W_COLOR: 0
    W_EXPRESSION: 0
    W_LIPREADING: 0.01
    W_LIPREADING_VV: 0
    W_EYE_BLINK: 4

TAG:
  NAME:  small_dataset




================================================
FILE: src/config/auido2pose.yaml
================================================
DATASET:
  TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
  EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
  TRAIN_BATCH_SIZE: 64
  EVAL_BATCH_SIZE: 1
  EXP: True
  EXP_DIM: 64
  FRAME_LEN: 32
  COEFF_LEN: 73
  NUM_CLASSES: 46
  AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
  COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
  DEBUG: True
  

MODEL:
  AUDIOENCODER:
    LEAKY_RELU: True
    NORM: 'IN'
  DISCRIMINATOR:
    LEAKY_RELU: False
    INPUT_CHANNELS: 6
  CVAE:
    AUDIO_EMB_IN_SIZE: 512
    AUDIO_EMB_OUT_SIZE: 6
    SEQ_LEN: 32
    LATENT_SIZE: 64
    ENCODER_LAYER_SIZES: [192, 128]
    DECODER_LAYER_SIZES: [128, 192]
    

TRAIN:
  MAX_EPOCH: 150
  GENERATOR:
    LR: 1.0e-4
  DISCRIMINATOR:
    LR: 1.0e-4
  LOSS:
    LAMBDA_REG: 1
    LAMBDA_LANDMARKS: 0
    LAMBDA_VERTICES: 0
    LAMBDA_GAN_MOTION: 0.7
    LAMBDA_GAN_COEFF: 0
    LAMBDA_KL: 1

TAG:
  NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder




================================================
FILE: src/config/facerender.yaml
================================================
model_params:
  common_params:
    num_kp: 15 
    image_channel: 3                    
    feature_channel: 32
    estimate_jacobian: False   # True
  kp_detector_params:
     temperature: 0.1
     block_expansion: 32            
     max_features: 1024
     scale_factor: 0.25         # 0.25
     num_blocks: 5
     reshape_channel: 16384  # 16384 = 1024 * 16
     reshape_depth: 16
  he_estimator_params:
     block_expansion: 64            
     max_features: 2048
     num_bins: 66
  generator_params:
    block_expansion: 64
    max_features: 512
    num_down_blocks: 2
    reshape_channel: 32
    reshape_depth: 16         # 512 = 32 * 16
    num_resblocks: 6
    estimate_occlusion_map: True
    dense_motion_params:
      block_expansion: 32
      max_features: 1024
      num_blocks: 5
      reshape_depth: 16
      compress: 4
  discriminator_params:
    scales: [1]
    block_expansion: 32                 
    max_features: 512
    num_blocks: 4
    sn: True
  mapping_params:
      coeff_nc: 70
      descriptor_nc: 1024
      layer: 3
      num_kp: 15
      num_bins: 66



================================================
FILE: src/config/facerender_still.yaml
================================================
model_params:
  common_params:
    num_kp: 15 
    image_channel: 3                    
    feature_channel: 32
    estimate_jacobian: False   # True
  kp_detector_params:
     temperature: 0.1
     block_expansion: 32            
     max_features: 1024
     scale_factor: 0.25         # 0.25
     num_blocks: 5
     reshape_channel: 16384  # 16384 = 1024 * 16
     reshape_depth: 16
  he_estimator_params:
     block_expansion: 64            
     max_features: 2048
     num_bins: 66
  generator_params:
    block_expansion: 64
    max_features: 512
    num_down_blocks: 2
    reshape_channel: 32
    reshape_depth: 16         # 512 = 32 * 16
    num_resblocks: 6
    estimate_occlusion_map: True
    dense_motion_params:
      block_expansion: 32
      max_features: 1024
      num_blocks: 5
      reshape_depth: 16
      compress: 4
  discriminator_params:
    scales: [1]
    block_expansion: 32                 
    max_features: 512
    num_blocks: 4
    sn: True
  mapping_params:
      coeff_nc: 73
      descriptor_nc: 1024
      layer: 3
      num_kp: 15
      num_bins: 66



================================================
FILE: src/face3d/data/__init__.py
================================================
"""This package includes all the modules related to data loading and preprocessing

 To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
 You need to implement four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point from data loader.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.

Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import numpy as np
import importlib
import torch.utils.data
from face3d.data.base_dataset import BaseDataset


def find_dataset_using_name(dataset_name):
    """Import the module "data/[dataset_name]_dataset.py".

    In the file, the class called DatasetNameDataset() will
    be instantiated. It has to be a subclass of BaseDataset,
    and it is case-insensitive.
    """
    dataset_filename = "data." + dataset_name + "_dataset"
    datasetlib = importlib.import_module(dataset_filename)

    dataset = None
    target_dataset_name = dataset_name.replace('_', '') + 'dataset'
    for name, cls in datasetlib.__dict__.items():
        if name.lower() == target_dataset_name.lower() \
           and issubclass(cls, BaseDataset):
            dataset = cls

    if dataset is None:
        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

    return dataset


def get_option_setter(dataset_name):
    """Return the static method <modify_commandline_options> of the dataset class."""
    dataset_class = find_dataset_using_name(dataset_name)
    return dataset_class.modify_commandline_options


def create_dataset(opt, rank=0):
    """Create a dataset given the option.

    This function wraps the class CustomDatasetDataLoader.
        This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from data import create_dataset
        >>> dataset = create_dataset(opt)
    """
    data_loader = CustomDatasetDataLoader(opt, rank=rank)
    dataset = data_loader.load_data()
    return dataset

class CustomDatasetDataLoader():
    """Wrapper class of Dataset class that performs multi-threaded data loading"""

    def __init__(self, opt, rank=0):
        """Initialize this class

        Step 1: create a dataset instance given the name [dataset_mode]
        Step 2: create a multi-threaded data loader.
        """
        self.opt = opt
        dataset_class = find_dataset_using_name(opt.dataset_mode)
        self.dataset = dataset_class(opt)
        self.sampler = None
        print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
        if opt.use_ddp and opt.isTrain:
            world_size = opt.world_size
            self.sampler = torch.utils.data.distributed.DistributedSampler(
                    self.dataset,
                    num_replicas=world_size,
                    rank=rank,
                    shuffle=not opt.serial_batches
                )
            self.dataloader = torch.utils.data.DataLoader(
                        self.dataset,
                        sampler=self.sampler,
                        num_workers=int(opt.num_threads / world_size), 
                        batch_size=int(opt.batch_size / world_size), 
                        drop_last=True)
        else:
            self.dataloader = torch.utils.data.DataLoader(
                self.dataset,
                batch_size=opt.batch_size,
                shuffle=(not opt.serial_batches) and opt.isTrain,
                num_workers=int(opt.num_threads),
                drop_last=True
            )

    def set_epoch(self, epoch):
        self.dataset.current_epoch = epoch
        if self.sampler is not None:
            self.sampler.set_epoch(epoch)

    def load_data(self):
        return self

    def __len__(self):
        """Return the number of data in the dataset"""
        return min(len(self.dataset), self.opt.max_dataset_size)

    def __iter__(self):
        """Return a batch of data"""
        for i, data in enumerate(self.dataloader):
            if i * self.opt.batch_size >= self.opt.max_dataset_size:
                break
            yield data


================================================
FILE: src/face3d/data/base_dataset.py
================================================
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.

It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod


class BaseDataset(data.Dataset, ABC):
    """This class is an abstract base class (ABC) for datasets.

    To create a subclass, you need to implement the following four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
    """

    def __init__(self, opt):
        """Initialize the class; save the options in the class

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.opt = opt
        # self.root = opt.dataroot
        self.current_epoch = 0

    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        return parser

    @abstractmethod
    def __len__(self):
        """Return the total number of images in the dataset."""
        return 0

    @abstractmethod
    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns:
            a dictionary of data with their names. It ususally contains the data itself and its metadata information.
        """
        pass


def get_transform(grayscale=False):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    transform_list += [transforms.ToTensor()]
    return transforms.Compose(transform_list)

def get_affine_mat(opt, size):
    shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
    w, h = size

    if 'shift' in opt.preprocess:
        shift_pixs = int(opt.shift_pixs)
        shift_x = random.randint(-shift_pixs, shift_pixs)
        shift_y = random.randint(-shift_pixs, shift_pixs)
    if 'scale' in opt.preprocess:
        scale = 1 + opt.scale_delta * (2 * random.random() - 1)
    if 'rot' in opt.preprocess:
        rot_angle = opt.rot_angle * (2 * random.random() - 1)
        rot_rad = -rot_angle * np.pi/180
    if 'flip' in opt.preprocess:
        flip = random.random() > 0.5

    shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
    flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
    shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
    rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
    scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
    shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
    
    affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin    
    affine_inv = np.linalg.inv(affine)
    return affine, affine_inv, flip

def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
    return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)

def apply_lm_affine(landmark, affine, flip, size):
    _, h = size
    lm = landmark.copy()
    lm[:, 1] = h - 1 - lm[:, 1]
    lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
    lm = lm @ np.transpose(affine)
    lm[:, :2] = lm[:, :2] / lm[:, 2:]
    lm = lm[:, :2]
    lm[:, 1] = h - 1 - lm[:, 1]
    if flip:
        lm_ = lm.copy()
        lm_[:17] = lm[16::-1]
        lm_[17:22] = lm[26:21:-1]
        lm_[22:27] = lm[21:16:-1]
        lm_[31:36] = lm[35:30:-1]
        lm_[36:40] = lm[45:41:-1]
        lm_[40:42] = lm[47:45:-1]
        lm_[42:46] = lm[39:35:-1]
        lm_[46:48] = lm[41:39:-1]
        lm_[48:55] = lm[54:47:-1]
        lm_[55:60] = lm[59:54:-1]
        lm_[60:65] = lm[64:59:-1]
        lm_[65:68] = lm[67:64:-1]
        lm = lm_
    return lm


================================================
FILE: src/face3d/data/flist_dataset.py
================================================
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
"""

import os.path
from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
from data.image_folder import make_dataset
from PIL import Image
import random
import util.util as util
import numpy as np
import json
import torch
from scipy.io import loadmat, savemat
import pickle
from util.preprocess import align_img, estimate_norm
from util.load_mats import load_lm3d


def default_flist_reader(flist):
    """
    flist format: impath label\nimpath label\n ...(same to caffe's filelist)
    """
    imlist = []
    with open(flist, 'r') as rf:
        for line in rf.readlines():
            impath = line.strip()
            imlist.append(impath)

    return imlist

def jason_flist_reader(flist):
    with open(flist, 'r') as fp:
        info = json.load(fp)
    return info

def parse_label(label):
    return torch.tensor(np.array(label).astype(np.float32))


class FlistDataset(BaseDataset):
    """
    It requires one directories to host training images '/path/to/data/train'
    You can train the model with the dataset flag '--dataroot /path/to/data'.
    """

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseDataset.__init__(self, opt)
        
        self.lm3d_std = load_lm3d(opt.bfm_folder)
        
        msk_names = default_flist_reader(opt.flist)
        self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]

        self.size = len(self.msk_paths) 
        self.opt = opt
        
        self.name = 'train' if opt.isTrain else 'val'
        if '_' in opt.flist:
            self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
        

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index (int)      -- a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            img (tensor)       -- an image in the input domain
            msk (tensor)       -- its corresponding attention mask
            lm  (tensor)       -- its corresponding 3d landmarks
            im_paths (str)     -- image paths
            aug_flag (bool)    -- a flag used to tell whether its raw or augmented
        """
        msk_path = self.msk_paths[index % self.size]  # make sure index is within then range
        img_path = msk_path.replace('mask/', '')
        lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'

        raw_img = Image.open(img_path).convert('RGB')
        raw_msk = Image.open(msk_path).convert('RGB')
        raw_lm = np.loadtxt(lm_path).astype(np.float32)

        _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
        
        aug_flag = self.opt.use_aug and self.opt.isTrain
        if aug_flag:
            img, lm, msk = self._augmentation(img, lm, self.opt, msk)
        
        _, H = img.size
        M = estimate_norm(lm, H)
        transform = get_transform()
        img_tensor = transform(img)
        msk_tensor = transform(msk)[:1, ...]
        lm_tensor = parse_label(lm)
        M_tensor = parse_label(M)


        return {'imgs': img_tensor, 
                'lms': lm_tensor, 
                'msks': msk_tensor, 
                'M': M_tensor,
                'im_paths': img_path, 
                'aug_flag': aug_flag,
                'dataset': self.name}

    def _augmentation(self, img, lm, opt, msk=None):
        affine, affine_inv, flip = get_affine_mat(opt, img.size)
        img = apply_img_affine(img, affine_inv)
        lm = apply_lm_affine(lm, affine, flip, img.size)
        if msk is not None:
            msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
        return img, lm, msk
    



    def __len__(self):
        """Return the total number of images in the dataset.
        """
        return self.size


================================================
FILE: src/face3d/data/image_folder.py
================================================
"""A modified image folder class

We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
import numpy as np
import torch.utils.data as data

from PIL import Image
import os
import os.path

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
    '.tif', '.TIF', '.tiff', '.TIFF',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_dataset_size=float("inf")):
    images = []
    assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    return images[:min(max_dataset_size, len(images))]


def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)


================================================
FILE: src/face3d/data/template_dataset.py
================================================
"""Dataset class template

This module provides a template for users to implement custom datasets.
You can specify '--dataset_mode template' to use this dataset.
The class name should be consistent with both the filename and its dataset_mode option.
The filename should be <dataset_mode>_dataset.py
The class name should be <Dataset_mode>Dataset.py
You need to implement the following functions:
    -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
    -- <__init__>: Initialize this dataset class.
    -- <__getitem__>: Return a data point and its metadata information.
    -- <__len__>: Return the number of images.
"""
from data.base_dataset import BaseDataset, get_transform
# from data.image_folder import make_dataset
# from PIL import Image


class TemplateDataset(BaseDataset):
    """A template dataset class for you to implement custom datasets."""
    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
        parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0)  # specify dataset-specific default values
        return parser

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions

        A few things can be done here.
        - save the options (have been done in BaseDataset)
        - get image paths and meta information of the dataset.
        - define the image transformation.
        """
        # save the option and dataset root
        BaseDataset.__init__(self, opt)
        # get the image paths of your dataset;
        self.image_paths = []  # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
        # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
        self.transform = get_transform(opt)

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index -- a random integer for data indexing

        Returns:
            a dictionary of data with their names. It usually contains the data itself and its metadata information.

        Step 1: get a random image path: e.g., path = self.image_paths[index]
        Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
        Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
        Step 4: return a data point as a dictionary.
        """
        path = 'temp'    # needs to be a string
        data_A = None    # needs to be a tensor
        data_B = None    # needs to be a tensor
        return {'data_A': data_A, 'data_B': data_B, 'path': path}

    def __len__(self):
        """Return the total number of images."""
        return len(self.image_paths)


================================================
FILE: src/face3d/extract_kp_videos.py
================================================
import os
import cv2
import time
import glob
import argparse
import face_alignment
import numpy as np
from PIL import Image
from tqdm import tqdm
from itertools import cycle

from torch.multiprocessing import Pool, Process, set_start_method

class KeypointExtractor():
    def __init__(self, device):
        self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, 
                                                     device=device)   

    def extract_keypoint(self, images, name=None, info=True):
        if isinstance(images, list):
            keypoints = []
            if info:
                i_range = tqdm(images,desc='landmark Det:')
            else:
                i_range = images

            for image in i_range:
                current_kp = self.extract_keypoint(image)
                if np.mean(current_kp) == -1 and keypoints:
                    keypoints.append(keypoints[-1])
                else:
                    keypoints.append(current_kp[None])

            keypoints = np.concatenate(keypoints, 0)
            np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
            return keypoints
        else:
            while True:
                try:
                    keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
                    break
                except RuntimeError as e:
                    if str(e).startswith('CUDA'):
                        print("Warning: out of memory, sleep for 1s")
                        time.sleep(1)
                    else:
                        print(e)
                        break    
                except TypeError:
                    print('No face detected in this image')
                    shape = [68, 2]
                    keypoints = -1. * np.ones(shape)                    
                    break
            if name is not None:
                np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
            return keypoints

def read_video(filename):
    frames = []
    cap = cv2.VideoCapture(filename)
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            frames.append(frame)
        else:
            break
    cap.release()
    return frames

def run(data):
    filename, opt, device = data
    os.environ['CUDA_VISIBLE_DEVICES'] = device
    kp_extractor = KeypointExtractor()
    images = read_video(filename)
    name = filename.split('/')[-2:]
    os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
    kp_extractor.extract_keypoint(
        images, 
        name=os.path.join(opt.output_dir, name[-2], name[-1])
    )

if __name__ == '__main__':
    set_start_method('spawn')
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input_dir', type=str, help='the folder of the input files')
    parser.add_argument('--output_dir', type=str, help='the folder of the output files')
    parser.add_argument('--device_ids', type=str, default='0,1')
    parser.add_argument('--workers', type=int, default=4)

    opt = parser.parse_args()
    filenames = list()
    VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
    VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
    extensions = VIDEO_EXTENSIONS
    
    for ext in extensions:
        os.listdir(f'{opt.input_dir}')
        print(f'{opt.input_dir}/*.{ext}')
        filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
    print('Total number of videos:', len(filenames))
    pool = Pool(opt.workers)
    args_list = cycle([opt])
    device_ids = opt.device_ids.split(",")
    device_ids = cycle(device_ids)
    for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
        None


================================================
FILE: src/face3d/extract_kp_videos_safe.py
================================================
import os
import cv2
import time
import glob
import argparse
import numpy as np
from PIL import Image
import torch
from tqdm import tqdm
from itertools import cycle
from torch.multiprocessing import Pool, Process, set_start_method

from facexlib.alignment import landmark_98_to_68
from facexlib.detection import init_detection_model

from facexlib.utils import load_file_from_url
from src.face3d.util.my_awing_arch import FAN

def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
    if model_name == 'awing_fan':
        model = FAN(num_modules=4, num_landmarks=98, device=device)
        model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
    else:
        raise NotImplementedError(f'{model_name} is not implemented.')

    model_path = load_file_from_url(
        url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
    model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
    model.eval()
    model = model.to(device)
    return model


class KeypointExtractor():
    def __init__(self, device='cuda'):

        ### gfpgan/weights
        try:
            import webui  # in webui
            root_path = 'extensions/SadTalker/gfpgan/weights' 

        except:
            root_path = 'gfpgan/weights'

        self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)   
        self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)

    def extract_keypoint(self, images, name=None, info=True):
        if isinstance(images, list):
            keypoints = []
            if info:
                i_range = tqdm(images,desc='landmark Det:')
            else:
                i_range = images

            for image in i_range:
                current_kp = self.extract_keypoint(image)
                # current_kp = self.detector.get_landmarks(np.array(image))
                if np.mean(current_kp) == -1 and keypoints:
                    keypoints.append(keypoints[-1])
                else:
                    keypoints.append(current_kp[None])

            keypoints = np.concatenate(keypoints, 0)
            np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
            return keypoints
        else:
            while True:
                try:
                    with torch.no_grad():
                        # face detection -> face alignment.
                        img = np.array(images)
                        bboxes = self.det_net.detect_faces(images, 0.97)
                        
                        bboxes = bboxes[0]
                        img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]

                        keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]

                        #### keypoints to the original location
                        keypoints[:,0] += int(bboxes[0])
                        keypoints[:,1] += int(bboxes[1])

                        break
                except RuntimeError as e:
                    if str(e).startswith('CUDA'):
                        print("Warning: out of memory, sleep for 1s")
                        time.sleep(1)
                    else:
                        print(e)
                        break    
                except TypeError:
                    print('No face detected in this image')
                    shape = [68, 2]
                    keypoints = -1. * np.ones(shape)                    
                    break
            if name is not None:
                np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
            return keypoints

def read_video(filename):
    frames = []
    cap = cv2.VideoCapture(filename)
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            frames.append(frame)
        else:
            break
    cap.release()
    return frames

def run(data):
    filename, opt, device = data
    os.environ['CUDA_VISIBLE_DEVICES'] = device
    kp_extractor = KeypointExtractor()
    images = read_video(filename)
    name = filename.split('/')[-2:]
    os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
    kp_extractor.extract_keypoint(
        images, 
        name=os.path.join(opt.output_dir, name[-2], name[-1])
    )

if __name__ == '__main__':
    set_start_method('spawn')
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--input_dir', type=str, help='the folder of the input files')
    parser.add_argument('--output_dir', type=str, help='the folder of the output files')
    parser.add_argument('--device_ids', type=str, default='0,1')
    parser.add_argument('--workers', type=int, default=4)

    opt = parser.parse_args()
    filenames = list()
    VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
    VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
    extensions = VIDEO_EXTENSIONS
    
    for ext in extensions:
        os.listdir(f'{opt.input_dir}')
        print(f'{opt.input_dir}/*.{ext}')
        filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
    print('Total number of videos:', len(filenames))
    pool = Pool(opt.workers)
    args_list = cycle([opt])
    device_ids = opt.device_ids.split(",")
    device_ids = cycle(device_ids)
    for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
        None


================================================
FILE: src/face3d/models/__init__.py
================================================
"""This package contains modules related to objective functions, optimizations, and network architectures.

To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
    -- <set_input>:                     unpack data from dataset and apply preprocessing.
    -- <forward>:                       produce intermediate results.
    -- <optimize_parameters>:           calculate loss, gradients, and update network weights.
    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.

In the function <__init__>, you need to define four lists:
    -- self.loss_names (str list):          specify the training losses that you want to plot and save.
    -- self.model_names (str list):         define networks used in our training.
    -- self.visual_names (str list):        specify the images that you want to display and save.
    -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.

Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""

import importlib
from src.face3d.models.base_model import BaseModel


def find_model_using_name(model_name):
    """Import the module "models/[model_name]_model.py".

    In the file, the class called DatasetNameModel() will
    be instantiated. It has to be a subclass of BaseModel,
    and it is case-insensitive.
    """
    model_filename = "face3d.models." + model_name + "_model"
    modellib = importlib.import_module(model_filename)
    model = None
    target_model_name = model_name.replace('_', '') + 'model'
    for name, cls in modellib.__dict__.items():
        if name.lower() == target_model_name.lower() \
           and issubclass(cls, BaseModel):
            model = cls

    if model is None:
        print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
        exit(0)

    return model


def get_option_setter(model_name):
    """Return the static method <modify_commandline_options> of the model class."""
    model_class = find_model_using_name(model_name)
    return model_class.modify_commandline_options


def create_model(opt):
    """Create a model given the option.

    This function warps the class CustomDatasetDataLoader.
    This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from models import create_model
        >>> model = create_model(opt)
    """
    model = find_model_using_name(opt.model)
    instance = model(opt)
    print("model [%s] was created" % type(instance).__name__)
    return instance


================================================
FILE: src/face3d/models/arcface_torch/README.md
================================================
# Distributed Arcface Training in Pytorch

This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
identity on a single server.

## Requirements

- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
- `pip install -r requirements.txt`.
- Download the dataset
  from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
  .

## How to Training

To train a model, run `train.py` with the path to the configs:

### 1. Single node, 8 GPUs:

```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
```

### 2. Multiple nodes, each node 8 GPUs:

Node 0:

```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
```

Node 1:

```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
```

### 3.Training resnet2060 with 8 GPUs:

```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
```

## Model Zoo

- The models are available for non-commercial research purposes only.  
- All models can be found in here.  
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g):   e8pw  
- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)

### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)

ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face 
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. 
As the result, we can evaluate the FAIR performance for different algorithms.  

For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The 
globalised multi-racial testset contains 242,143 identities and 1,624,305 images. 

For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). 
Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. 
There are totally 13,928 positive pairs and 96,983,824 negative pairs.

| Datasets | backbone  | Training throughout | Size / MB  | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
| :---:    | :---      | :---                | :---       |:---                   |:---                  |     
| MS1MV3    | r18  | -              | 91   | **47.85** | **68.33** |
| Glint360k | r18  | 8536           | 91   | **53.32** | **72.07** |
| MS1MV3    | r34  | -              | 130  | **58.72** | **77.36** |
| Glint360k | r34  | 6344           | 130  | **65.10** | **83.02** |
| MS1MV3    | r50  | 5500           | 166  | **63.85** | **80.53** |
| Glint360k | r50  | 5136           | 166  | **70.23** | **87.08** |
| MS1MV3    | r100 | -              | 248  | **69.09** | **84.31** |
| Glint360k | r100 | 3332           | 248  | **75.57** | **90.66** |
| MS1MV3    | mobilefacenet | 12185 | 7.8  | **41.52** | **65.26** |        
| Glint360k | mobilefacenet | 11197 | 7.8  | **44.52** | **66.48** |  

### Performance on IJB-C and Verification Datasets

|   Datasets | backbone      | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw  |  log    |
| :---:      |    :---       | :---          | :---  | :---  |:---   |:---    |:---     |  
| MS1MV3     | r18      | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|         
| MS1MV3     | r34      | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|        
| MS1MV3     | r50      | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|         
| MS1MV3     | r100     | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|        
| MS1MV3     | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
| Glint360k  |r18-0.1   | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| 
| Glint360k  |r34-0.1   | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| 
| Glint360k  |r50-0.1   | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| 
| Glint360k  |r100-0.1  | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|

[comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)


## [Speed Benchmark](docs/speed_benchmark.md)

**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
accuracy with several times faster training performance and smaller GPU memory. 
Partial FC is a sparse variant of the model parallel architecture for large sacle  face recognition. Partial FC use a 
sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a 
sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, 
we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed 
training and mixed precision training.

![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)

More details see 
[speed_benchmark.md](docs/speed_benchmark.md) in docs.

### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)

`-` means training failed because of gpu memory limitations.

| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
| :---    | :--- | :--- | :--- |
|125000   | 4681         | 4824          | 5004     |
|1400000  | **1672**     | 3043          | 4738     |
|5500000  | **-**        | **1389**      | 3975     |
|8000000  | **-**        | **-**         | 3565     |
|16000000 | **-**        | **-**         | 2679     |
|29000000 | **-**        | **-**         | **1855** |

### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)

| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
| :---    | :---      | :---      | :---  |
|125000   | 7358      | 5306      | 4868  |
|1400000  | 32252     | 11178     | 6056  |
|5500000  | **-**     | 32188     | 9854  |
|8000000  | **-**     | **-**     | 12310 |
|16000000 | **-**     | **-**     | 19950 |
|29000000 | **-**     | **-**     | 32324 |

## Evaluation ICCV2021-MFR and IJB-C

More details see [eval.md](docs/eval.md) in docs.

## Test

We tested many versions of PyTorch. Please create an issue if you are having trouble.  

- [x] torch 1.6.0
- [x] torch 1.7.1
- [x] torch 1.8.0
- [x] torch 1.9.0

## Citation

```
@inproceedings{deng2019arcface,
  title={Arcface: Additive angular margin loss for deep face recognition},
  author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  pages={4690--4699},
  year={2019}
}
@inproceedings{an2020partical_fc,
  title={Partial FC: Training 10 Million Identities on a Single Machine},
  author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
  Zhang, Debing and Fu Ying},
  booktitle={Arxiv 2010.05222},
  year={2020}
}
```


================================================
FILE: src/face3d/models/arcface_torch/backbones/__init__.py
================================================
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
from .mobilefacenet import get_mbf


def get_model(name, **kwargs):
    # resnet
    if name == "r18":
        return iresnet18(False, **kwargs)
    elif name == "r34":
        return iresnet34(False, **kwargs)
    elif name == "r50":
        return iresnet50(False, **kwargs)
    elif name == "r100":
        return iresnet100(False, **kwargs)
    elif name == "r200":
        return iresnet200(False, **kwargs)
    elif name == "r2060":
        from .iresnet2060 import iresnet2060
        return iresnet2060(False, **kwargs)
    elif name == "mbf":
        fp16 = kwargs.get("fp16", False)
        num_features = kwargs.get("num_features", 512)
        return get_mbf(fp16=fp16, num_features=num_features)
    else:
        raise ValueError()

================================================
FILE: src/face3d/models/arcface_torch/backbones/iresnet.py
================================================
import torch
from torch import nn

__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=dilation,
                     groups=groups,
                     bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class IBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
        self.conv1 = conv3x3(inplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out


class IResNet(nn.Module):
    fc_scale = 7 * 7
    def __init__(self,
                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
        super(IResNet, self).__init__()
        self.fp16 = fp16
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
        self.prelu = nn.PReLU(self.inplanes)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
        self.dropout = nn.Dropout(p=dropout, inplace=True)
        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
        nn.init.constant_(self.features.weight, 1.0)
        self.features.weight.requires_grad = False

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.1)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, IBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
            )
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      groups=self.groups,
                      base_width=self.base_width,
                      dilation=self.dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        with torch.cuda.amp.autocast(self.fp16):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.prelu(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
            x = self.bn2(x)
            x = torch.flatten(x, 1)
            x = self.dropout(x)
        x = self.fc(x.float() if self.fp16 else x)
        x = self.features(x)
        return x


def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        raise ValueError()
    return model


def iresnet18(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
                    progress, **kwargs)


def iresnet34(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
                    progress, **kwargs)


def iresnet50(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
                    progress, **kwargs)


def iresnet100(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
                    progress, **kwargs)


def iresnet200(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
                    progress, **kwargs)



================================================
FILE: src/face3d/models/arcface_torch/backbones/iresnet2060.py
================================================
import torch
from torch import nn

assert torch.__version__ >= "1.8.1"
from torch.utils.checkpoint import checkpoint_sequential

__all__ = ['iresnet2060']


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=dilation,
                     groups=groups,
                     bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class IBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
        self.conv1 = conv3x3(inplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out


class IResNet(nn.Module):
    fc_scale = 7 * 7

    def __init__(self,
                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
        super(IResNet, self).__init__()
        self.fp16 = fp16
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
        self.prelu = nn.PReLU(self.inplanes)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
        self.dropout = nn.Dropout(p=dropout, inplace=True)
        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
        nn.init.constant_(self.features.weight, 1.0)
        self.features.weight.requires_grad = False

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.1)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, IBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
            )
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      groups=self.groups,
                      base_width=self.base_width,
                      dilation=self.dilation))

        return nn.Sequential(*layers)

    def checkpoint(self, func, num_seg, x):
        if self.training:
            return checkpoint_sequential(func, num_seg, x)
        else:
            return func(x)

    def forward(self, x):
        with torch.cuda.amp.autocast(self.fp16):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.prelu(x)
            x = self.layer1(x)
            x = self.checkpoint(self.layer2, 20, x)
            x = self.checkpoint(self.layer3, 100, x)
            x = self.layer4(x)
            x = self.bn2(x)
            x = torch.flatten(x, 1)
            x = self.dropout(x)
        x = self.fc(x.float() if self.fp16 else x)
        x = self.features(x)
        return x


def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        raise ValueError()
    return model


def iresnet2060(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)


================================================
FILE: src/face3d/models/arcface_torch/backbones/mobilefacenet.py
================================================
'''
Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
Original author cavalleria
'''

import torch.nn as nn
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
import torch


class Flatten(Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class ConvBlock(Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(ConvBlock, self).__init__()
        self.layers = nn.Sequential(
            Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
            BatchNorm2d(num_features=out_c),
            PReLU(num_parameters=out_c)
        )

    def forward(self, x):
        return self.layers(x)


class LinearBlock(Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(LinearBlock, self).__init__()
        self.layers = nn.Sequential(
            Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
            BatchNorm2d(num_features=out_c)
        )

    def forward(self, x):
        return self.layers(x)


class DepthWise(Module):
    def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
        super(DepthWise, self).__init__()
        self.residual = residual
        self.layers = nn.Sequential(
            ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
            ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
            LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        )

    def forward(self, x):
        short_cut = None
        if self.residual:
            short_cut = x
        x = self.layers(x)
        if self.residual:
            output = short_cut + x
        else:
            output = x
        return output


class Residual(Module):
    def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
        super(Residual, self).__init__()
        modules = []
        for _ in range(num_block):
            modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
        self.layers = Sequential(*modules)

    def forward(self, x):
        return self.layers(x)


class GDC(Module):
    def __init__(self, embedding_size):
        super(GDC, self).__init__()
        self.layers = nn.Sequential(
            LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
            Flatten(),
            Linear(512, embedding_size, bias=False),
            BatchNorm1d(embedding_size))

    def forward(self, x):
        return self.layers(x)


class MobileFaceNet(Module):
    def __init__(self, fp16=False, num_features=512):
        super(MobileFaceNet, self).__init__()
        scale = 2
        self.fp16 = fp16
        self.layers = nn.Sequential(
            ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
            ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
            DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
            Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
            DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
            Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
            DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
            Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
        )
        self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
        self.features = GDC(num_features)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x):
        with torch.cuda.amp.autocast(self.fp16):
            x = self.layers(x)
        x = self.conv_sep(x.float() if self.fp16 else x)
        x = self.features(x)
        return x


def get_mbf(fp16, num_features):
    return MobileFaceNet(fp16, num_features)

================================================
FILE: src/face3d/models/arcface_torch/configs/3millions.py
================================================
from easydict import EasyDict as edict

# configs for test speed

config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "synthetic"
config.num_classes = 300 * 10000
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = []


================================================
FILE: src/face3d/models/arcface_torch/configs/3millions_pfc.py
================================================
from easydict import EasyDict as edict

# configs for test speed

config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 0.1
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "synthetic"
config.num_classes = 300 * 10000
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = []


================================================
FILE: src/face3d/models/arcface_torch/configs/__init__.py
================================================


================================================
FILE: src/face3d/models/arcface_torch/configs/base.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = "ms1mv3_arcface_r50"

config.dataset = "ms1m-retinaface-t1"
config.embedding_size = 512
config.sample_rate = 1
config.fp16 = False
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

if config.dataset == "emore":
    config.rec = "/train_tmp/faces_emore"
    config.num_classes = 85742
    config.num_image = 5822653
    config.num_epoch = 16
    config.warmup_epoch = -1
    config.decay_epoch = [8, 14, ]
    config.val_targets = ["lfw", ]

elif config.dataset == "ms1m-retinaface-t1":
    config.rec = "/train_tmp/ms1m-retinaface-t1"
    config.num_classes = 93431
    config.num_image = 5179510
    config.num_epoch = 25
    config.warmup_epoch = -1
    config.decay_epoch = [11, 17, 22]
    config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

elif config.dataset == "glint360k":
    config.rec = "/train_tmp/glint360k"
    config.num_classes = 360232
    config.num_image = 17091657
    config.num_epoch = 20
    config.warmup_epoch = -1
    config.decay_epoch = [8, 12, 15, 18]
    config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

elif config.dataset == "webface":
    config.rec = "/train_tmp/faces_webface_112x112"
    config.num_classes = 10572
    config.num_image = "forget"
    config.num_epoch = 34
    config.warmup_epoch = -1
    config.decay_epoch = [20, 28, 32]
    config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_mbf.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "cosface"
config.network = "mbf"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 0.1
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 2e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_r100.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "cosface"
config.network = "r100"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_r18.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "cosface"
config.network = "r18"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_r34.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "cosface"
config.network = "r34"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/glint360k_r50.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "cosface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "arcface"
config.network = "mbf"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 2e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 20, 25]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "arcface"
config.network = "r18"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "arcface"
config.network = "r2060"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 64
config.lr = 0.1  # batch size is 512

config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "arcface"
config.network = "r34"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1  # batch size is 512

config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]


================================================
FILE: src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
================================================
from easydict import EasyDict as edict

# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate =
Download .txt
gitextract_pkbb6gh_/

├── .gitignore
├── LICENSE
├── README.md
├── app_sadtalker.py
├── cog.yaml
├── docs/
│   ├── FAQ.md
│   ├── best_practice.md
│   ├── changlelog.md
│   ├── face3d.md
│   ├── install.md
│   └── webui_extension.md
├── inference.py
├── launcher.py
├── predict.py
├── quick_demo.ipynb
├── req.txt
├── requirements.txt
├── requirements3d.txt
├── scripts/
│   ├── download_models.sh
│   ├── extension.py
│   └── test.sh
├── src/
│   ├── audio2exp_models/
│   │   ├── audio2exp.py
│   │   └── networks.py
│   ├── audio2pose_models/
│   │   ├── audio2pose.py
│   │   ├── audio_encoder.py
│   │   ├── cvae.py
│   │   ├── discriminator.py
│   │   ├── networks.py
│   │   └── res_unet.py
│   ├── config/
│   │   ├── auido2exp.yaml
│   │   ├── auido2pose.yaml
│   │   ├── facerender.yaml
│   │   ├── facerender_still.yaml
│   │   └── similarity_Lm3D_all.mat
│   ├── face3d/
│   │   ├── data/
│   │   │   ├── __init__.py
│   │   │   ├── base_dataset.py
│   │   │   ├── flist_dataset.py
│   │   │   ├── image_folder.py
│   │   │   └── template_dataset.py
│   │   ├── extract_kp_videos.py
│   │   ├── extract_kp_videos_safe.py
│   │   ├── models/
│   │   │   ├── __init__.py
│   │   │   ├── arcface_torch/
│   │   │   │   ├── README.md
│   │   │   │   ├── backbones/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── iresnet.py
│   │   │   │   │   ├── iresnet2060.py
│   │   │   │   │   └── mobilefacenet.py
│   │   │   │   ├── configs/
│   │   │   │   │   ├── 3millions.py
│   │   │   │   │   ├── 3millions_pfc.py
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── base.py
│   │   │   │   │   ├── glint360k_mbf.py
│   │   │   │   │   ├── glint360k_r100.py
│   │   │   │   │   ├── glint360k_r18.py
│   │   │   │   │   ├── glint360k_r34.py
│   │   │   │   │   ├── glint360k_r50.py
│   │   │   │   │   ├── ms1mv3_mbf.py
│   │   │   │   │   ├── ms1mv3_r18.py
│   │   │   │   │   ├── ms1mv3_r2060.py
│   │   │   │   │   ├── ms1mv3_r34.py
│   │   │   │   │   ├── ms1mv3_r50.py
│   │   │   │   │   └── speed.py
│   │   │   │   ├── dataset.py
│   │   │   │   ├── docs/
│   │   │   │   │   ├── eval.md
│   │   │   │   │   ├── install.md
│   │   │   │   │   ├── modelzoo.md
│   │   │   │   │   └── speed_benchmark.md
│   │   │   │   ├── eval/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   └── verification.py
│   │   │   │   ├── eval_ijbc.py
│   │   │   │   ├── inference.py
│   │   │   │   ├── losses.py
│   │   │   │   ├── onnx_helper.py
│   │   │   │   ├── onnx_ijbc.py
│   │   │   │   ├── partial_fc.py
│   │   │   │   ├── requirement.txt
│   │   │   │   ├── run.sh
│   │   │   │   ├── torch2onnx.py
│   │   │   │   ├── train.py
│   │   │   │   └── utils/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── plot.py
│   │   │   │       ├── utils_amp.py
│   │   │   │       ├── utils_callbacks.py
│   │   │   │       ├── utils_config.py
│   │   │   │       ├── utils_logging.py
│   │   │   │       └── utils_os.py
│   │   │   ├── base_model.py
│   │   │   ├── bfm.py
│   │   │   ├── facerecon_model.py
│   │   │   ├── losses.py
│   │   │   ├── networks.py
│   │   │   └── template_model.py
│   │   ├── options/
│   │   │   ├── __init__.py
│   │   │   ├── base_options.py
│   │   │   ├── inference_options.py
│   │   │   ├── test_options.py
│   │   │   └── train_options.py
│   │   ├── util/
│   │   │   ├── BBRegressorParam_r.mat
│   │   │   ├── __init__.py
│   │   │   ├── detect_lm68.py
│   │   │   ├── generate_list.py
│   │   │   ├── html.py
│   │   │   ├── load_mats.py
│   │   │   ├── my_awing_arch.py
│   │   │   ├── nvdiffrast.py
│   │   │   ├── preprocess.py
│   │   │   ├── skin_mask.py
│   │   │   ├── test_mean_face.txt
│   │   │   ├── util.py
│   │   │   └── visualizer.py
│   │   └── visualize.py
│   ├── facerender/
│   │   ├── animate.py
│   │   ├── modules/
│   │   │   ├── dense_motion.py
│   │   │   ├── discriminator.py
│   │   │   ├── generator.py
│   │   │   ├── keypoint_detector.py
│   │   │   ├── make_animation.py
│   │   │   ├── mapping.py
│   │   │   └── util.py
│   │   └── sync_batchnorm/
│   │       ├── __init__.py
│   │       ├── batchnorm.py
│   │       ├── comm.py
│   │       ├── replicate.py
│   │       └── unittest.py
│   ├── generate_batch.py
│   ├── generate_facerender_batch.py
│   ├── gradio_demo.py
│   ├── test_audio2coeff.py
│   └── utils/
│       ├── audio.py
│       ├── croper.py
│       ├── face_enhancer.py
│       ├── hparams.py
│       ├── init_path.py
│       ├── model2safetensor.py
│       ├── paste_pic.py
│       ├── preprocess.py
│       ├── safetensor_helper.py
│       ├── text2speech.py
│       └── videoio.py
├── webui.bat
└── webui.sh
Download .txt
SYMBOL INDEX (702 symbols across 88 files)

FILE: app_sadtalker.py
  function toggle_audio_file (line 13) | def toggle_audio_file(choice):
  function ref_video_fn (line 19) | def ref_video_fn(path_of_ref_video):
  function sadtalker_demo (line 25) | def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/confi...

FILE: inference.py
  function main (line 15) | def main(args):

FILE: launcher.py
  function check_python_version (line 23) | def check_python_version():
  function commit_hash (line 49) | def commit_hash():
  function run (line 63) | def run(command, desc=None, errdesc=None, custom_env=None, live=False):
  function check_run (line 91) | def check_run(command):
  function is_installed (line 96) | def is_installed(package):
  function repo_dir (line 105) | def repo_dir(name):
  function run_python (line 109) | def run_python(code, desc=None, errdesc=None):
  function run_pip (line 113) | def run_pip(args, desc=None):
  function check_run_python (line 121) | def check_run_python(code):
  function git_clone (line 125) | def git_clone(url, dir, name, commithash=None):
  function git_pull_recursive (line 146) | def git_pull_recursive(dir):
  function run_extension_installer (line 156) | def run_extension_installer(extension_dir):
  function prepare_environment (line 170) | def prepare_environment():
  function start (line 195) | def start():

FILE: predict.py
  class Predictor (line 16) | class Predictor(BasePredictor):
    method setup (line 17) | def setup(self):
    method predict (line 44) | def predict(
  function load_default (line 172) | def load_default():

FILE: scripts/extension.py
  function check_all_files_safetensor (line 14) | def check_all_files_safetensor(current_dir):
  function check_all_files (line 33) | def check_all_files(current_dir):
  function download_model (line 58) | def download_model(local_dir='./checkpoints'):
  function get_source_image (line 62) | def get_source_image(image):
  function get_img_from_txt2img (line 65) | def get_img_from_txt2img(x):
  function get_img_from_img2img (line 73) | def get_img_from_img2img(x):
  function get_default_checkpoint_path (line 81) | def get_default_checkpoint_path():
  function install (line 106) | def install():
  function on_ui_tabs (line 162) | def on_ui_tabs():
  function on_ui_settings (line 183) | def on_ui_settings():

FILE: src/audio2exp_models/audio2exp.py
  class Audio2Exp (line 6) | class Audio2Exp(nn.Module):
    method __init__ (line 7) | def __init__(self, netG, cfg, device, prepare_training_loss=False):
    method test (line 13) | def test(self, batch):

FILE: src/audio2exp_models/networks.py
  class Conv2d (line 5) | class Conv2d(nn.Module):
    method __init__ (line 6) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
    method forward (line 16) | def forward(self, x):
  class SimpleWrapperV2 (line 26) | class SimpleWrapperV2(nn.Module):
    method __init__ (line 27) | def __init__(self) -> None:
    method forward (line 67) | def forward(self, x, ref, ratio):

FILE: src/audio2pose_models/audio2pose.py
  class Audio2Pose (line 7) | class Audio2Pose(nn.Module):
    method __init__ (line 8) | def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
    method forward (line 24) | def forward(self, x):
    method test (line 48) | def test(self, x):

FILE: src/audio2pose_models/audio_encoder.py
  class Conv2d (line 5) | class Conv2d(nn.Module):
    method __init__ (line 6) | def __init__(self, cin, cout, kernel_size, stride, padding, residual=F...
    method forward (line 15) | def forward(self, x):
  class AudioEncoder (line 21) | class AudioEncoder(nn.Module):
    method __init__ (line 22) | def __init__(self, wav2lip_checkpoint, device):
    method forward (line 54) | def forward(self, audio_sequences):

FILE: src/audio2pose_models/cvae.py
  function class2onehot (line 6) | def class2onehot(idx, class_num):
  class CVAE (line 13) | class CVAE(nn.Module):
    method __init__ (line 14) | def __init__(self, cfg):
    method reparameterize (line 30) | def reparameterize(self, mu, logvar):
    method forward (line 35) | def forward(self, batch):
    method test (line 43) | def test(self, batch):
  class ENCODER (line 51) | class ENCODER(nn.Module):
    method __init__ (line 52) | def __init__(self, layer_sizes, latent_size, num_classes,
    method forward (line 73) | def forward(self, batch):
  class DECODER (line 99) | class DECODER(nn.Module):
    method __init__ (line 100) | def __init__(self, layer_sizes, latent_size, num_classes,
    method forward (line 123) | def forward(self, batch):

FILE: src/audio2pose_models/discriminator.py
  class ConvNormRelu (line 5) | class ConvNormRelu(nn.Module):
    method __init__ (line 6) | def __init__(self, conv_type='1d', in_channels=3, out_channels=64, dow...
    method forward (line 49) | def forward(self, x):
  class PoseSequenceDiscriminator (line 59) | class PoseSequenceDiscriminator(nn.Module):
    method __init__ (line 60) | def __init__(self, cfg):
    method forward (line 72) | def forward(self, x):

FILE: src/audio2pose_models/networks.py
  class ResidualConv (line 5) | class ResidualConv(nn.Module):
    method __init__ (line 6) | def __init__(self, input_dim, output_dim, stride, padding):
    method forward (line 24) | def forward(self, x):
  class Upsample (line 29) | class Upsample(nn.Module):
    method __init__ (line 30) | def __init__(self, input_dim, output_dim, kernel, stride):
    method forward (line 37) | def forward(self, x):
  class Squeeze_Excite_Block (line 41) | class Squeeze_Excite_Block(nn.Module):
    method __init__ (line 42) | def __init__(self, channel, reduction=16):
    method forward (line 52) | def forward(self, x):
  class ASPP (line 59) | class ASPP(nn.Module):
    method __init__ (line 60) | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
    method forward (line 88) | def forward(self, x):
    method _init_weights (line 95) | def _init_weights(self):
  class Upsample_ (line 104) | class Upsample_(nn.Module):
    method __init__ (line 105) | def __init__(self, scale=2):
    method forward (line 110) | def forward(self, x):
  class AttentionBlock (line 114) | class AttentionBlock(nn.Module):
    method __init__ (line 115) | def __init__(self, input_encoder, input_decoder, output_dim):
    method forward (line 137) | def forward(self, x1, x2):

FILE: src/audio2pose_models/res_unet.py
  class ResUnet (line 6) | class ResUnet(nn.Module):
    method __init__ (line 7) | def __init__(self, channel=1, filters=[32, 64, 128, 256]):
    method forward (line 39) | def forward(self, x):

FILE: src/face3d/data/__init__.py
  function find_dataset_using_name (line 19) | def find_dataset_using_name(dataset_name):
  function get_option_setter (line 42) | def get_option_setter(dataset_name):
  function create_dataset (line 48) | def create_dataset(opt, rank=0):
  class CustomDatasetDataLoader (line 62) | class CustomDatasetDataLoader():
    method __init__ (line 65) | def __init__(self, opt, rank=0):
    method set_epoch (line 99) | def set_epoch(self, epoch):
    method load_data (line 104) | def load_data(self):
    method __len__ (line 107) | def __len__(self):
    method __iter__ (line 111) | def __iter__(self):

FILE: src/face3d/data/base_dataset.py
  class BaseDataset (line 13) | class BaseDataset(data.Dataset, ABC):
    method __init__ (line 23) | def __init__(self, opt):
    method modify_commandline_options (line 34) | def modify_commandline_options(parser, is_train):
    method __len__ (line 47) | def __len__(self):
    method __getitem__ (line 52) | def __getitem__(self, index):
  function get_transform (line 64) | def get_transform(grayscale=False):
  function get_affine_mat (line 71) | def get_affine_mat(opt, size):
  function apply_img_affine (line 98) | def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
  function apply_lm_affine (line 101) | def apply_lm_affine(landmark, affine, flip, size):

FILE: src/face3d/data/flist_dataset.py
  function default_flist_reader (line 19) | def default_flist_reader(flist):
  function jason_flist_reader (line 31) | def jason_flist_reader(flist):
  function parse_label (line 36) | def parse_label(label):
  class FlistDataset (line 40) | class FlistDataset(BaseDataset):
    method __init__ (line 46) | def __init__(self, opt):
    method __getitem__ (line 67) | def __getitem__(self, index):
    method _augmentation (line 111) | def _augmentation(self, img, lm, opt, msk=None):
    method __len__ (line 122) | def __len__(self):

FILE: src/face3d/data/image_folder.py
  function is_image_file (line 20) | def is_image_file(filename):
  function make_dataset (line 24) | def make_dataset(dir, max_dataset_size=float("inf")):
  function default_loader (line 36) | def default_loader(path):
  class ImageFolder (line 40) | class ImageFolder(data.Dataset):
    method __init__ (line 42) | def __init__(self, root, transform=None, return_paths=False,
    method __getitem__ (line 55) | def __getitem__(self, index):
    method __len__ (line 65) | def __len__(self):

FILE: src/face3d/data/template_dataset.py
  class TemplateDataset (line 19) | class TemplateDataset(BaseDataset):
    method modify_commandline_options (line 22) | def modify_commandline_options(parser, is_train):
    method __init__ (line 36) | def __init__(self, opt):
    method __getitem__ (line 54) | def __getitem__(self, index):
    method __len__ (line 73) | def __len__(self):

FILE: src/face3d/extract_kp_videos.py
  class KeypointExtractor (line 14) | class KeypointExtractor():
    method __init__ (line 15) | def __init__(self, device):
    method extract_keypoint (line 19) | def extract_keypoint(self, images, name=None, info=True):
  function read_video (line 58) | def read_video(filename):
  function run (line 72) | def run(data):

FILE: src/face3d/extract_kp_videos_safe.py
  function init_alignment_model (line 19) | def init_alignment_model(model_name, half=False, device='cuda', model_ro...
  class KeypointExtractor (line 34) | class KeypointExtractor():
    method __init__ (line 35) | def __init__(self, device='cuda'):
    method extract_keypoint (line 48) | def extract_keypoint(self, images, name=None, info=True):
  function read_video (line 101) | def read_video(filename):
  function run (line 115) | def run(data):

FILE: src/face3d/models/__init__.py
  function find_model_using_name (line 25) | def find_model_using_name(model_name):
  function get_option_setter (line 48) | def get_option_setter(model_name):
  function create_model (line 54) | def create_model(opt):

FILE: src/face3d/models/arcface_torch/backbones/__init__.py
  function get_model (line 5) | def get_model(name, **kwargs):

FILE: src/face3d/models/arcface_torch/backbones/iresnet.py
  function conv3x3 (line 7) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  function conv1x1 (line 19) | def conv1x1(in_planes, out_planes, stride=1):
  class IBasicBlock (line 28) | class IBasicBlock(nn.Module):
    method __init__ (line 30) | def __init__(self, inplanes, planes, stride=1, downsample=None,
    method forward (line 46) | def forward(self, x):
  class IResNet (line 60) | class IResNet(nn.Module):
    method __init__ (line 62) | def __init__(self,
    method _make_layer (line 114) | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    method forward (line 140) | def forward(self, x):
  function _iresnet (line 157) | def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
  function iresnet18 (line 164) | def iresnet18(pretrained=False, progress=True, **kwargs):
  function iresnet34 (line 169) | def iresnet34(pretrained=False, progress=True, **kwargs):
  function iresnet50 (line 174) | def iresnet50(pretrained=False, progress=True, **kwargs):
  function iresnet100 (line 179) | def iresnet100(pretrained=False, progress=True, **kwargs):
  function iresnet200 (line 184) | def iresnet200(pretrained=False, progress=True, **kwargs):

FILE: src/face3d/models/arcface_torch/backbones/iresnet2060.py
  function conv3x3 (line 10) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  function conv1x1 (line 22) | def conv1x1(in_planes, out_planes, stride=1):
  class IBasicBlock (line 31) | class IBasicBlock(nn.Module):
    method __init__ (line 34) | def __init__(self, inplanes, planes, stride=1, downsample=None,
    method forward (line 50) | def forward(self, x):
  class IResNet (line 64) | class IResNet(nn.Module):
    method __init__ (line 67) | def __init__(self,
    method _make_layer (line 119) | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    method checkpoint (line 145) | def checkpoint(self, func, num_seg, x):
    method forward (line 151) | def forward(self, x):
  function _iresnet (line 168) | def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
  function iresnet2060 (line 175) | def iresnet2060(pretrained=False, progress=True, **kwargs):

FILE: src/face3d/models/arcface_torch/backbones/mobilefacenet.py
  class Flatten (line 11) | class Flatten(Module):
    method forward (line 12) | def forward(self, x):
  class ConvBlock (line 16) | class ConvBlock(Module):
    method __init__ (line 17) | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=...
    method forward (line 25) | def forward(self, x):
  class LinearBlock (line 29) | class LinearBlock(Module):
    method __init__ (line 30) | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=...
    method forward (line 37) | def forward(self, x):
  class DepthWise (line 41) | class DepthWise(Module):
    method __init__ (line 42) | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=...
    method forward (line 51) | def forward(self, x):
  class Residual (line 63) | class Residual(Module):
    method __init__ (line 64) | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1),...
    method forward (line 71) | def forward(self, x):
  class GDC (line 75) | class GDC(Module):
    method __init__ (line 76) | def __init__(self, embedding_size):
    method forward (line 84) | def forward(self, x):
  class MobileFaceNet (line 88) | class MobileFaceNet(Module):
    method __init__ (line 89) | def __init__(self, fp16=False, num_features=512):
    method _initialize_weights (line 107) | def _initialize_weights(self):
    method forward (line 121) | def forward(self, x):
  function get_mbf (line 129) | def get_mbf(fp16, num_features):

FILE: src/face3d/models/arcface_torch/dataset.py
  class BackgroundGenerator (line 13) | class BackgroundGenerator(threading.Thread):
    method __init__ (line 14) | def __init__(self, generator, local_rank, max_prefetch=6):
    method run (line 22) | def run(self):
    method next (line 28) | def next(self):
    method __next__ (line 34) | def __next__(self):
    method __iter__ (line 37) | def __iter__(self):
  class DataLoaderX (line 41) | class DataLoaderX(DataLoader):
    method __init__ (line 43) | def __init__(self, local_rank, **kwargs):
    method __iter__ (line 48) | def __iter__(self):
    method preload (line 54) | def preload(self):
    method __next__ (line 62) | def __next__(self):
  class MXFaceDataset (line 71) | class MXFaceDataset(Dataset):
    method __init__ (line 72) | def __init__(self, root_dir, local_rank):
    method __getitem__ (line 93) | def __getitem__(self, index):
    method __len__ (line 106) | def __len__(self):
  class SyntheticDataset (line 110) | class SyntheticDataset(Dataset):
    method __init__ (line 111) | def __init__(self, local_rank):
    method __getitem__ (line 120) | def __getitem__(self, index):
    method __len__ (line 123) | def __len__(self):

FILE: src/face3d/models/arcface_torch/eval/verification.py
  class LFold (line 41) | class LFold:
    method __init__ (line 42) | def __init__(self, n_splits=2, shuffle=False):
    method split (line 47) | def split(self, indices):
  function calculate_roc (line 54) | def calculate_roc(thresholds,
  function calculate_accuracy (line 109) | def calculate_accuracy(threshold, dist, actual_issame):
  function calculate_val (line 124) | def calculate_val(thresholds,
  function calculate_val_far (line 165) | def calculate_val_far(threshold, dist, actual_issame):
  function evaluate (line 179) | def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
  function load_bin (line 200) | def load_bin(path, image_size):
  function test (line 227) | def test(data_set, backbone, batch_size, nfolds=10):
  function dumpR (line 275) | def dumpR(data_set,

FILE: src/face3d/models/arcface_torch/eval_ijbc.py
  class Embedding (line 54) | class Embedding(object):
    method __init__ (line 55) | def __init__(self, prefix, data_shape, batch_size=1):
    method get (line 75) | def get(self, rimg, landmark):
    method forward_db (line 104) | def forward_db(self, batch_data):
  function divideIntoNstrand (line 113) | def divideIntoNstrand(listTemp, n):
  function read_template_media_list (line 120) | def read_template_media_list(path):
  function read_template_pair_list (line 131) | def read_template_pair_list(path):
  function read_image_feature (line 145) | def read_image_feature(path):
  function get_image_feature (line 154) | def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
  function image2template_feature (line 212) | def image2template_feature(img_feats=None, templates=None, medias=None):
  function verification (line 252) | def verification(template_norm_feats=None,
  function verification2 (line 282) | def verification2(template_norm_feats=None,
  function read_score (line 306) | def read_score(path):

FILE: src/face3d/models/arcface_torch/inference.py
  function inference (line 11) | def inference(weight, name, img):

FILE: src/face3d/models/arcface_torch/losses.py
  function get_loss (line 5) | def get_loss(name):
  class CosFace (line 14) | class CosFace(nn.Module):
    method __init__ (line 15) | def __init__(self, s=64.0, m=0.40):
    method forward (line 20) | def forward(self, cosine, label):
  class ArcFace (line 29) | class ArcFace(nn.Module):
    method __init__ (line 30) | def __init__(self, s=64.0, m=0.5):
    method forward (line 35) | def forward(self, cosine: torch.Tensor, label):

FILE: src/face3d/models/arcface_torch/onnx_helper.py
  class ArcFaceORT (line 15) | class ArcFaceORT:
    method __init__ (line 16) | def __init__(self, model_path, cpu=False):
    method check (line 22) | def check(self, track='cfat', test_img = None):
    method check_batch (line 184) | def check_batch(self, img):
    method meta_info (line 202) | def meta_info(self):
    method forward (line 206) | def forward(self, imgs):
    method benchmark (line 222) | def benchmark(self, img):

FILE: src/face3d/models/arcface_torch/onnx_ijbc.py
  class AlignedDataSet (line 28) | class AlignedDataSet(mx.gluon.data.Dataset):
    method __init__ (line 29) | def __init__(self, root, lines, align=True):
    method __len__ (line 34) | def __len__(self):
    method __getitem__ (line 37) | def __getitem__(self, idx):
  function extract (line 54) | def extract(model_root, dataset):
  function read_template_media_list (line 78) | def read_template_media_list(path):
  function read_template_pair_list (line 85) | def read_template_pair_list(path):
  function read_image_feature (line 93) | def read_image_feature(path):
  function image2template_feature (line 99) | def image2template_feature(img_feats=None,
  function verification (line 125) | def verification(template_norm_feats=None,
  function verification2 (line 147) | def verification2(template_norm_feats=None,
  function main (line 169) | def main(args):

FILE: src/face3d/models/arcface_torch/partial_fc.py
  class PartialFC (line 11) | class PartialFC(Module):
    method __init__ (line 20) | def __init__(self, rank, local_rank, world_size, batch_size, resume,
    method save_params (line 93) | def save_params(self):
    method sample (line 100) | def sample(self, total_label):
    method forward (line 125) | def forward(self, total_features, norm_weight):
    method update (line 133) | def update(self):
    method prepare (line 139) | def prepare(self, label, optimizer):
    method forward_backward (line 159) | def forward_backward(self, label, features, optimizer):

FILE: src/face3d/models/arcface_torch/torch2onnx.py
  function convert_onnx (line 6) | def convert_onnx(net, path_module, output, opset=11, simplify=False):

FILE: src/face3d/models/arcface_torch/train.py
  function main (line 21) | def main(args):

FILE: src/face3d/models/arcface_torch/utils/plot.py
  function read_template_pair_list (line 19) | def read_template_pair_list(path):

FILE: src/face3d/models/arcface_torch/utils/utils_amp.py
  class _MultiDeviceReplicator (line 14) | class _MultiDeviceReplicator(object):
    method __init__ (line 19) | def __init__(self, master_tensor: torch.Tensor) -> None:
    method get (line 24) | def get(self, device) -> torch.Tensor:
  class MaxClipGradScaler (line 32) | class MaxClipGradScaler(GradScaler):
    method __init__ (line 33) | def __init__(self, init_scale, max_scale: float, growth_interval=100):
    method scale_clip (line 37) | def scale_clip(self):
    method scale (line 46) | def scale(self, outputs):

FILE: src/face3d/models/arcface_torch/utils/utils_callbacks.py
  class CallBackVerification (line 12) | class CallBackVerification(object):
    method __init__ (line 13) | def __init__(self, frequent, rank, val_targets, rec_prefix, image_size...
    method ver_test (line 23) | def ver_test(self, backbone: torch.nn.Module, global_step: int):
    method init_dataset (line 36) | def init_dataset(self, val_targets, data_dir, image_size):
    method __call__ (line 44) | def __call__(self, num_update, backbone: torch.nn.Module):
  class CallBackLogging (line 51) | class CallBackLogging(object):
    method __init__ (line 52) | def __init__(self, frequent, rank, total_step, batch_size, world_size,...
    method __call__ (line 64) | def __call__(self,
  class CallBackModelCheckpoint (line 105) | class CallBackModelCheckpoint(object):
    method __init__ (line 106) | def __init__(self, rank, output="./"):
    method __call__ (line 110) | def __call__(self, global_step, backbone, partial_fc, ):

FILE: src/face3d/models/arcface_torch/utils/utils_config.py
  function get_config (line 5) | def get_config(config_file):

FILE: src/face3d/models/arcface_torch/utils/utils_logging.py
  class AverageMeter (line 6) | class AverageMeter(object):
    method __init__ (line 10) | def __init__(self):
    method reset (line 17) | def reset(self):
    method update (line 23) | def update(self, val, n=1):
  function init_logging (line 30) | def init_logging(rank, models_root):

FILE: src/face3d/models/base_model.py
  class BaseModel (line 12) | class BaseModel(ABC):
    method __init__ (line 22) | def __init__(self, opt):
    method dict_grad_hook_factory (line 49) | def dict_grad_hook_factory(add_func=lambda x: x):
    method modify_commandline_options (line 60) | def modify_commandline_options(parser, is_train):
    method set_input (line 73) | def set_input(self, input):
    method forward (line 82) | def forward(self):
    method optimize_parameters (line 87) | def optimize_parameters(self):
    method setup (line 91) | def setup(self, opt):
    method parallelize (line 107) | def parallelize(self, convert_sync_batchnorm=True):
    method data_dependent_initialize (line 138) | def data_dependent_initialize(self, data):
    method train (line 141) | def train(self):
    method eval (line 148) | def eval(self):
    method test (line 155) | def test(self):
    method compute_visuals (line 165) | def compute_visuals(self):
    method get_image_paths (line 169) | def get_image_paths(self, name='A'):
    method update_learning_rate (line 173) | def update_learning_rate(self):
    method get_current_visuals (line 184) | def get_current_visuals(self):
    method get_current_losses (line 192) | def get_current_losses(self):
    method save_networks (line 200) | def save_networks(self, epoch):
    method __patch_instance_norm_state_dict (line 230) | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i...
    method load_networks (line 244) | def load_networks(self, epoch):
    method print_networks (line 284) | def print_networks(self, verbose):
    method set_requires_grad (line 302) | def set_requires_grad(self, nets, requires_grad=False):
    method generate_visuals_for_evaluation (line 315) | def generate_visuals_for_evaluation(self, data, mode):

FILE: src/face3d/models/bfm.py
  function perspective_projection (line 11) | def perspective_projection(focal, center):
  class SH (line 19) | class SH:
    method __init__ (line 20) | def __init__(self):
  class ParametricFaceModel (line 26) | class ParametricFaceModel:
    method __init__ (line 27) | def __init__(self,
    method to (line 80) | def to(self, device):
    method compute_shape (line 87) | def compute_shape(self, id_coeff, exp_coeff):
    method compute_texture (line 103) | def compute_texture(self, tex_coeff, normalize=True):
    method compute_norm (line 118) | def compute_norm(self, face_shape):
    method compute_color (line 141) | def compute_color(self, face_texture, face_norm, gamma):
    method compute_rotation (line 175) | def compute_rotation(self, angles):
    method to_camera (line 211) | def to_camera(self, face_shape):
    method to_image (line 215) | def to_image(self, face_shape):
    method transform (line 230) | def transform(self, face_shape, rot, trans):
    method get_landmarks (line 243) | def get_landmarks(self, face_proj):
    method split_coeff (line 253) | def split_coeff(self, coeffs):
    method compute_for_render (line 275) | def compute_for_render(self, coeffs):
    method compute_for_render_woRotation (line 302) | def compute_for_render_woRotation(self, coeffs):

FILE: src/face3d/models/facerecon_model.py
  class FaceReconModel (line 17) | class FaceReconModel(BaseModel):
    method modify_commandline_options (line 20) | def modify_commandline_options(parser, is_train=False):
    method __init__ (line 71) | def __init__(self, opt):
    method set_input (line 115) | def set_input(self, input):
    method forward (line 127) | def forward(self, output_coeff, device):
    method compute_losses (line 137) | def compute_losses(self):
    method optimize_parameters (line 169) | def optimize_parameters(self, isTrain=True):
    method compute_visuals (line 178) | def compute_visuals(self):
    method save_mesh (line 200) | def save_mesh(self, name):
    method save_coeff (line 211) | def save_coeff(self,name):

FILE: src/face3d/models/losses.py
  function resize_n_crop (line 7) | def resize_n_crop(image, M, dsize=112):
  class PerceptualLoss (line 13) | class PerceptualLoss(nn.Module):
    method __init__ (line 14) | def __init__(self, recog_net, input_size=112):
    method forward (line 19) | def forward(imageA, imageB, M):
  function perceptual_loss (line 39) | def perceptual_loss(id_featureA, id_featureB):
  function photo_loss (line 45) | def photo_loss(imageA, imageB, mask, eps=1e-6):
  function landmark_loss (line 56) | def landmark_loss(predict_lm, gt_lm, weight=None):
  function reg_loss (line 76) | def reg_loss(coeffs_dict, opt=None):
  function reflectance_loss (line 101) | def reflectance_loss(texture, mask):

FILE: src/face3d/models/networks.py
  function resize_n_crop (line 21) | def resize_n_crop(image, M, dsize=112):
  function filter_state_dict (line 26) | def filter_state_dict(state_dict, remove_name='fc'):
  function get_scheduler (line 34) | def get_scheduler(optimizer, opt):
  function define_net_recon (line 61) | def define_net_recon(net_recon, use_last_fc=False, init_path=None):
  function define_net_recog (line 64) | def define_net_recog(net_recog, pretrained_path=None):
  class ReconNetWrapper (line 69) | class ReconNetWrapper(nn.Module):
    method __init__ (line 71) | def __init__(self, net_recon, use_last_fc=False, init_path=None):
    method forward (line 97) | def forward(self, x):
  class RecogNetWrapper (line 107) | class RecogNetWrapper(nn.Module):
    method __init__ (line 108) | def __init__(self, net_recog, pretrained_path=None, input_size=112):
    method forward (line 121) | def forward(self, image, M):
  function conv3x3 (line 146) | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: in...
  function conv1x1 (line 152) | def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool...
  class BasicBlock (line 157) | class BasicBlock(nn.Module):
    method __init__ (line 160) | def __init__(
    method forward (line 187) | def forward(self, x: Tensor) -> Tensor:
  class Bottleneck (line 206) | class Bottleneck(nn.Module):
    method __init__ (line 215) | def __init__(
    method forward (line 241) | def forward(self, x: Tensor) -> Tensor:
  class ResNet (line 264) | class ResNet(nn.Module):
    method __init__ (line 266) | def __init__(
    method _make_layer (line 331) | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], plan...
    method _forward_impl (line 356) | def _forward_impl(self, x: Tensor) -> Tensor:
    method forward (line 374) | def forward(self, x: Tensor) -> Tensor:
  function _resnet (line 378) | def _resnet(
  function resnet18 (line 394) | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: ...
  function resnet34 (line 406) | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: ...
  function resnet50 (line 418) | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: ...
  function resnet101 (line 430) | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs:...
  function resnet152 (line 442) | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs:...
  function resnext50_32x4d (line 454) | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **k...
  function resnext101_32x8d (line 468) | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **...
  function wide_resnet50_2 (line 482) | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **k...
  function wide_resnet101_2 (line 500) | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **...

FILE: src/face3d/models/template_model.py
  class TemplateModel (line 24) | class TemplateModel(BaseModel):
    method modify_commandline_options (line 26) | def modify_commandline_options(parser, is_train=True):
    method __init__ (line 42) | def __init__(self, opt):
    method set_input (line 73) | def set_input(self, input):
    method forward (line 84) | def forward(self):
    method backward (line 88) | def backward(self):
    method optimize_parameters (line 95) | def optimize_parameters(self):

FILE: src/face3d/options/base_options.py
  class BaseOptions (line 13) | class BaseOptions():
    method __init__ (line 20) | def __init__(self, cmd_line=None):
    method initialize (line 27) | def initialize(self, parser):
    method gather_options (line 52) | def gather_options(self):
    method print_options (line 93) | def print_options(self, opt):
    method parse (line 122) | def parse(self):

FILE: src/face3d/options/inference_options.py
  class InferenceOptions (line 4) | class InferenceOptions(BaseOptions):
    method initialize (line 10) | def initialize(self, parser):

FILE: src/face3d/options/test_options.py
  class TestOptions (line 7) | class TestOptions(BaseOptions):
    method initialize (line 13) | def initialize(self, parser):

FILE: src/face3d/options/train_options.py
  class TrainOptions (line 7) | class TrainOptions(BaseOptions):
    method initialize (line 13) | def initialize(self, parser):

FILE: src/face3d/util/detect_lm68.py
  function save_label (line 12) | def save_label(labels, save_path):
  function draw_landmarks (line 15) | def draw_landmarks(img, landmark, save_name):
  function load_data (line 35) | def load_data(img_name, txt_name):
  function load_lm_graph (line 39) | def load_lm_graph(graph_filename):
  function detect_68p (line 53) | def detect_68p(img_path,sess,input_op,output_op):

FILE: src/face3d/util/generate_list.py
  function write_list (line 7) | def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder=...
  function check_list (line 21) | def check_list(rlms_list, rimgs_list, rmsks_list):

FILE: src/face3d/util/html.py
  class HTML (line 6) | class HTML:
    method __init__ (line 14) | def __init__(self, web_dir, title, refresh=0):
    method get_image_dir (line 35) | def get_image_dir(self):
    method add_header (line 39) | def add_header(self, text):
    method add_images (line 48) | def add_images(self, ims, txts, links, width=400):
    method save (line 68) | def save(self):

FILE: src/face3d/util/load_mats.py
  function LoadExpBasis (line 11) | def LoadExpBasis(bfm_folder='BFM'):
  function transferBFM09 (line 32) | def transferBFM09(bfm_folder='BFM'):
  function load_lm3d (line 105) | def load_lm3d(bfm_folder):

FILE: src/face3d/util/my_awing_arch.py
  function calculate_points (line 8) | def calculate_points(heatmaps):
  class AddCoordsTh (line 44) | class AddCoordsTh(nn.Module):
    method __init__ (line 46) | def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=Fal...
    method forward (line 53) | def forward(self, input_tensor, heatmap=None):
  class CoordConvTh (line 110) | class CoordConvTh(nn.Module):
    method __init__ (line 113) | def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, f...
    method forward (line 123) | def forward(self, input_tensor, heatmap=None):
  function conv3x3 (line 130) | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilati...
  class BasicBlock (line 135) | class BasicBlock(nn.Module):
    method __init__ (line 138) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 148) | def forward(self, x):
  class ConvBlock (line 165) | class ConvBlock(nn.Module):
    method __init__ (line 167) | def __init__(self, in_planes, out_planes):
    method forward (line 185) | def forward(self, x):
  class HourGlass (line 210) | class HourGlass(nn.Module):
    method __init__ (line 212) | def __init__(self, num_modules, depth, num_features, first_one=False):
    method _generate_network (line 230) | def _generate_network(self, level):
    method _forward (line 242) | def _forward(self, level, inp):
    method forward (line 264) | def forward(self, x, heatmap):
  class FAN (line 269) | class FAN(nn.Module):
    method __init__ (line 271) | def __init__(self, num_modules=1, end_relu=False, gray_scale=False, nu...
    method forward (line 324) | def forward(self, x):
    method get_landmarks (line 359) | def get_landmarks(self, img):

FILE: src/face3d/util/nvdiffrast.py
  class MeshRenderer (line 32) | class MeshRenderer(nn.Module):
    method __init__ (line 33) | def __init__(self,
    method forward (line 50) | def forward(self, vertex, tri, feat=None):

FILE: src/face3d/util/preprocess.py
  function POS (line 17) | def POS(xp, x):
  function resize_n_crop_img (line 42) | def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
  function extract_5p (line 66) | def extract_5p(lm):
  function align_img (line 74) | def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor...

FILE: src/face3d/util/skin_mask.py
  class GMM (line 9) | class GMM:
    method __init__ (line 10) | def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
    method likelihood (line 23) | def likelihood(self, data):
  function _rgb2ycbcr (line 42) | def _rgb2ycbcr(rgb):
  function _bgr2ycbcr (line 54) | def _bgr2ycbcr(bgr):
  function skinmask (line 90) | def skinmask(imbgr):
  function get_skin_mask (line 111) | def get_skin_mask(img_path):

FILE: src/face3d/util/util.py
  function str2bool (line 14) | def str2bool(v):
  function copyconf (line 25) | def copyconf(default_opt, **kwargs):
  function genvalconf (line 31) | def genvalconf(train_opt, **kwargs):
  function find_class_in_module (line 43) | def find_class_in_module(target_cls_name, module):
  function tensor2im (line 56) | def tensor2im(input_image, imtype=np.uint8):
  function diagnose_network (line 77) | def diagnose_network(net, name='network'):
  function save_image (line 96) | def save_image(image_numpy, image_path, aspect_ratio=1.0):
  function print_numpy (line 116) | def print_numpy(x, val=True, shp=False):
  function mkdirs (line 132) | def mkdirs(paths):
  function mkdir (line 145) | def mkdir(path):
  function correct_resize_label (line 155) | def correct_resize_label(t, size):
  function correct_resize (line 169) | def correct_resize(t, size, mode=Image.BICUBIC):
  function draw_landmarks (line 180) | def draw_landmarks(img, landmark, color='r', step=2):

FILE: src/face3d/util/visualizer.py
  function save_images (line 13) | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
  class Visualizer (line 44) | class Visualizer():
    method __init__ (line 50) | def __init__(self, opt):
    method reset (line 77) | def reset(self):
    method display_current_results (line 82) | def display_current_results(self, visuals, total_iters, epoch, save_re...
    method plot_current_losses (line 117) | def plot_current_losses(self, total_iters, losses):
    method print_current_losses (line 131) | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
  class MyVisualizer (line 150) | class MyVisualizer:
    method __init__ (line 151) | def __init__(self, opt):
    method display_current_results (line 174) | def display_current_results(self, visuals, total_iters, epoch, dataset...
    method plot_current_losses (line 205) | def plot_current_losses(self, total_iters, losses, dataset='train'):
    method print_current_losses (line 210) | def print_current_losses(self, epoch, iters, losses, t_comp, t_data, d...

FILE: src/face3d/visualize.py
  function gen_composed_video (line 12) | def gen_composed_video(args, device, first_frame_coeff, coeff_path, audi...

FILE: src/facerender/animate.py
  class AnimateFromCoeff (line 33) | class AnimateFromCoeff():
    method __init__ (line 35) | def __init__(self, sadtalker_path, device):
    method load_cpk_facevid2vid_safetensor (line 86) | def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=N...
    method load_cpk_facevid2vid (line 113) | def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discri...
    method load_cpk_mapping (line 143) | def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminato...
    method generate (line 157) | def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=No...

FILE: src/facerender/modules/dense_motion.py
  class DenseMotionNetwork (line 9) | class DenseMotionNetwork(nn.Module):
    method __init__ (line 14) | def __init__(self, block_expansion, num_blocks, max_features, num_kp, ...
    method create_sparse_motions (line 34) | def create_sparse_motions(self, feature, kp_driving, kp_source):
    method create_deformed_feature (line 59) | def create_deformed_feature(self, feature, sparse_motions):
    method create_heatmap_representations (line 68) | def create_heatmap_representations(self, feature, kp_driving, kp_source):
    method forward (line 80) | def forward(self, feature, kp_driving, kp_source):

FILE: src/facerender/modules/discriminator.py
  class DownBlock2d (line 7) | class DownBlock2d(nn.Module):
    method __init__ (line 12) | def __init__(self, in_features, out_features, norm=False, kernel_size=...
    method forward (line 25) | def forward(self, x):
  class Discriminator (line 36) | class Discriminator(nn.Module):
    method __init__ (line 41) | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, m...
    method forward (line 57) | def forward(self, x):
  class MultiScaleDiscriminator (line 69) | class MultiScaleDiscriminator(nn.Module):
    method __init__ (line 74) | def __init__(self, scales=(), **kwargs):
    method forward (line 82) | def forward(self, x):

FILE: src/facerender/modules/generator.py
  class OcclusionAwareGenerator (line 8) | class OcclusionAwareGenerator(nn.Module):
    method __init__ (line 13) | def __init__(self, image_channel, feature_channel, num_kp, block_expan...
    method deform_input (line 61) | def deform_input(self, inp, deformation):
    method forward (line 70) | def forward(self, source_image, kp_driving, kp_source):
  class SPADEDecoder (line 120) | class SPADEDecoder(nn.Module):
    method __init__ (line 121) | def __init__(self):
    method forward (line 140) | def forward(self, feature):
  class OcclusionAwareSPADEGenerator (line 161) | class OcclusionAwareSPADEGenerator(nn.Module):
    method __init__ (line 163) | def __init__(self, image_channel, feature_channel, num_kp, block_expan...
    method deform_input (line 201) | def deform_input(self, inp, deformation):
    method forward (line 210) | def forward(self, source_image, kp_driving, kp_source):

FILE: src/facerender/modules/keypoint_detector.py
  class KPDetector (line 9) | class KPDetector(nn.Module):
    method __init__ (line 14) | def __init__(self, block_expansion, feature_channel, num_kp, image_cha...
    method gaussian2kp (line 44) | def gaussian2kp(self, heatmap):
    method forward (line 56) | def forward(self, x):
  class HEEstimator (line 85) | class HEEstimator(nn.Module):
    method __init__ (line 90) | def __init__(self, block_expansion, feature_channel, num_kp, image_cha...
    method forward (line 136) | def forward(self, x):

FILE: src/facerender/modules/make_animation.py
  function normalize_kp (line 7) | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_moveme...
  function headpose_pred_to_degree (line 29) | def headpose_pred_to_degree(pred):
  function get_rotation_matrix (line 37) | def get_rotation_matrix(yaw, pitch, roll):
  function keypoint_transformation (line 65) | def keypoint_transformation(kp_canonical, he, wo_exp=False):
  function make_animation (line 102) | def make_animation(source_image, source_semantics, target_semantics,
  class AnimateModel (line 141) | class AnimateModel(torch.nn.Module):
    method __init__ (line 146) | def __init__(self, generator, kp_extractor, mapping):
    method forward (line 156) | def forward(self, x):

FILE: src/facerender/modules/mapping.py
  class MappingNet (line 8) | class MappingNet(nn.Module):
    method __init__ (line 9) | def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins):
    method forward (line 32) | def forward(self, input_3dmm):

FILE: src/facerender/modules/util.py
  function kp2gaussian (line 12) | def kp2gaussian(kp, spatial_size, kp_variance):
  function make_coordinate_grid_2d (line 35) | def make_coordinate_grid_2d(spatial_size, type):
  function make_coordinate_grid (line 54) | def make_coordinate_grid(spatial_size, type):
  class ResBottleneck (line 73) | class ResBottleneck(nn.Module):
    method __init__ (line 74) | def __init__(self, in_features, stride):
    method forward (line 88) | def forward(self, x):
  class ResBlock2d (line 105) | class ResBlock2d(nn.Module):
    method __init__ (line 110) | def __init__(self, in_features, kernel_size, padding):
    method forward (line 119) | def forward(self, x):
  class ResBlock3d (line 130) | class ResBlock3d(nn.Module):
    method __init__ (line 135) | def __init__(self, in_features, kernel_size, padding):
    method forward (line 144) | def forward(self, x):
  class UpBlock2d (line 155) | class UpBlock2d(nn.Module):
    method __init__ (line 160) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
    method forward (line 167) | def forward(self, x):
  class UpBlock3d (line 174) | class UpBlock3d(nn.Module):
    method __init__ (line 179) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
    method forward (line 186) | def forward(self, x):
  class DownBlock2d (line 195) | class DownBlock2d(nn.Module):
    method __init__ (line 200) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
    method forward (line 207) | def forward(self, x):
  class DownBlock3d (line 215) | class DownBlock3d(nn.Module):
    method __init__ (line 220) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
    method forward (line 231) | def forward(self, x):
  class SameBlock2d (line 239) | class SameBlock2d(nn.Module):
    method __init__ (line 244) | def __init__(self, in_features, out_features, groups=1, kernel_size=3,...
    method forward (line 254) | def forward(self, x):
  class Encoder (line 261) | class Encoder(nn.Module):
    method __init__ (line 266) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
    method forward (line 276) | def forward(self, x):
  class Decoder (line 283) | class Decoder(nn.Module):
    method __init__ (line 288) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
    method forward (line 305) | def forward(self, x):
  class Hourglass (line 319) | class Hourglass(nn.Module):
    method __init__ (line 324) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
    method forward (line 330) | def forward(self, x):
  class KPHourglass (line 334) | class KPHourglass(nn.Module):
    method __init__ (line 339) | def __init__(self, block_expansion, in_features, reshape_features, res...
    method forward (line 360) | def forward(self, x):
  class AntiAliasInterpolation2d (line 371) | class AntiAliasInterpolation2d(nn.Module):
    method __init__ (line 375) | def __init__(self, channels, scale):
    method forward (line 409) | def forward(self, input):
  class SPADE (line 420) | class SPADE(nn.Module):
    method __init__ (line 421) | def __init__(self, norm_nc, label_nc):
    method forward (line 433) | def forward(self, x, segmap):
  class SPADEResnetBlock (line 443) | class SPADEResnetBlock(nn.Module):
    method __init__ (line 444) | def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation...
    method forward (line 467) | def forward(self, x, seg1):
    method shortcut (line 474) | def shortcut(self, x, seg1):
    method actvn (line 481) | def actvn(self, x):
  class audio2image (line 484) | class audio2image(nn.Module):
    method __init__ (line 485) | def __init__(self, generator, kp_extractor, he_estimator_video, he_est...
    method headpose_pred_to_degree (line 494) | def headpose_pred_to_degree(self, pred):
    method get_rotation_matrix (line 503) | def get_rotation_matrix(self, yaw, pitch, roll):
    method keypoint_transformation (line 531) | def keypoint_transformation(self, kp_canonical, he):
    method forward (line 557) | def forward(self, source_image, target_audio):

FILE: src/facerender/sync_batchnorm/batchnorm.py
  function _sum_ft (line 24) | def _sum_ft(tensor):
  function _unsqueeze_ft (line 29) | def _unsqueeze_ft(tensor):
  class _SynchronizedBatchNorm (line 38) | class _SynchronizedBatchNorm(_BatchNorm):
    method __init__ (line 39) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
    method forward (line 48) | def forward(self, input):
    method __data_parallel_replicate__ (line 80) | def __data_parallel_replicate__(self, ctx, copy_id):
    method _data_parallel_master (line 90) | def _data_parallel_master(self, intermediates):
    method _compute_mean_std (line 113) | def _compute_mean_std(self, sum_, ssum, size):
  class SynchronizedBatchNorm1d (line 128) | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
    method _check_input_dim (line 184) | def _check_input_dim(self, input):
  class SynchronizedBatchNorm2d (line 191) | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
    method _check_input_dim (line 247) | def _check_input_dim(self, input):
  class SynchronizedBatchNorm3d (line 254) | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
    method _check_input_dim (line 311) | def _check_input_dim(self, input):

FILE: src/facerender/sync_batchnorm/comm.py
  class FutureResult (line 18) | class FutureResult(object):
    method __init__ (line 21) | def __init__(self):
    method put (line 26) | def put(self, result):
    method get (line 32) | def get(self):
  class SlavePipe (line 46) | class SlavePipe(_SlavePipeBase):
    method run_slave (line 49) | def run_slave(self, msg):
  class SyncMaster (line 56) | class SyncMaster(object):
    method __init__ (line 67) | def __init__(self, master_callback):
    method __getstate__ (line 78) | def __getstate__(self):
    method __setstate__ (line 81) | def __setstate__(self, state):
    method register_slave (line 84) | def register_slave(self, identifier):
    method run_master (line 102) | def run_master(self, master_msg):
    method nr_slaves (line 136) | def nr_slaves(self):

FILE: src/facerender/sync_batchnorm/replicate.py
  class CallbackContext (line 23) | class CallbackContext(object):
  function execute_replication_callbacks (line 27) | def execute_replication_callbacks(modules):
  class DataParallelWithCallback (line 50) | class DataParallelWithCallback(DataParallel):
    method replicate (line 64) | def replicate(self, module, device_ids):
  function patch_replication_callback (line 70) | def patch_replication_callback(data_parallel):

FILE: src/facerender/sync_batchnorm/unittest.py
  function as_numpy (line 17) | def as_numpy(v):
  class TorchTestCase (line 23) | class TorchTestCase(unittest.TestCase):
    method assertTensorClose (line 24) | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):

FILE: src/generate_batch.py
  function crop_pad_audio (line 10) | def crop_pad_audio(wav, audio_length):
  function parse_audio_length (line 17) | def parse_audio_length(audio_length, sr, fps):
  function generate_blink_seq (line 25) | def generate_blink_seq(num_frames):
  function generate_blink_seq_randomly (line 37) | def generate_blink_seq_randomly(num_frames):
  function get_data (line 51) | def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_pa...

FILE: src/generate_facerender_batch.py
  function get_facerender_data (line 8) | def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
  function transform_semantic_1 (line 88) | def transform_semantic_1(semantic, semantic_radius):
  function transform_semantic_target (line 93) | def transform_semantic_target(coeff_3dmm, frame_index, semantic_radius):
  function gen_camera_pose (line 100) | def gen_camera_pose(camera_degree_list, frame_num, batch_size):

FILE: src/gradio_demo.py
  function mp3_to_wav (line 14) | def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
  class SadTalker (line 19) | class SadTalker():
    method __init__ (line 21) | def __init__(self, checkpoint_path='checkpoints', config_path='src/con...
    method test (line 36) | def test(self, source_image, driven_audio, preprocess='crop',

FILE: src/test_audio2coeff.py
  function load_cpk (line 16) | def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"):
  class Audio2Coeff (line 25) | class Audio2Coeff():
    method __init__ (line 27) | def __init__(self, sadtalker_path, device):
    method generate (line 74) | def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_p...
    method using_refpose (line 107) | def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):

FILE: src/utils/audio.py
  function load_wav (line 9) | def load_wav(path, sr):
  function save_wav (line 12) | def save_wav(wav, path, sr):
  function save_wavenet_wav (line 17) | def save_wavenet_wav(wav, path, sr):
  function preemphasis (line 20) | def preemphasis(wav, k, preemphasize=True):
  function inv_preemphasis (line 25) | def inv_preemphasis(wav, k, inv_preemphasize=True):
  function get_hop_size (line 30) | def get_hop_size():
  function linearspectrogram (line 37) | def linearspectrogram(wav):
  function melspectrogram (line 45) | def melspectrogram(wav):
  function _lws_processor (line 53) | def _lws_processor():
  function _stft (line 57) | def _stft(y):
  function num_frames (line 65) | def num_frames(length, fsize, fshift):
  function pad_lr (line 76) | def pad_lr(x, fsize, fshift):
  function librosa_pad_lr (line 86) | def librosa_pad_lr(x, fsize, fshift):
  function _linear_to_mel (line 92) | def _linear_to_mel(spectogram):
  function _build_mel_basis (line 98) | def _build_mel_basis():
  function _amp_to_db (line 103) | def _amp_to_db(x):
  function _db_to_amp (line 107) | def _db_to_amp(x):
  function _normalize (line 110) | def _normalize(S):
  function _denormalize (line 124) | def _denormalize(D):

FILE: src/utils/croper.py
  class Preprocesser (line 19) | class Preprocesser:
    method __init__ (line 20) | def __init__(self, device='cuda'):
    method get_landmark (line 23) | def get_landmark(self, img_np):
    method align_face (line 43) | def align_face(self, img, lm, output_size=1024):
    method crop (line 126) | def crop(self, img_np_list, still=False, xsize=512):    # first frame ...

FILE: src/utils/face_enhancer.py
  class GeneratorWithLen (line 13) | class GeneratorWithLen(object):
    method __init__ (line 16) | def __init__(self, gen, length):
    method __len__ (line 20) | def __len__(self):
    method __iter__ (line 23) | def __iter__(self):
  function enhancer_list (line 26) | def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
  function enhancer_generator_with_len (line 30) | def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='r...
  function enhancer_generator_no_len (line 42) | def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='rea...

FILE: src/utils/hparams.py
  class HParams (line 4) | class HParams:
    method __init__ (line 5) | def __init__(self, **kwargs):
    method __getattr__ (line 11) | def __getattr__(self, key):
    method set_hparam (line 16) | def set_hparam(self, key, value):
  function hparams_debug_string (line 157) | def hparams_debug_string():

FILE: src/utils/init_path.py
  function init_path (line 4) | def init_path(checkpoint_dir, config_dir, size=512, old_version=False, p...

FILE: src/utils/model2safetensor.py
  function load_cpk_facevid2vid (line 43) | def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=...
  function load_cpk_facevid2vid_safetensor (line 75) | def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None,
  class SadTalker (line 125) | class SadTalker(torch.nn.Module):
    method __init__ (line 126) | def __init__(self, kp_extractor, generator, netG, audio2pose, face_3dr...

FILE: src/utils/paste_pic.py
  function paste_pic (line 8) | def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_vide...

FILE: src/utils/preprocess.py
  function split_coeff (line 22) | def split_coeff(coeffs):
  class CropAndExtract (line 46) | class CropAndExtract():
    method __init__ (line 47) | def __init__(self, sadtalker_path, device):
    method generate (line 63) | def generate(self, input_path, save_dir, crop_or_resize='crop', source...

FILE: src/utils/safetensor_helper.py
  function load_x_from_safetensor (line 3) | def load_x_from_safetensor(checkpoint, key):

FILE: src/utils/text2speech.py
  class TTSTalker (line 6) | class TTSTalker():
    method __init__ (line 7) | def __init__(self) -> None:
    method test (line 11) | def test(self, text, language='en'):

FILE: src/utils/videoio.py
  function load_video_to_cv2 (line 8) | def load_video_to_cv2(input_path):
  function save_video_with_watermark (line 20) | def save_video_with_watermark(video, audio, save_path, watermark=False):
Condensed preview — 141 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (617K chars).
[
  {
    "path": ".gitignore",
    "chars": 3205,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11692,
    "preview": "Tencent is pleased to support the open source community by making SadTalker available.\n\nCopyright (C), a Tencent company"
  },
  {
    "path": "README.md",
    "chars": 17068,
    "preview": "<div align=\"center\">\n\n<img src='https://user-images.githubusercontent.com/4397546/229094115-862c747e-7397-4b54-ba4a-bd36"
  },
  {
    "path": "app_sadtalker.py",
    "chars": 5836,
    "preview": "import os, sys\nimport gradio as gr\nfrom src.gradio_demo import SadTalker  \n\n\ntry:\n    import webui  # in webui\n    in_we"
  },
  {
    "path": "cog.yaml",
    "chars": 1115,
    "preview": "build:\n  gpu: true\n  cuda: \"11.3\"\n  python_version: \"3.8\"\n  system_packages:\n    - \"ffmpeg\"\n    - \"libgl1-mesa-glx\"\n    "
  },
  {
    "path": "docs/FAQ.md",
    "chars": 2168,
    "preview": "\n## Frequency Asked Question\n\n**Q: `ffmpeg` is not recognized as an internal or external command**\n\nIn Linux, you can in"
  },
  {
    "path": "docs/best_practice.md",
    "chars": 5334,
    "preview": "# Best Practices and Tips for configuration\n\n> Our model only works on REAL people or the portrait image similar to REAL"
  },
  {
    "path": "docs/changlelog.md",
    "chars": 2230,
    "preview": "## changelogs\n\n\n- __[2023.04.06]__: stable-diffiusion webui extension is release.\n\n- __[2023.04.03]__: Enable TTS in hug"
  },
  {
    "path": "docs/face3d.md",
    "chars": 1412,
    "preview": "## 3D Face Visualization\n\nWe use `pytorch3d` to visualize the 3D faces from a single image.\n\nThe requirements for 3D vis"
  },
  {
    "path": "docs/install.md",
    "chars": 1235,
    "preview": "### macOS\n\nThis method has been tested on a M1 Mac (13.3)\n\n```bash\ngit clone https://github.com/OpenTalker/SadTalker.git"
  },
  {
    "path": "docs/webui_extension.md",
    "chars": 2070,
    "preview": "## Run SadTalker as a Stable Diffusion WebUI Extension.\n\n1. Install the lastest version of [stable-diffusion-webui](http"
  },
  {
    "path": "inference.py",
    "chars": 7610,
    "preview": "from glob import glob\nimport shutil\nimport torch\nfrom time import  strftime\nimport os, sys, time\nfrom argparse import Ar"
  },
  {
    "path": "launcher.py",
    "chars": 7039,
    "preview": "# this scripts installs necessary requirements and launches main program in webui.py\n# borrow from : https://github.com/"
  },
  {
    "path": "predict.py",
    "chars": 6481,
    "preview": "\"\"\"run bash scripts/download_models.sh first to prepare the weights file\"\"\"\nimport os\nimport shutil\nfrom argparse import"
  },
  {
    "path": "quick_demo.ipynb",
    "chars": 6917,
    "preview": "{\n  \"cells\": [\n    {\n      \"attachments\": {},\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"M74Gs_Tj"
  },
  {
    "path": "req.txt",
    "chars": 306,
    "preview": "llvmlite==0.38.1\nnumpy==1.21.6\nface_alignment==1.3.5\nimageio==2.19.3\nimageio-ffmpeg==0.4.7\nlibrosa==0.10.0.post2\nnumba=="
  },
  {
    "path": "requirements.txt",
    "chars": 277,
    "preview": "numpy==1.23.4\nface_alignment==1.3.5\nimageio==2.19.3\nimageio-ffmpeg==0.4.7\nlibrosa==0.9.2 # \nnumba\nresampy==0.3.1\npydub=="
  },
  {
    "path": "requirements3d.txt",
    "chars": 288,
    "preview": "numpy==1.23.4\nface_alignment==1.3.5\nimageio==2.19.3\nimageio-ffmpeg==0.4.7\nlibrosa==0.9.2 # \nnumba\nresampy==0.3.1\npydub=="
  },
  {
    "path": "scripts/download_models.sh",
    "chars": 2763,
    "preview": "mkdir ./checkpoints  \n\n# lagency download link\n# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2"
  },
  {
    "path": "scripts/extension.py",
    "chars": 6754,
    "preview": "import os, sys\r\nfrom pathlib import Path\r\nimport tempfile\r\nimport gradio as gr\r\nfrom modules.call_queue import wrap_grad"
  },
  {
    "path": "scripts/test.sh",
    "chars": 855,
    "preview": "# ### some test command before commit.\n# python inference.py --preprocess crop --size 256\n# python inference.py --prepro"
  },
  {
    "path": "src/audio2exp_models/audio2exp.py",
    "chars": 1253,
    "preview": "from tqdm import tqdm\nimport torch\nfrom torch import nn\n\n\nclass Audio2Exp(nn.Module):\n    def __init__(self, netG, cfg, "
  },
  {
    "path": "src/audio2exp_models/networks.py",
    "chars": 2977,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nclass Conv2d(nn.Module):\n    def __init__(self, cin, "
  },
  {
    "path": "src/audio2pose_models/audio2pose.py",
    "chars": 3812,
    "preview": "import torch\nfrom torch import nn\nfrom src.audio2pose_models.cvae import CVAE\nfrom src.audio2pose_models.discriminator i"
  },
  {
    "path": "src/audio2pose_models/audio_encoder.py",
    "chars": 2745,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass Conv2d(nn.Module):\n    def __init__(self, "
  },
  {
    "path": "src/audio2pose_models/cvae.py",
    "chars": 6072,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom src.audio2pose_models.res_unet import ResUnet\n\nde"
  },
  {
    "path": "src/audio2pose_models/discriminator.py",
    "chars": 2669,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nclass ConvNormRelu(nn.Module):\n    def __init__(self,"
  },
  {
    "path": "src/audio2pose_models/networks.py",
    "chars": 4247,
    "preview": "import torch.nn as nn\nimport torch\n\n\nclass ResidualConv(nn.Module):\n    def __init__(self, input_dim, output_dim, stride"
  },
  {
    "path": "src/audio2pose_models/res_unet.py",
    "chars": 2229,
    "preview": "import torch\nimport torch.nn as nn\nfrom src.audio2pose_models.networks import ResidualConv, Upsample\n\n\nclass ResUnet(nn."
  },
  {
    "path": "src/config/auido2exp.yaml",
    "chars": 1209,
    "preview": "DATASET:\n  TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt\n  EVAL_FILE_LIST: /apdceph"
  },
  {
    "path": "src/config/auido2pose.yaml",
    "chars": 1097,
    "preview": "DATASET:\n  TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt\n "
  },
  {
    "path": "src/config/facerender.yaml",
    "chars": 1087,
    "preview": "model_params:\n  common_params:\n    num_kp: 15 \n    image_channel: 3                    \n    feature_channel: 32\n    esti"
  },
  {
    "path": "src/config/facerender_still.yaml",
    "chars": 1087,
    "preview": "model_params:\n  common_params:\n    num_kp: 15 \n    image_channel: 3                    \n    feature_channel: 32\n    esti"
  },
  {
    "path": "src/face3d/data/__init__.py",
    "chars": 4584,
    "preview": "\"\"\"This package includes all the modules related to data loading and preprocessing\n\n To add a custom dataset class calle"
  },
  {
    "path": "src/face3d/data/base_dataset.py",
    "chars": 4679,
    "preview": "\"\"\"This module implements an abstract base class (ABC) 'BaseDataset' for datasets.\n\nIt also includes common transformati"
  },
  {
    "path": "src/face3d/data/flist_dataset.py",
    "chars": 4093,
    "preview": "\"\"\"This script defines the custom dataset for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os.path\nfrom data.base_dataset import "
  },
  {
    "path": "src/face3d/data/image_folder.py",
    "chars": 1959,
    "preview": "\"\"\"A modified image folder class\n\nWe modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/ma"
  },
  {
    "path": "src/face3d/data/template_dataset.py",
    "chars": 3506,
    "preview": "\"\"\"Dataset class template\n\nThis module provides a template for users to implement custom datasets.\nYou can specify '--da"
  },
  {
    "path": "src/face3d/extract_kp_videos.py",
    "chars": 3915,
    "preview": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport face_alignment\nimport numpy as np\nfrom PIL import Im"
  },
  {
    "path": "src/face3d/extract_kp_videos_safe.py",
    "chars": 5705,
    "preview": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport numpy as np\nfrom PIL import Image\nimport torch\nfrom "
  },
  {
    "path": "src/face3d/models/__init__.py",
    "chars": 3090,
    "preview": "\"\"\"This package contains modules related to objective functions, optimizations, and network architectures.\n\nTo add a cus"
  },
  {
    "path": "src/face3d/models/arcface_torch/README.md",
    "chars": 8597,
    "preview": "# Distributed Arcface Training in Pytorch\n\nThis is a deep learning library that makes face recognition efficient, and ef"
  },
  {
    "path": "src/face3d/models/arcface_torch/backbones/__init__.py",
    "chars": 822,
    "preview": "from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200\nfrom .mobilefacenet import get_mbf\n\n\ndef ge"
  },
  {
    "path": "src/face3d/models/arcface_torch/backbones/iresnet.py",
    "chars": 7149,
    "preview": "import torch\nfrom torch import nn\n\n__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']\n\n\ndef c"
  },
  {
    "path": "src/face3d/models/arcface_torch/backbones/iresnet2060.py",
    "chars": 6708,
    "preview": "import torch\nfrom torch import nn\n\nassert torch.__version__ >= \"1.8.1\"\nfrom torch.utils.checkpoint import checkpoint_seq"
  },
  {
    "path": "src/face3d/models/arcface_torch/backbones/mobilefacenet.py",
    "chars": 4895,
    "preview": "'''\nAdapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py\nOriginal author ca"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/3millions.py",
    "chars": 519,
    "preview": "from easydict import EasyDict as edict\n\n# configs for test speed\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.networ"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/3millions_pfc.py",
    "chars": 519,
    "preview": "from easydict import EasyDict as edict\n\n# configs for test speed\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.networ"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/base.py",
    "chars": 1628,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_mbf.py",
    "chars": 647,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_r100.py",
    "chars": 648,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_r18.py",
    "chars": 647,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_r34.py",
    "chars": 647,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/glint360k_r50.py",
    "chars": 647,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py",
    "chars": 651,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_r18.py",
    "chars": 651,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py",
    "chars": 652,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_r34.py",
    "chars": 651,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/ms1mv3_r50.py",
    "chars": 651,
    "preview": "from easydict import EasyDict as edict\n\n# make training faster\n# our RAM is 256G\n# mount -t tmpfs -o size=140G  tmpfs /t"
  },
  {
    "path": "src/face3d/models/arcface_torch/configs/speed.py",
    "chars": 519,
    "preview": "from easydict import EasyDict as edict\n\n# configs for test speed\n\nconfig = edict()\nconfig.loss = \"arcface\"\nconfig.networ"
  },
  {
    "path": "src/face3d/models/arcface_torch/dataset.py",
    "chars": 3868,
    "preview": "import numbers\nimport os\nimport queue as Queue\nimport threading\n\nimport mxnet as mx\nimport numpy as np\nimport torch\nfrom"
  },
  {
    "path": "src/face3d/models/arcface_torch/docs/eval.md",
    "chars": 655,
    "preview": "## Eval on ICCV2021-MFR\n\ncoming soon.\n\n\n## Eval IJBC\nYou can eval ijbc with pytorch or onnx.\n\n\n1. Eval IJBC With Onnx\n``"
  },
  {
    "path": "src/face3d/models/arcface_torch/docs/install.md",
    "chars": 1613,
    "preview": "## v1.8.0 \n### Linux and Windows  \n```shell\n# CUDA 11.0\npip --default-timeout=100 install torch==1.8.0+cu111 torchvision"
  },
  {
    "path": "src/face3d/models/arcface_torch/docs/modelzoo.md",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/face3d/models/arcface_torch/docs/speed_benchmark.md",
    "chars": 5465,
    "preview": "## Test Training Speed\n\n- Test Commands\n\nYou need to use the following two commands to test the Partial FC training perf"
  },
  {
    "path": "src/face3d/models/arcface_torch/eval/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/face3d/models/arcface_torch/eval/verification.py",
    "chars": 16092,
    "preview": "\"\"\"Helper for evaluation on the Labeled Faces in the Wild dataset \n\"\"\"\n\n# MIT License\n#\n# Copyright (c) 2016 David Sandb"
  },
  {
    "path": "src/face3d/models/arcface_torch/eval_ijbc.py",
    "chars": 17222,
    "preview": "# coding: utf-8\n\nimport os\nimport pickle\n\nimport matplotlib\nimport pandas as pd\n\nmatplotlib.use('Agg')\nimport matplotlib"
  },
  {
    "path": "src/face3d/models/arcface_torch/inference.py",
    "chars": 1033,
    "preview": "import argparse\n\nimport cv2\nimport numpy as np\nimport torch\n\nfrom backbones import get_model\n\n\n@torch.no_grad()\ndef infe"
  },
  {
    "path": "src/face3d/models/arcface_torch/losses.py",
    "chars": 1137,
    "preview": "import torch\nfrom torch import nn\n\n\ndef get_loss(name):\n    if name == \"cosface\":\n        return CosFace()\n    elif name"
  },
  {
    "path": "src/face3d/models/arcface_torch/onnx_helper.py",
    "chars": 10422,
    "preview": "from __future__ import division\nimport datetime\nimport os\nimport os.path as osp\nimport glob\nimport numpy as np\nimport cv"
  },
  {
    "path": "src/face3d/models/arcface_torch/onnx_ijbc.py",
    "chars": 10321,
    "preview": "import argparse\nimport os\nimport pickle\nimport timeit\n\nimport cv2\nimport mxnet as mx\nimport numpy as np\nimport pandas as"
  },
  {
    "path": "src/face3d/models/arcface_torch/partial_fc.py",
    "chars": 9492,
    "preview": "import logging\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom torch.nn import Module\nfrom torch.nn.functi"
  },
  {
    "path": "src/face3d/models/arcface_torch/requirement.txt",
    "chars": 40,
    "preview": "tensorboard\neasydict\nmxnet\nonnx\nsklearn\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/run.sh",
    "chars": 260,
    "preview": "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --ma"
  },
  {
    "path": "src/face3d/models/arcface_torch/torch2onnx.py",
    "chars": 2365,
    "preview": "import numpy as np\nimport onnx\nimport torch\n\n\ndef convert_onnx(net, path_module, output, opset=11, simplify=False):\n    "
  },
  {
    "path": "src/face3d/models/arcface_torch/train.py",
    "chars": 6023,
    "preview": "import argparse\nimport logging\nimport os\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\n"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/plot.py",
    "chars": 2222,
    "preview": "# coding: utf-8\n\nimport os\nfrom pathlib import Path\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_amp.py",
    "chars": 3286,
    "preview": "from typing import Dict, List\n\nimport torch\n\nif torch.__version__ < '1.9':\n    Iterable = torch._six.container_abcs.Iter"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_callbacks.py",
    "chars": 5038,
    "preview": "import logging\nimport os\nimport time\nfrom typing import List\n\nimport torch\n\nfrom eval import verification\nfrom utils.uti"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_config.py",
    "chars": 571,
    "preview": "import importlib\nimport os.path as osp\n\n\ndef get_config(config_file):\n    assert config_file.startswith('configs/'), 'co"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_logging.py",
    "chars": 1110,
    "preview": "import logging\nimport os\nimport sys\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current val"
  },
  {
    "path": "src/face3d/models/arcface_torch/utils/utils_os.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/face3d/models/base_model.py",
    "chars": 13168,
    "preview": "\"\"\"This script defines the base network model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os\nimport numpy as np\nimport torch"
  },
  {
    "path": "src/face3d/models/bfm.py",
    "chars": 12349,
    "preview": "\"\"\"This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nimport  torch\nim"
  },
  {
    "path": "src/face3d/models/facerecon_model.py",
    "chars": 10843,
    "preview": "\"\"\"This script defines the face reconstruction model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nimport torch\nfr"
  },
  {
    "path": "src/face3d/models/losses.py",
    "chars": 4171,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom kornia.geometry import warp_affine\nimport torch.nn.functional"
  },
  {
    "path": "src/face3d/models/networks.py",
    "chars": 20766,
    "preview": "\"\"\"This script defines deep neural networks for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os\nimport numpy as np\nimport torch.n"
  },
  {
    "path": "src/face3d/models/template_model.py",
    "chars": 5970,
    "preview": "\"\"\"Model class template\n\nThis module provides a template for users to implement custom models.\nYou can specify '--model "
  },
  {
    "path": "src/face3d/options/__init__.py",
    "chars": 136,
    "preview": "\"\"\"This package options includes option modules: training options, test options, and basic options (used in both trainin"
  },
  {
    "path": "src/face3d/options/base_options.py",
    "chars": 7493,
    "preview": "\"\"\"This script contains base options for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport argparse\nimport os\nfrom util import util\nim"
  },
  {
    "path": "src/face3d/options/inference_options.py",
    "chars": 1171,
    "preview": "from face3d.options.base_options import BaseOptions\n\n\nclass InferenceOptions(BaseOptions):\n    \"\"\"This class includes te"
  },
  {
    "path": "src/face3d/options/test_options.py",
    "chars": 830,
    "preview": "\"\"\"This script contains the test options for Deep3DFaceRecon_pytorch\n\"\"\"\n\nfrom .base_options import BaseOptions\n\n\nclass "
  },
  {
    "path": "src/face3d/options/train_options.py",
    "chars": 3744,
    "preview": "\"\"\"This script contains the training options for Deep3DFaceRecon_pytorch\n\"\"\"\n\nfrom .base_options import BaseOptions\nfrom"
  },
  {
    "path": "src/face3d/util/__init__.py",
    "chars": 114,
    "preview": "\"\"\"This package includes a miscellaneous collection of useful helper functions.\"\"\"\nfrom src.face3d.util import *\n\n"
  },
  {
    "path": "src/face3d/util/detect_lm68.py",
    "chars": 4033,
    "preview": "import os\nimport cv2\nimport numpy as np\nfrom scipy.io import loadmat\nimport tensorflow as tf\nfrom util.preprocess import"
  },
  {
    "path": "src/face3d/util/generate_list.py",
    "chars": 1346,
    "preview": "\"\"\"This script is to generate training list files for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport os\n\n# save path to training da"
  },
  {
    "path": "src/face3d/util/html.py",
    "chars": 3223,
    "preview": "import dominate\nfrom dominate.tags import meta, h3, table, tr, td, p, a, img, br\nimport os\n\n\nclass HTML:\n    \"\"\"This HTM"
  },
  {
    "path": "src/face3d/util/load_mats.py",
    "chars": 4445,
    "preview": "\"\"\"This script is to load 3D face model for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nfrom PIL import Image\nfrom s"
  },
  {
    "path": "src/face3d/util/my_awing_arch.py",
    "chars": 12503,
    "preview": "import cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef calculate_points("
  },
  {
    "path": "src/face3d/util/nvdiffrast.py",
    "chars": 4627,
    "preview": "\"\"\"This script is the differentiable renderer for Deep3DFaceRecon_pytorch\n    Attention, antialiasing step is missing in"
  },
  {
    "path": "src/face3d/util/preprocess.py",
    "chars": 3336,
    "preview": "\"\"\"This script contains the image preprocessing code for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nfrom scipy.io i"
  },
  {
    "path": "src/face3d/util/skin_mask.py",
    "chars": 5333,
    "preview": "\"\"\"This script is to generate skin attention mask for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport math\nimport numpy as np\nimport"
  },
  {
    "path": "src/face3d/util/test_mean_face.txt",
    "chars": 3481,
    "preview": "-5.228591537475585938e+01\n2.078247070312500000e-01\n-5.064269638061523438e+01\n-1.315765380859375000e+01\n-4.95293922424316"
  },
  {
    "path": "src/face3d/util/util.py",
    "chars": 6588,
    "preview": "\"\"\"This script contains basic utilities for Deep3DFaceRecon_pytorch\n\"\"\"\nfrom __future__ import print_function\nimport num"
  },
  {
    "path": "src/face3d/util/visualizer.py",
    "chars": 10485,
    "preview": "\"\"\"This script defines the visualizer for Deep3DFaceRecon_pytorch\n\"\"\"\n\nimport numpy as np\nimport os\nimport sys\nimport nt"
  },
  {
    "path": "src/face3d/visualize.py",
    "chars": 1719,
    "preview": "# check the sync of 3dmm feature and the audio\nimport cv2\nimport numpy as np\nfrom src.face3d.models.bfm import Parametri"
  },
  {
    "path": "src/facerender/animate.py",
    "chars": 11529,
    "preview": "import os\nimport cv2\nimport yaml\nimport numpy as np\nimport warnings\nfrom skimage import img_as_ubyte\nimport safetensors\n"
  },
  {
    "path": "src/facerender/modules/dense_motion.py",
    "chars": 5866,
    "preview": "from torch import nn\nimport torch.nn.functional as F\nimport torch\nfrom src.facerender.modules.util import Hourglass, mak"
  },
  {
    "path": "src/facerender/modules/discriminator.py",
    "chars": 2872,
    "preview": "from torch import nn\nimport torch.nn.functional as F\nfrom facerender.modules.util import kp2gaussian\nimport torch\n\n\nclas"
  },
  {
    "path": "src/facerender/modules/generator.py",
    "chars": 11493,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom src.facerender.modules.util import ResBlock2d, Sa"
  },
  {
    "path": "src/facerender/modules/keypoint_detector.py",
    "chars": 6687,
    "preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\n\nfrom src.facerender.sync_batchnorm import Synchronize"
  },
  {
    "path": "src/facerender/modules/make_animation.py",
    "chars": 6802,
    "preview": "from scipy.spatial import ConvexHull\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom tqdm import tq"
  },
  {
    "path": "src/facerender/modules/mapping.py",
    "chars": 1585,
    "preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MappingNet(nn.Module):\n  "
  },
  {
    "path": "src/facerender/modules/util.py",
    "chars": 20173,
    "preview": "from torch import nn\n\nimport torch.nn.functional as F\nimport torch\n\nfrom src.facerender.sync_batchnorm import Synchroniz"
  },
  {
    "path": "src/facerender/sync_batchnorm/__init__.py",
    "chars": 449,
    "preview": "# -*- coding: utf-8 -*-\n# File   : __init__.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2"
  },
  {
    "path": "src/facerender/sync_batchnorm/batchnorm.py",
    "chars": 12973,
    "preview": "# -*- coding: utf-8 -*-\n# File   : batchnorm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/"
  },
  {
    "path": "src/facerender/sync_batchnorm/comm.py",
    "chars": 4449,
    "preview": "# -*- coding: utf-8 -*-\n# File   : comm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n"
  },
  {
    "path": "src/facerender/sync_batchnorm/replicate.py",
    "chars": 3226,
    "preview": "# -*- coding: utf-8 -*-\n# File   : replicate.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/"
  },
  {
    "path": "src/facerender/sync_batchnorm/unittest.py",
    "chars": 835,
    "preview": "# -*- coding: utf-8 -*-\n# File   : unittest.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2"
  },
  {
    "path": "src/generate_batch.py",
    "chars": 4453,
    "preview": "import os\n\nfrom tqdm import tqdm\nimport torch\nimport numpy as np\nimport random\nimport scipy.io as scio\nimport src.utils."
  },
  {
    "path": "src/generate_facerender_batch.py",
    "chars": 5757,
    "preview": "import os\nimport numpy as np\nfrom PIL import Image\nfrom skimage import io, img_as_float32, transform\nimport torch\nimport"
  },
  {
    "path": "src/gradio_demo.py",
    "chars": 6819,
    "preview": "import torch, uuid\r\nimport os, sys, shutil\r\nfrom src.utils.preprocess import CropAndExtract\r\nfrom src.test_audio2coeff i"
  },
  {
    "path": "src/test_audio2coeff.py",
    "chars": 5393,
    "preview": "import os \nimport torch\nimport numpy as np\nfrom scipy.io import savemat, loadmat\nfrom yacs.config import CfgNode as CN\nf"
  },
  {
    "path": "src/utils/audio.py",
    "chars": 4518,
    "preview": "import librosa\nimport librosa.filters\nimport numpy as np\n# import tensorflow as tf\nfrom scipy import signal\nfrom scipy.i"
  },
  {
    "path": "src/utils/croper.py",
    "chars": 5870,
    "preview": "import os\nimport cv2\nimport time\nimport glob\nimport argparse\nimport scipy\nimport numpy as np\nfrom PIL import Image\nimpor"
  },
  {
    "path": "src/utils/face_enhancer.py",
    "chars": 4436,
    "preview": "import os\nimport torch \n\nfrom gfpgan import GFPGANer\n\nfrom tqdm import tqdm\n\nfrom src.utils.videoio import load_video_to"
  },
  {
    "path": "src/utils/hparams.py",
    "chars": 6055,
    "preview": "from glob import glob\nimport os\n\nclass HParams:\n\tdef __init__(self, **kwargs):\n\t\tself.data = {}\n\n\t\tfor key, value in kwa"
  },
  {
    "path": "src/utils/init_path.py",
    "chars": 2607,
    "preview": "import os\nimport glob\n\ndef init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'):\n\n    i"
  },
  {
    "path": "src/utils/model2safetensor.py",
    "chars": 6027,
    "preview": "import torch\nimport yaml\nimport os\n\nimport safetensors\nfrom safetensors.torch import save_file\nfrom yacs.config import C"
  },
  {
    "path": "src/utils/paste_pic.py",
    "chars": 2390,
    "preview": "import cv2, os\nimport numpy as np\nfrom tqdm import tqdm\nimport uuid\n\nfrom src.utils.videoio import save_video_with_water"
  },
  {
    "path": "src/utils/preprocess.py",
    "chars": 6890,
    "preview": "import numpy as np\nimport cv2, os, sys, torch\nfrom tqdm import tqdm\nfrom PIL import Image \n\n# 3dmm extraction\nimport saf"
  },
  {
    "path": "src/utils/safetensor_helper.py",
    "chars": 198,
    "preview": "\n\ndef load_x_from_safetensor(checkpoint, key):\n    x_generator = {}\n    for k,v in checkpoint.items():\n        if key in"
  },
  {
    "path": "src/utils/text2speech.py",
    "chars": 489,
    "preview": "import os\nimport tempfile\nfrom TTS.api import TTS\n\n\nclass TTSTalker():\n    def __init__(self) -> None:\n        model_nam"
  },
  {
    "path": "src/utils/videoio.py",
    "chars": 1455,
    "preview": "import shutil\nimport uuid\n\nimport os\n\nimport cv2\n\ndef load_video_to_cv2(input_path):\n    video_stream = cv2.VideoCapture"
  },
  {
    "path": "webui.bat",
    "chars": 275,
    "preview": "@echo off\n\nIF NOT EXIST venv (\npython -m venv venv\n) ELSE (\necho venv folder already exists, skipping creation...\n)\ncall"
  },
  {
    "path": "webui.sh",
    "chars": 3747,
    "preview": "#!/usr/bin/env bash\n\n\n# If run from macOS, load defaults from webui-macos-env.sh\nif [[ \"$OSTYPE\" == \"darwin\"* ]]; then\n "
  }
]

// ... and 2 more files (download for full content)

About this extraction

This page contains the full source code of the OpenTalker/SadTalker GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 141 files (575.1 KB), approximately 156.7k tokens, and a symbol index with 702 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!