[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized files\n*__pycache__/\n*.py[cod]\n\n\n# Distribution / packaging\ndist/\nbuild/\n*.egg-info/\n*.egg\n\n# Virtual environments\nvenv/\nenv/\n.venv/\n.pip/\npipenv/\n\n# IDE specific files\n.idea/\n.vscode/\n\n# Logs and temporary files\n*.log\n*.swp\n*.swo\n*.tmp\n\n# Miscellaneous\n.DS_Store\nrecords.txt\nbinary_resources\ndo_experiment.py\nsmall_tools\nscancel_job.py\nresults\ntemplate_resources\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "\n# UniTalker: Scaling up Audio-Driven 3D Facial Animation through A Unified Model [ECCV2024]\n\n## Useful Links\n\n<div align=\"center\">\n    <a href=\"https://x-niper.github.io/projects/UniTalker/\" class=\"button\"><b>[Homepage]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;\n    <a href=\"https://arxiv.org/abs/2408.00762\" class=\"button\"><b>[arXiv]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;\n    <a href=\"https://www.youtube.com/watch?v=oUUh67ECzig\" class=\"button\"><b>[Video]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;\n</div>\n\n<div align=\"center\">\n<img src=\"https://github.com/X-niper/X-niper.github.io/blob/master/projects/UniTalker/assets/unitalker_architecture.png\" alt=\"UniTalker Architecture\" width=\"80%\" align=center/>\n</div>\n\n\n> UniTalker generates realistic facial motion from different audio domains, including clean and noisy voices in various languages, text-to-speech-generated audios, and even noisy songs accompanied by back-ground music.\n> \n> UniTalker can output multiple annotations.\n> \n> For datasets with new annotations, one can simply plug new heads into UniTalker and train it with existing datasets or solely with new ones, avoiding retopology.\n\n## Installation\n### Environment\n- Linux\n- Python 3.10\n- Pytorch 2.2.0\n- CUDA 12.1\n- transformers 4.39.3\n- Pytorch3d 0.7.7 (Optional: just for rendering the results)\n\n```bash\n  conda create -n unitalker python==3.10\n  conda activate unitalker\n  conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia\n  pip install transformers librosa tensorboardX smplx chumpy numpy==1.23.5 opencv-python\n```\n\n## Inference\n\n### Download checkpoints, PCA models and template resources\n\n[UniTalker-B-[D0-D7]](https://drive.google.com/file/d/1PmF8I6lyo0_64-NgeN5qIQAX6Bg0yw44/view?usp=sharing): The base model in paper. Download it and place it in `./pretrained_models` .\n\n[UniTalker-L-[D0-D7]](https://drive.google.com/file/d/1sH2T7KLFNjUnTM-V1eRMM1Tytxd2sYAp/view?usp=sharing): The default model in paper. Please first try the base model  to run the pipeline through.\n\n[Unitalker-data-release-V1](https://drive.google.com/file/d/1Un7TB0Z5A1CG6bgeqKlhnSOECFN-C6KK/view?usp=sharing): The released datasets, PCA models, data-split json files and id-template numpy array. Download and unzip it in this repo.\n\n[FLAME2020](https://flame.is.tue.mpg.de/download.php): Please download FLAME 2020 and move generic_model.pkl into resources/binary_resources/flame.pkl.\n\n<!-- [PCA models](https://drive.google.com/file/d/1e0sG2vvdrtAMgwD5njctifhX0ai4eu3g/view?usp=sharing): download the pca models and unzip it in \"./unitalker_data_release\" -->\n\nUse `git lfs pull` to get `./resources.zip` and `./test_audios.zip` and unzip it in this repo.\n\nFinally, these files should be organized as follows:\n\n```text\n├── pretrained_models\n│   ├── UniTalker-B-D0-D7.pt\n│   ├── UniTalker-L-D0-D7.pt\n├── resources\n│   ├── binary_resources\n│   │   ├── 02_flame_mouth_idx.npy\n│   │   ├── ...\n│   │   └── vocaset_FDD_wo_eyes.npy\n│   └── obj_template\n│       ├── 3DETF_blendshape_weight.obj\n│       ├── ...\n│       └── meshtalk_6172_vertices.obj\n└── unitalker_data_release_V1\n│   ├── D0_BIWI\n│   │   ├── id_template.npy\n│   │   └── pca.npz\n│   ├── D1_vocaset\n│   │   ├── id_template.npy\n│   │   └── pca.npz\n│   ├── D2_meshtalk\n│   │   ├── id_template.npy\n│   │   └── pca.npz\n│   ├── D3D4_3DETF\n│   │   ├── D3_HDTF\n│   │   └── D4_RAVDESS\n│   ├── D5_unitalker_faceforensics++\n│   │   ├── id_template.npy\n│   │   ├── test\n│   │   ├── test.json\n│   │   ├── train\n│   │   ├── train.json\n│   │   ├── val\n│   │   └── val.json\n│   ├── D6_unitalker_Chinese_speech\n│   │   ├── id_template.npy\n│   │   ├── test\n│   │   ├── test.json\n│   │   ├── train\n│   │   ├── train.json\n│   │   ├── val\n│   │   └── val.json\n│   └── D7_unitalker_song\n│       ├── id_template.npy\n│       ├── test\n│       ├── test.json\n│       ├── train\n│       ├── train.json\n│       ├── val\n│       └── val.json\n```\n\n### Demo\n\n```bash\n  python -m main.demo --config config/unitalker.yaml test_out_path ./test_results/demo.npz\n  python -m main.render ./test_results/demo.npz ./test_audios ./test_results/\n```\n\n## Train\n\n### Download Data\n[Unitalker-data-release-V1](https://drive.google.com/file/d/1qRBPsTdOWp72ty04oD1Q_ivtwMjrACLH/view?usp=sharing) contains D5, D6 and D7. The datasets have been processed and grouped into train, validation and test. Please use these three datasets to try the training step.\nIf you want to train the model on the D0-D7, you need to download the datasets following these links: \n[D0: BIWI](https://github.com/Doubiiu/CodeTalker/blob/main/BIWI/README.md).\n[D1: VOCASET](https://voca.is.tue.mpg.de/).\n[D2: meshtalk](https://github.com/facebookresearch/meshtalk?tab=readme-ov-file).\n[D4,D5: 3DETF](https://github.com/psyai-net/EmoTalk_release).\n\n### Modify Config and Train\nPlease modify `dataset` and `duplicate_list` in `config/unitalker.yaml` according to the datasets you have prepared, ensuring that both lists maintain the same length.\n\n\n```bash\npython -m main.train --config config/unitalker.yaml \n```"
  },
  {
    "path": "config/unitalker.yaml",
    "content": "DATA:\n  dataset: \"D0,D1,D2,D3,D4,D5,D6,D7\"\n  demo_dataset: \"D0,D1,D2,D3,D4,D6,D7\" # To visualize D5,  please download FLAME2020 model first and place it in \"resources/binary_resources/flame.pkl\"\n  data_root: \"./unitalker_data_release_V1/\"\n  duplicate_list: \"10,5,4,1,1,1,1,1\"\n  train_json_key: train.json\n\nLOSS:\n  rec_weight: 1.0\n  pca_weight: 0.01\n  blendshape_weight: 0.0001\n\n\nNETWORK:\n  use_pca: True\n  pca_dim: 512\n\nMODEL:\n  audio_encoder_repo: 'microsoft/wavlm-base-plus' # choose from 'microsoft/wavlm-base-plus' and facebook/wav2vec2-large-xlsr-53'\n  freeze_wav2vec: False\n  interpolate_pos: 1\n  decoder_dimension: 256\n  decoder_type: conv  # choose from conv and transformer\n  period: 30\n  headlayer: 1\n\nTRAIN:\n  fix_encoder_first: True\n  random_id_prob: 0.1\n  re_init_decoder_and_head: False\n  only_head: False\n  workers: 4  # data loader workers\n  batch_size: 1  # batch size for training\n  batch_size_val: 1  # batch size for validation during training, memory and speed tradeoff\n  lr: 0.0001\n  epochs: 100\n  save_every: 1\n  save_path: \"./results\"\n  weight_path: \"./pretrained_models/UniTalker-B-D0-D7.pt\"\n  eval_every: 1\n  test_every: 1\n  log_every: 1\n  do_test: True\n\nTEST:\n  test_wav_dir: \"./test_audios\"\n  test_out_path: \"./test_results/demo.npz\"\n"
  },
  {
    "path": "dataset/data_item.py",
    "content": "import librosa\nimport numpy as np\nfrom typing import Any\n\nfrom .dataset_config import dataset_config\n\n\nclass DataItem:\n\n    def __init__(self,\n                 annot_path: str,\n                 audio_path: str,\n                 identity_idx: int,\n                 annot_type: str = '',\n                 dataset_name: str = '',\n                 id_template: np.ndarray = None,\n                 fps: float = 30,\n                 processor=None) -> None:\n        self.annot_path = annot_path\n        self.audio_path = audio_path\n        self.identity_idx = identity_idx\n        self.fps = fps\n        self.annot_data = np.load(annot_path)\n        if self.fps == 60:\n            self.annot_data = self.annot_data[::2]\n            self.fps = 30\n        elif self.fps == 100:\n            self.annot_data = self.annot_data[::4]\n            self.fps = 25\n        elif self.fps != 30 and self.fps != 25:\n            raise ValueError('wrong fps')\n        scale = dataset_config[dataset_name]['scale']\n        self.annot_data = self.scale_and_offset(self.annot_data, scale, 0.0)\n        self.audio_data, sr = self.load_audio(audio_path)\n        self.original_audio_data, self.annot_data = self.truncate(\n            self.audio_data, sr, self.annot_data, self.fps)\n        self.audio_data = np.squeeze(\n            processor(self.original_audio_data, sampling_rate=sr).input_values)\n        self.audio_data = self.audio_data.astype(np.float32)\n        self.annot_type = annot_type\n        self.id_template = self.scale_and_offset(id_template, scale, 0.0)\n        self.annot_data = self.annot_data.reshape(len(self.annot_data),\n                                                  -1).astype(np.float32)\n        self.id_template = self.id_template.reshape(1, -1).astype(np.float32)\n        self.dataset_name = dataset_name\n        self.data_dict = self.to_dict()\n        return\n\n    def load_audio(self, audio_path: str):\n        if audio_path.endswith('.npy'):\n            return np.load(audio_path), 16000\n        else:\n            return librosa.load(audio_path, sr=16000)\n\n    def truncate(self, audio_array: np.ndarray, sr: int,\n                 annot_data: np.ndarray, fps: int):\n        audio_duration = len(audio_array) / sr\n        annot_duration = len(annot_data) / fps\n        duration = min(audio_duration, annot_duration, 20)\n        audio_length = round(duration * sr)\n        annot_length = round(duration * fps)\n        self.duration = duration\n        return audio_array[:audio_length], annot_data[:annot_length]\n\n    def offset_id(self, offset: int):\n        self.identity_idx = self.identity_idx + offset\n        self.data_dict['id'] = self.identity_idx\n        return\n\n    def scale_and_offset(self,\n                         data: np.ndarray,\n                         scale: float = 1.0,\n                         offset: np.ndarray = 0.0):\n        return data * scale + offset\n\n    def to_dict(self, ) -> Any:\n        item = {\n            'data': self.annot_data,\n            'audio': self.audio_data,\n            'original_audio': self.original_audio_data,\n            'fps': self.fps,\n            'id': self.identity_idx,\n            'annot_path': self.annot_path,\n            'audio_path': self.audio_path,\n            'annot_type': self.annot_type,\n            'template': self.id_template,\n            'dataset_name': self.dataset_name\n        }\n        return item\n\n    def get_dict(self, ):\n        return self.data_dict\n"
  },
  {
    "path": "dataset/dataset.py",
    "content": "import json\nimport numpy as np\nimport os.path as osp\nfrom torch.utils import data\nfrom tqdm import tqdm\nfrom transformers import Wav2Vec2FeatureExtractor\nfrom typing import Any, List\n\nfrom .data_item import DataItem\nfrom .dataset_config import dataset_config\n\n\nclass AudioFaceDataset(data.Dataset):\n\n    def __init__(self, data_label_path: str, args: dict = None) -> None:\n        super().__init__()\n        data_root = osp.dirname(data_label_path, )\n        id_template_path = osp.join(data_root, 'id_template.npy')\n        id_template_list = np.load(id_template_path)\n        with open(data_label_path) as f:\n            labels = json.load(f)\n        info = labels['info']\n        self.id_list = info['id_list']\n        self.sr = 16000\n        data_info_list = labels['data']\n        data_list = []\n        for data_info in tqdm(data_info_list):\n            data = DataItem(\n                annot_path=osp.join(data_root, data_info['annot_path']),\n                audio_path=osp.join(data_root, data_info['audio_path']),\n                identity_idx=data_info['id'],\n                annot_type=data_info['annot_type'],\n                dataset_name=data_info['dataset'],\n                id_template=id_template_list[data_info['id']],\n                fps=data_info['fps'],\n                processor=args.processor,\n            )\n            if data.duration < 0.5:\n                continue\n            data_list.append(data)\n        self.data_list = data_list\n        return\n\n    def __len__(self, ):\n        return len(self.data_list)\n\n    def __getitem__(self, index: Any) -> Any:\n        data = self.data_list[index]\n        return data.get_dict()\n\n    def get_identity_num(self, ):\n        return len(self.id_list)\n\n\nclass MixAudioFaceDataset(AudioFaceDataset):\n\n    def __init__(self,\n                 dataset_list: List[AudioFaceDataset],\n                 duplicate_list: list = None) -> None:\n        super(AudioFaceDataset).__init__()\n        self.id_list = []\n        self.sr = 16000\n        self.data_list = []\n        self.annot_type_list = []\n        if duplicate_list is None:\n            duplicate_list = [1] * len(dataset_list)\n        else:\n            assert len(duplicate_list) == len(dataset_list)\n        for dup, dataset in zip(duplicate_list, dataset_list):\n            id_index_offset = len(self.id_list)\n            for d in dataset.data_list:\n                d.offset_id(id_index_offset)\n                if d.annot_type not in self.annot_type_list:\n                    self.annot_type_list.append(d.annot_type)\n            self.id_list = self.id_list + dataset.id_list\n            self.data_list = self.data_list + dataset.data_list * dup\n        return\n\n\ndef get_dataset_list(args):\n    dataset_name_list = args.dataset\n    dataset_dir_list = [\n        dataset_config[name]['dirname'] for name in dataset_name_list\n    ]\n    train_json_list = [\n        osp.join(args.data_root, dataset_name, args.train_json_key)\n        for dataset_name in dataset_dir_list\n    ]\n    args.read_audio = False\n    args.processor = None\n    train_dataset_list = []\n    processor = Wav2Vec2FeatureExtractor.from_pretrained(\n        args.audio_encoder_repo)\n    args.processor = processor\n    for train_json in train_json_list:\n        dataset = AudioFaceDataset(\n            train_json,\n            args,\n        )\n        train_dataset_list.append(dataset)\n    return train_dataset_list\n\n\ndef get_single_dataset(args, json_path: str = ''):\n    args.read_audio = True\n    processor = Wav2Vec2FeatureExtractor.from_pretrained(\n        args.audio_encoder_repo)\n    args.processor = processor\n\n    dataset = AudioFaceDataset(\n        json_path,\n        args,\n    )\n    test_loader = data.DataLoader(\n        dataset=dataset,\n        batch_size=1,\n        shuffle=False,\n        pin_memory=True,\n        num_workers=args.workers)\n    return test_loader\n\n\ndef get_dataloaders(args):\n    dataset_name_list = args.dataset\n    print(dataset_name_list)\n    dataset_dir_list = [\n        dataset_config[name]['dirname'] for name in dataset_name_list\n    ]\n    train_json_list = [\n        osp.join(args.data_root, dataset_dir, args.train_json_key)\n        for dataset_dir in dataset_dir_list\n    ]\n    val_json_list = [\n        osp.join(args.data_root, dataset_dir, 'val.json')\n        for dataset_dir in dataset_dir_list\n    ]\n    test_json_list = [\n        osp.join(args.data_root, dataset_dir, 'test.json')\n        for dataset_dir in dataset_dir_list\n    ]\n    processor = Wav2Vec2FeatureExtractor.from_pretrained(\n        args.audio_encoder_repo)\n    args.processor = processor\n    if getattr(args, 'do_train', True) is True:\n        train_dataset_list = []\n        for train_json in train_json_list:\n            dataset = AudioFaceDataset(train_json, args)\n            train_dataset_list.append(dataset)\n        mix_train_dataset = MixAudioFaceDataset(train_dataset_list,\n                                                args.duplicate_list)\n        train_loader = data.DataLoader(\n            dataset=mix_train_dataset,\n            batch_size=args.batch_size,\n            shuffle=True,\n            pin_memory=True,\n            num_workers=args.workers)\n    else:\n        train_loader = None\n    if getattr(args, 'do_validate', True) is True:\n        val_dataset_list = []\n        for val_json in val_json_list:\n            dataset = AudioFaceDataset(val_json, args)\n            val_dataset_list.append(dataset)\n        mix_val_dataset = MixAudioFaceDataset(val_dataset_list)\n        val_loader = data.DataLoader(\n            dataset=mix_val_dataset,\n            batch_size=1,\n            shuffle=False,\n            pin_memory=True,\n            num_workers=args.workers)\n    else:\n        val_loader = None\n    if getattr(args, 'do_test', True) is True:\n        test_dataset_list = []\n        for test_json in test_json_list:\n            dataset = AudioFaceDataset(test_json, args)\n            test_dataset_list.append(dataset)\n        mix_test_dataset = MixAudioFaceDataset(test_dataset_list)\n        test_loader = data.DataLoader(\n            dataset=mix_test_dataset,\n            batch_size=1,\n            shuffle=False,\n            pin_memory=True,\n            num_workers=args.workers)\n    else:\n        test_loader = None\n    return train_loader, val_loader, test_loader\n"
  },
  {
    "path": "dataset/dataset_config.py",
    "content": "dataset_config = {\n    'D0': {\n        'dirname': 'D0_BIWI',\n        'annot_type': 'BIWI_23370_vertices',\n        'scale': 0.2,\n        'annot_dim': 23370 * 3,\n        'subjects': 6,\n        'pca': True,\n    },\n    'D1': {\n        'dirname': 'D1_vocaset',\n        'annot_type': 'FLAME_5023_vertices',\n        'scale': 1.0,\n        'annot_dim': 5023 * 3,\n        'subjects': 12,\n        'pca': True,\n    },\n    'D2': {\n        'dirname': 'D2_meshtalk',\n        'annot_type': 'meshtalk_6172_vertices',\n        'scale': 0.001,\n        'annot_dim': 6172 * 3,\n        'subjects': 13,\n        'pca': True,\n    },\n    'D3': {\n        'dirname': 'D3D4_3DETF/D3_HDTF',\n        'annot_type': '3DETF_blendshape_weight',\n        'scale': 1.0,\n        'annot_dim': 52,\n        'subjects': 141,\n        'pca': False,\n    },\n    'D4': {\n        'dirname': 'D3D4_3DETF/D4_RAVDESS',\n        'annot_type': '3DETF_blendshape_weight',\n        'scale': 1.0,\n        'annot_dim': 52,\n        'subjects': 24,\n        'pca': False,\n    },\n    'D5': {\n        'dirname': 'D5_unitalker_faceforensics++',\n        'annot_type': 'flame_params_from_dadhead',\n        'scale': 1.0,\n        'annot_dim': 413,\n        'subjects': 719,\n        'pca': False,\n    },\n    'D6': {\n        'dirname': 'D6_unitalker_Chinese_speech',\n        'annot_type': 'inhouse_blendshape_weight',\n        'scale': 1.0,\n        'annot_dim': 51,\n        'subjects': 8,\n        'pca': False,\n    },\n    'D7': {\n        'dirname': 'D7_unitalker_song',\n        'annot_type': 'inhouse_blendshape_weight',\n        'scale': 1.0,\n        'annot_dim': 51,\n        'subjects': 11,\n        'pca': False,\n    },\n}\n"
  },
  {
    "path": "loss/loss.py",
    "content": "import numpy as np\nimport pickle\nimport torch\nfrom dataclasses import dataclass\nfrom smplx.lbs import lbs\nfrom smplx.utils import Struct, to_np, to_tensor\nfrom torch import nn\n\n\n@dataclass\nclass FlameParams:\n    betas: torch.Tensor\n    jaw: torch.Tensor\n    eyeballs: torch.Tensor\n    neck: torch.Tensor\n\n\nclass FlameLayer(nn.Module):\n\n    def __init__(self, model_path: str) -> None:\n        super().__init__()\n        self.dtype = torch.float32\n        with open(model_path, 'rb') as f:\n            self.flame_model = Struct(**pickle.load(f, encoding='latin1'))\n        shapedirs = self.flame_model.shapedirs\n        # The shape components\n        self.register_buffer('shapedirs',\n                             to_tensor(to_np(shapedirs), dtype=self.dtype))\n\n        j_regressor = to_tensor(\n            to_np(self.flame_model.J_regressor), dtype=self.dtype)\n        self.register_buffer('J_regressor', j_regressor)\n        self.register_buffer(\n            'v_template',\n            to_tensor(to_np(self.flame_model.v_template), dtype=self.dtype),\n        )\n        num_pose_basis = self.flame_model.posedirs.shape[-1]\n        posedirs = np.reshape(self.flame_model.posedirs,\n                              [-1, num_pose_basis]).T\n        self.register_buffer('posedirs',\n                             to_tensor(to_np(posedirs), dtype=self.dtype))\n        parents = to_tensor(to_np(self.flame_model.kintree_table[0])).long()\n        parents[0] = -1\n        self.register_buffer('parents', parents)\n\n        self.register_buffer(\n            'lbs_weights',\n            to_tensor(to_np(self.flame_model.weights), dtype=self.dtype))\n        default_neck_pose = torch.zeros([1, 3],\n                                        dtype=self.dtype,\n                                        requires_grad=False)\n        self.register_parameter(\n            'neck_pose', nn.Parameter(default_neck_pose, requires_grad=False))\n        default_eyeball_pose = torch.zeros([1, 6],\n                                           dtype=self.dtype,\n                                           requires_grad=False)\n        self.register_parameter(\n            'eyeballs',\n            nn.Parameter(default_eyeball_pose, requires_grad=False))\n        default_jaw = torch.zeros([1, 3],\n                                  dtype=self.dtype,\n                                  requires_grad=False)\n        self.register_parameter('jaw',\n                                nn.Parameter(default_jaw, requires_grad=False))\n\n        self.FLAME_CONSTS = {\n            'betas': 400,\n            'rotation': 6,\n            'jaw': 3,\n            'eyeballs': 0,\n            'neck': 0,\n            'translation': 3,\n            'scale': 1,\n        }\n\n    def flame_params_from_3dmm(self,\n                               tensor_3dmm: torch.Tensor,\n                               zero_expr: bool = False):\n        constants = self.FLAME_CONSTS\n\n        assert tensor_3dmm.ndim == 2\n\n        cur_index: int = 0\n        betas = tensor_3dmm[:, :constants['betas']]\n        cur_index += constants['betas']\n        jaw = tensor_3dmm[:, cur_index:cur_index + constants['jaw']]\n        cur_index += constants['jaw']\n        cur_index += constants['rotation']\n        eyeballs = tensor_3dmm[:, cur_index:cur_index + constants['eyeballs']]\n        cur_index += constants['eyeballs']\n        neck = tensor_3dmm[:, cur_index:cur_index + constants['neck']]\n        cur_index += constants['neck']\n        cur_index += constants['translation']\n        cur_index += constants['scale']\n        return FlameParams(betas=betas, jaw=jaw, eyeballs=eyeballs, neck=neck)\n\n    def forward(self, tensor_3dmm: torch.Tensor):\n        bs = tensor_3dmm.shape[0]\n        flame_params = self.flame_params_from_3dmm(tensor_3dmm)\n        betas = flame_params.betas\n        rotation = torch.zeros([bs, 3], device=tensor_3dmm.device)\n        neck_pose = flame_params.neck if not (\n            0 in flame_params.neck.shape) else self.neck_pose[[0]].expand(\n                bs, -1)\n        eyeballs = flame_params.eyeballs if not (\n            0 in flame_params.eyeballs.shape) else self.eyeballs[[0]].expand(\n                bs, -1)\n        jaw = flame_params.jaw if not (\n            0 in flame_params.jaw.shape) else self.jaw[[0]].expand(bs, -1)\n\n        full_pose = torch.cat([rotation, neck_pose, jaw, eyeballs], dim=1)\n        template_vertices = self.v_template.unsqueeze(0).repeat(bs, 1, 1)\n        vertices, _ = lbs(\n            betas,\n            full_pose,\n            template_vertices,\n            self.shapedirs,\n            self.posedirs,\n            self.J_regressor,\n            self.parents,\n            self.lbs_weights,\n        )\n        return vertices\n\n\nclass FlameParamLossForDADHead(nn.Module):\n\n    def __init__(self, mouth_indices_path: str, loss_mask_path: str,\n                 flame_model_path: str) -> None:\n        super().__init__()\n        mouth_indices = np.load(mouth_indices_path)\n        loss_indices = np.load(loss_mask_path)\n        self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))\n        self.register_buffer('loss_indices', torch.from_numpy(loss_indices))\n        self.mse = nn.MSELoss()\n        self.flame_layer = FlameLayer(flame_model_path)\n\n    def forward(self, x: torch.Tensor, target: torch.Tensor):\n        ori_shape = x.shape\n        x = self.flame_layer(x.reshape(-1, ori_shape[-1]))\n        target = self.flame_layer(target.reshape(-1, ori_shape[-1]))\n        x = x.reshape((ori_shape[0], ori_shape[1], -1, 3))\n        target = target.reshape((ori_shape[0], ori_shape[1], -1, 3))\n        return self.mse(x[:, :, self.loss_indices], target[:, :,\n                                                           self.loss_indices])\n\n    def get_vertices(\n        self,\n        x: torch.Tensor,\n    ):\n        x = self.flame_layer(x.reshape(-1, x.shape[-1]))\n        return x[:, self.loss_indices]\n\n    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):\n        ori_shape = x.shape\n        x = self.flame_layer(x.reshape(-1, ori_shape[-1]))\n        target = self.flame_layer(target.reshape(-1, ori_shape[-1]))\n        metric_L2 = (target[:, self.mouth_idx] - x[:, self.mouth_idx])**2\n        metric_L2 = metric_L2.sum(-1)\n        metric_L2 = metric_L2.max(-1)[0]\n        metric_L2norm = metric_L2**0.5\n        return metric_L2.mean(), metric_L2norm.mean()\n\n\nclass VerticesLoss(nn.Module):\n\n    def __init__(self, mouth_indices_path: str) -> None:\n        super().__init__()\n        mouth_indices = np.load(mouth_indices_path)\n        self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))\n        self.mse = nn.MSELoss()\n\n    def forward(self, x: torch.Tensor, target: torch.Tensor):\n        return self.mse(x, target)\n\n    def get_vertices(\n        self,\n        x: torch.Tensor,\n    ):\n        return x.reshape(-1, x.shape[-1] // 3, 3)\n\n    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):\n        ori_shape = x.shape\n        target = target.reshape(*ori_shape[:-1], -1, 3)\n        x = x.reshape(*ori_shape[:-1], -1, 3)\n        metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2\n        metric_L2 = metric_L2.sum(-1)\n        metric_L2 = metric_L2.max(-1)[0]\n        metric_L2norm = metric_L2**0.5\n        return metric_L2.mean(), metric_L2norm.mean()\n\n\nclass BlendShapeLoss(nn.Module):\n\n    def __init__(self,\n                 mouth_indices_path: str,\n                 blendshape_path: str,\n                 bs_beta: float = 0.0,\n                 components_num: int = None) -> None:\n        super().__init__()\n        mouth_indices = np.load(mouth_indices_path)\n        blendshape = np.load(blendshape_path)\n        self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))\n        mean_shape = blendshape['meanshape']\n        mean_shape = mean_shape.reshape(-1).astype(np.float32)\n        blend_shape = blendshape['blendshape']\n        blend_shape = blend_shape.reshape(len(blend_shape),\n                                          -1).astype(np.float32)\n        self.register_buffer('mean_shape', torch.from_numpy(mean_shape))\n        self.register_buffer('blend_shape', torch.from_numpy(blend_shape))\n        self.bs_beta = bs_beta\n        self.mse = nn.MSELoss()\n        self.components_num = components_num\n\n    def forward(self, x: torch.Tensor, target: torch.Tensor):\n        return self.bs_beta * self.mse(x, target) + self.mse(\n            self.bs2vertices(x), self.bs2vertices(target))\n\n    def bs2vertices(self, x: torch.Tensor):\n        return x[:, :self.components_num] @ \\\n            self.blend_shape[:self.components_num] + \\\n            self.mean_shape\n\n    def get_vertices(\n        self,\n        x: torch.Tensor,\n    ):\n        x = self.bs2vertices(x)\n        return x.reshape(-1, x.shape[-1] // 3, 3)\n\n    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):\n        target = self.bs2vertices(target)\n        x = self.bs2vertices(x)\n        ori_shape = x.shape\n        target = target.reshape(*ori_shape[:-1], -1, 3)\n        x = x.reshape(*ori_shape[:-1], -1, 3)\n        metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2\n        metric_L2 = metric_L2.sum(-1)\n        metric_L2 = metric_L2.max(-1)[0]\n        metric_L2norm = metric_L2**0.5\n        return metric_L2.mean(), metric_L2norm.mean()\n\n\nclass ParamLoss(nn.Module):\n\n    def __init__(self, weight: float):\n        super().__init__()\n        self.weight = weight\n        self.mse = nn.MSELoss()\n        self.L1 = nn.L1Loss()\n\n    def forward(self, x: torch.Tensor, target: torch.Tensor):\n        return self.weight * self.mse(x, target)\n\n    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):\n        return self.mse(x, target), self.L1(x, target)\n\n    def get_vertices(\n        self,\n        x: torch.Tensor,\n    ):\n        return x.squeeze(0)\n\n\nclass UniTalkerLoss(nn.Module):\n\n    def __init__(self, args) -> None:\n        super().__init__()\n        loss_config = {\n            'flame_params_from_dadhead': {\n                'class': FlameParamLossForDADHead,\n                'args': {\n                    'mouth_indices_path':\n                    'resources/binary_resources/02_flame_mouth_idx.npy',\n                    'loss_mask_path':\n                    'resources/binary_resources/flame_humanface_index.npy',\n                    'flame_model_path': 'resources/binary_resources/flame.pkl',\n                }\n            },\n            '3DETF_blendshape_weight': {\n                'class': BlendShapeLoss,\n                'args': {\n                    'mouth_indices_path':\n                    'resources/binary_resources/03_emo_talk_mouth_idx.npy',\n                    'blendshape_path':\n                    'resources/binary_resources/EmoTalk.npz',\n                    'bs_beta': args.blendshape_weight,\n                },\n            },\n            'inhouse_blendshape_weight': {\n                'class': BlendShapeLoss,\n                'args': {\n                    'mouth_indices_path':\n                    'resources/binary_resources/05_inhouse_arkit_mouth_idx.npy',\n                    'blendshape_path':\n                    'resources/binary_resources/inhouse_arkit.npz',\n                    'bs_beta': args.blendshape_weight,\n                },\n            },\n            'FLAME_5023_vertices': {\n                'class': VerticesLoss,\n                'args': {\n                    'mouth_indices_path':\n                    'resources/binary_resources/02_flame_mouth_idx.npy',\n                },\n            },\n            'BIWI_23370_vertices': {\n                'class': VerticesLoss,\n                'args': {\n                    'mouth_indices_path':\n                    'resources/binary_resources/04_BIWI_mouth_idx.npy',\n                },\n            },\n            'FLAME_SUB_2055_vertices': {\n                'class': VerticesLoss,\n                'args': {\n                    'mouth_indices_path':\n                    'resources/binary_resources/mouth_idx_FLAME_5023_vertices_wo_eyes_wo_head.npy',\n                },\n            },\n            'meshtalk_6172_vertices': {\n                'class': VerticesLoss,\n                'args': {\n                    'mouth_indices_path':\n                    'resources/binary_resources/06_mesh_talk_mouth_idx.npy',\n                },\n            },\n            '24_viseme': {\n                'class': ParamLoss,\n                'args': {\n                    'weight': args.blendshape_weight\n                },\n            }\n        }\n\n        loss_module_dict = {}\n        for k, v in loss_config.items():\n            loss_module = v['class'](**v['args'])\n            loss_module_dict[k] = loss_module\n        self.loss_module_dict = nn.ModuleDict(loss_module_dict)\n        self.mse = nn.MSELoss()\n        return\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        target: torch.Tensor,\n        annot_type: str,\n    ):\n        return self.loss_module_dict[annot_type](x, target)\n\n    def pca_loss(self, x: torch.Tensor, target: torch.Tensor):\n        return self.mse(x, target)\n\n    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor,\n                     annot_type: str):\n        return self.loss_module_dict[annot_type].mouth_metric(x, target)\n\n    def get_vertices(self, x: torch.Tensor, annot_type: str):\n        return self.loss_module_dict[annot_type].get_vertices(x)\n"
  },
  {
    "path": "main/demo.py",
    "content": "import librosa\nimport numpy as np\nimport os\nimport torch\nimport torch.optim\nimport torch.utils.data\nfrom transformers import Wav2Vec2FeatureExtractor\n\nfrom dataset.dataset_config import dataset_config\nfrom loss.loss import UniTalkerLoss\nfrom models.unitalker import UniTalker\nfrom utils.utils import get_parser, get_template_verts, get_audio_encoder_dim\n\ndef get_all_audios(audio_root:str):\n    wav_f_names = []\n    for r, _, f in os.walk(audio_root):\n        for wav in f:\n            if wav.endswith('.wav') or wav.endswith('.mp3'):\n                relative_path = os.path.join(r, wav)\n                relative_path = os.path.relpath(relative_path,\n                                                audio_root)\n                wav_f_names.append(relative_path)\n    wav_f_names = sorted(wav_f_names)\n    return wav_f_names\n\ndef split_long_audio(\n    audio: np.ndarray,\n    processor:Wav2Vec2FeatureExtractor\n):\n    # audio = audio.squeeze(0)\n    a, b = 25, 5\n    sr = 16000 \n    total_length = len(audio) /sr\n    reps = max(0, int(np.ceil((total_length - a) / (a - b)))) + 1\n    in_audio_split_list = []\n    start, end = 0, int(a * sr)\n    step = int((a - b) * sr)\n    for i in range(reps):\n        audio_split = audio[start:end]\n        audio_split = np.squeeze(\n            processor(audio_split, sampling_rate=sr).input_values)\n        in_audio_split_list.append(audio_split)\n        start += step\n        end += step\n    return in_audio_split_list\n\ndef merge_out_list(out_list: list, fps:int):\n    if len(out_list) == 1:\n        return out_list[0]\n    a, b = 25, 5\n    left_weight = np.linspace(1, 0, b * fps)[:, np.newaxis]\n    right_weight = 1 - left_weight\n    a = a * fps \n    b = b * fps \n    offset = a - b\n\n    out_length = len(out_list[-1]) + offset * (len(out_list) - 1)\n    merged_out = np.empty((out_length, out_list[-1].shape[-1]),\n                            dtype=out_list[-1].dtype)\n    merged_out[:a] = out_list[0]\n    for out_piece in out_list[1:]:\n        merged_out[a - b:a] = left_weight * merged_out[\n            a - b:a] + right_weight * out_piece[:b]\n        merged_out[a:a + offset] = out_piece[b:]\n        a += offset\n    return merged_out\n\n\ndef main():\n    args = get_parser()\n    training_dataset_name_list = args.dataset\n    demo_dataset_name_list = args.demo_dataset\n    \n    condition_id_config = {\n        'D0': 3,\n        'D1': 3,\n        'D2': 0,\n        'D3': 0,\n        'D4': 0,\n        'D5': 0,\n        'D6': 4,\n        'D7': 0,\n    }\n    template_id_config = {\n        'D0': 3,\n        'D1': 3,\n        'D2': 0,\n        'D3': 0,\n        'D4': 0,\n        'D5': 0,\n        'D6': 4,\n        'D7': 0,\n    }\n\n    checkpoint = torch.load(args.weight_path, map_location='cpu')\n    args.identity_num = len(checkpoint['decoder.learnable_style_emb.weight'])\n\n\n    start_idx = 0\n    for dataset_name in training_dataset_name_list:\n        annot_type = dataset_config[dataset_name]['annot_type']\n        id_num = dataset_config[dataset_name]['subjects']\n        end_idx = start_idx + id_num\n        local_condition_idx = condition_id_config[dataset_name]\n        template_idx = template_id_config[dataset_name]\n        template = get_template_verts(args.data_root, dataset_name, template_idx)\n        template_id_config[dataset_name] = torch.Tensor(template.reshape(1, -1))\n        if args.condition_id == 'each':\n            condition_id_config[dataset_name] = torch.tensor(\n                start_idx + local_condition_idx).reshape(1)\n        elif args.condition_id == 'common':\n            condition_id_config[dataset_name] = torch.tensor(args.identity_num -\n                                                      1).reshape(1)\n        else:\n            try:\n                condition_id = int(args.condition_id)\n                condition_id_config[dataset_name] = torch.tensor(\n                    condition_id).reshape(1)\n            except ValueError:\n                assert args.condition_id in dataset_config.keys()\n                condition_id_config[dataset_name] = torch.tensor(\n                    start_idx + local_condition_idx).reshape(1)\n        start_idx = end_idx\n    if args.random_id_prob > 0:\n        assert end_idx == (args.identity_num - 1)\n    else:\n        assert end_idx == args.identity_num\n\n    args.audio_encoder_feature_dim = get_audio_encoder_dim(\n        args.audio_encoder_repo)\n\n    model = UniTalker(args)\n    model.load_state_dict(checkpoint, strict=False)\n    model.eval()\n    model.cuda()\n\n    wav_f_names = get_all_audios(args.test_wav_dir)\n    out_npz_path = args.test_out_path\n    dirname = os.path.dirname(out_npz_path)\n    os.makedirs(dirname, exist_ok=True)\n\n\n    out_dict = {}\n    loss_module = UniTalkerLoss(args).cuda()\n    processor = Wav2Vec2FeatureExtractor.from_pretrained(\n        args.audio_encoder_repo,)\n    \n    with torch.no_grad():\n        for wav_f in wav_f_names:\n            out_dict[wav_f] = {}\n            wav_path = os.path.join(args.test_wav_dir, wav_f)\n            audio_data, sr = librosa.load(wav_path, sr=16000)\n            audio_data = np.squeeze(\n                processor(audio_data, sampling_rate=sr).input_values)\n\n            audio_data_splits = split_long_audio(audio_data, processor)\n            for dataset_name in demo_dataset_name_list:\n                template = template_id_config[dataset_name].cuda()\n                scale = dataset_config[dataset_name]['scale']\n                template = scale * template\n                condition_id = condition_id_config[dataset_name].cuda()\n                annot_type = dataset_config[dataset_name]['annot_type']\n                if annot_type == 'BIWI_23370_vertices':\n                    fps = 25\n                else:\n                    fps = 30\n\n                out_list = []\n                for audio_data in audio_data_splits:\n                    audio_data = torch.Tensor(audio_data[None]).cuda()\n                    out, _, _ = model(\n                        audio_data,\n                        template,\n                        face_motion=None,\n                        style_idx=condition_id,\n                        annot_type=annot_type,\n                        fps=fps,\n                    )\n                    out_list.append(out.cpu().numpy().squeeze(0))\n                out = merge_out_list(out_list, fps)\n                out = loss_module.get_vertices(torch.from_numpy(out).cuda(), annot_type)\n                out_dict[wav_f][dataset_name] = out.cpu() \n        print(f\"save results to {out_npz_path}\")\n        np.savez(out_npz_path, **out_dict)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "main/render.py",
    "content": "\"\"\"Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V.\n\n(MPG) is holder of all proprietary rights on this computer program. You can\nonly use this computer program if you have closed a license agreement with MPG\nor you get the right to use the computer program from someone who is authorized\nto grant you that right. Any use of the computer program without a valid\nlicense is prohibited and liable to prosecution. Copyright 2019 Max-Planck-\nGesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of\nits Max Planck Institute for Intelligent Systems and the Max Planck Institute\nfor Biological Cybernetics. All rights reserved. More information about VOCA is\navailable at http://voca.is.tue.mpg.de. For comments or questions, please email\nus at voca@tue.mpg.de\n\"\"\"\n\nimport cv2\nimport numpy as np\nimport os\nimport subprocess\nimport torch\nfrom tqdm import tqdm\nfrom dataset.dataset_config import dataset_config\n\ndef render_meshes(mesh_vertices: torch.Tensor,\n                  faces: torch.Tensor,\n                  img_size: tuple = (256, 256),\n                  aa_factor: int = 1,\n                  vis_bs: int = 100):\n    assert len(mesh_vertices) == len(faces) or len(faces) == 1\n    if len(faces) == 1:\n        faces = faces.repeat(len(mesh_vertices), 1, 1)\n\n    x_max = y_max = mesh_vertices[..., 0:2].abs().max()\n    from pytorch3d.renderer import (\n        DirectionalLights,\n        FoVOrthographicCameras,\n        HardPhongShader,\n        Materials,\n        MeshRasterizer,\n        MeshRenderer,\n        RasterizationSettings,\n        TexturesVertex,\n    )\n    from pytorch3d.renderer.blending import BlendParams\n    from pytorch3d.renderer.materials import Materials\n    from pytorch3d.structures import Meshes\n    aa_factor = int(aa_factor)\n    device = mesh_vertices.device\n    verts_batch_lst = torch.split(mesh_vertices, vis_bs)\n    faces_batch_lst = torch.split(faces, vis_bs)\n    R = torch.eye(3).to(mesh_vertices.device)\n    R[0, 0] = -1\n    R[2, 2] = -1\n    T = torch.zeros(3).to(mesh_vertices.device)\n    T[2] = 10\n    R, T = R[None], T[None]\n    W, H = img_size\n    cameras = FoVOrthographicCameras(\n        device=device,\n        R=R,\n        T=T,\n        znear=0.01,\n        zfar=3,\n        max_x=x_max * 1.2,\n        min_x=-x_max * 1.2,\n        max_y=y_max * 1.2,\n        min_y=-y_max * 1.2,\n    )\n    # cameras = FoVPerspectiveCameras(\n    # \tdevice=device,\n    # \tR=R,\n    # \tT=T,\n    # \tznear=0.01,\n    # \tzfar=3,\n    # \taspect_ratio=1,\n    # \tfov=0.3,\n    # \tdegrees=False\n    # \t)\n\n    raster_settings = RasterizationSettings(\n        image_size=(W * aa_factor, H * aa_factor),\n        blur_radius=0.0,\n        faces_per_pixel=1,\n    )\n    raster = MeshRasterizer(\n        cameras=cameras,\n        raster_settings=raster_settings,\n    )\n    blend_params = BlendParams(\n        sigma=0.0, gamma=0.0, background_color=(0.0, 0.0, 0.0))\n    lights = DirectionalLights(\n        device=device,\n        direction=((0, 0, 1), ),\n        ambient_color=((0.3, 0.3, 0.3), ),\n        diffuse_color=((0.6, 0.6, 0.6), ),\n        specular_color=((0.1, 0.1, 0.1), ))\n    materias = Materials(\n        ambient_color=((1, 1, 1), ),\n        diffuse_color=((1, 1, 1), ),\n        specular_color=((1, 1, 1), ),\n        shininess=15,\n        device=device)\n    shader = HardPhongShader(\n        device=device,\n        cameras=cameras,\n        lights=lights,\n        materials=materias,\n        blend_params=blend_params)\n    renderer = MeshRenderer(\n        rasterizer=raster,\n        shader=shader,\n    )\n    rendered_imgs = []\n    for verts_batch, faces_batch in tqdm(\n            zip(verts_batch_lst, faces_batch_lst),\n            total=(len(verts_batch_lst))):\n        textures = TexturesVertex(verts_features=torch.ones_like(verts_batch))\n        meshes = Meshes(\n            verts=verts_batch, faces=faces_batch, textures=textures)\n        with torch.no_grad():\n            imgs = renderer(meshes)\n        rendered_imgs.append(imgs.cpu())\n    rendered_imgs = torch.cat(rendered_imgs, dim=0)\n    if aa_factor > 1:\n        rendered_imgs = rendered_imgs.permute(0, 3, 1, 2)  # NHWC -> NCHW\n        rendered_imgs = torch.nn.functional.interpolate(\n            rendered_imgs, scale_factor=1 / aa_factor, mode='bicubic')\n        rendered_imgs = rendered_imgs.permute(0, 2, 3, 1)  # NCHW -> NHWC\n    return rendered_imgs\n\n\ndef read_obj(in_path):\n    with open(in_path, 'r') as obj_file:\n        # Read the lines of the OBJ file\n        lines = obj_file.readlines()\n\n    # Initialize empty lists for vertices and faces\n    verts = []\n    faces = []\n    for line in lines:\n        line = line.strip()  # Remove leading/trailing whitespace\n        elements = line.split()  # Split the line into elements\n\n        if len(elements) == 0:\n            continue  # Skip empty lines\n\n        # Check the type of line (vertex or face)\n        if elements[0] == 'v':\n            # Vertex line\n            x, y, z = map(float,\n                          elements[1:4])  # Extract the vertex coordinates\n            verts.append((x, y, z))  # Add the vertex to the list\n        elif elements[0] == 'f':\n            # Face line\n            face_indices = [\n                int(index.split('/')[0]) for index in elements[1:]\n            ]  # Extract the vertex indices\n            faces.append(face_indices)  # Add the face to the list\n    return np.array(verts), np.array(faces)\n\n\ndef pad_for_libx264(image_array):\n    \"\"\"Pad zeros if width or height of image_array is not divisible by 2.\n    Otherwise you will get.\n\n    \\\"[libx264 @ 0x1b1d560] width not divisible by 2 \\\"\n\n    Args:\n            image_array (np.ndarray):\n                    Image or images load by cv2.imread().\n                    Possible shapes:\n                    1. [height, width]\n                    2. [height, width, channels]\n                    3. [images, height, width]\n                    4. [images, height, width, channels]\n\n    Returns:\n            np.ndarray:\n                    A image with both edges divisible by 2.\n    \"\"\"\n    if image_array.ndim == 2 or \\\n      (image_array.ndim == 3 and image_array.shape[2] == 3):\n        hei_index = 0\n        wid_index = 1\n    elif image_array.ndim == 4 or \\\n      (image_array.ndim == 3 and image_array.shape[2] != 3):\n        hei_index = 1\n        wid_index = 2\n    else:\n        return image_array\n    hei_pad = image_array.shape[hei_index] % 2\n    wid_pad = image_array.shape[wid_index] % 2\n    if hei_pad + wid_pad > 0:\n        pad_width = []\n        for dim_index in range(image_array.ndim):\n            if dim_index == hei_index:\n                pad_width.append((0, hei_pad))\n            elif dim_index == wid_index:\n                pad_width.append((0, wid_pad))\n            else:\n                pad_width.append((0, 0))\n        values = 0\n        image_array = \\\n         np.pad(image_array,\n             pad_width,\n             mode='constant', constant_values=values)\n    return image_array\n\n\ndef array_to_video(\n    image_array: np.ndarray,\n    output_path: str,\n    fps=30,\n    resolution=None,\n    disable_log: bool = False,\n) -> None:\n    \"\"\"Convert an array to a video directly, gif not supported.\n\n    Args:\n            image_array (np.ndarray): shape should be (f * h * w * 3).\n            output_path (str): output video file path.\n            fps (Union[int, float, optional): fps. Defaults to 30.\n            resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],\n                    optional): (height, width) of the output video.\n                    Defaults to None.\n            disable_log (bool, optional): whether close the ffmepg command info.\n                    Defaults to False.\n    Raises:\n            FileNotFoundError: check output path.\n            TypeError: check input array.\n\n    Returns:\n            None.\n    \"\"\"\n    if not isinstance(image_array, np.ndarray):\n        raise TypeError('Input should be np.ndarray.')\n    assert image_array.ndim == 4\n    assert image_array.shape[-1] == 3\n    if resolution:\n        height, width = resolution\n        width += width % 2\n        height += height % 2\n    else:\n        image_array = pad_for_libx264(image_array)\n        height, width = image_array.shape[1], image_array.shape[2]\n    command = [\n        '/usr/bin/ffmpeg',\n        '-y',  # (optional) overwrite output file if it exists\n        '-f',\n        'rawvideo',\n        '-s',\n        f'{int(width)}x{int(height)}',  # size of one frame\n        '-pix_fmt',\n        'bgr24',\n        '-r',\n        f'{fps}',  # frames per second\n        '-loglevel',\n        'error',\n        '-threads',\n        '4',\n        '-i',\n        '-',  # The input comes from a pipe\n        '-vcodec',\n        'libx264',\n        '-an',  # Tells FFMPEG not to expect any audio\n        output_path,\n    ]\n    if not disable_log:\n        print(f'Running \\\"{\" \".join(command)}\\\"')\n    process = subprocess.Popen(\n        command,\n        stdin=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n    )\n    if process.stdin is None or process.stderr is None:\n        raise BrokenPipeError('No buffer received.')\n    index = 0\n    while True:\n        if index >= image_array.shape[0]:\n            break\n        process.stdin.write(image_array[index].tobytes())\n        index += 1\n    process.stdin.close()\n    process.stderr.close()\n    process.wait()\n\n\ndef get_obj_faces(annot_type: str):\n    template_obj_path_dict = {\n        '3DETF_blendshape_weight': 'resources/obj_template/3DETF_blendshape_weight.obj',\n        'FLAME_5023_vertices': 'resources/obj_template/FLAME_5023_vertices.obj',\n        'BIWI_23370_vertices': 'resources/obj_template/BIWI_23370_vertices.obj',\n        'flame_params_from_dadhead': 'resources/obj_template/flame_params_from_dadhead.obj',\n        'inhouse_blendshape_weight': 'resources/obj_template/inhouse_blendshape_weight.obj',\n        'meshtalk_6172_vertices': 'resources/obj_template/meshtalk_6172_vertices.obj'\n    }\n\n    faces = np.array(read_obj(template_obj_path_dict[annot_type])[1])\n    if faces.min() == 1:\n        faces = faces - 1\n    return faces\n\n\ndef vis_model_out_proxy():\n    import sys\n    npz_path = sys.argv[1]\n    audio_dir = sys.argv[2]\n    out_dir = sys.argv[3]\n    save_images = False \n    os.makedirs(out_dir, exist_ok=True)\n    out_dict = np.load(npz_path, allow_pickle=True)\n    print(list(out_dict.keys()))\n    for wav_f in out_dict.keys():\n        wav_path = os.path.join(audio_dir, wav_f)\n        cur_out_dict = out_dict[wav_f]\n        for datasetname, out in cur_out_dict.item().items():\n            out = torch.Tensor(out)\n            annot_type = dataset_config[datasetname]['annot_type']\n            faces = get_obj_faces(annot_type)\n            with torch.no_grad():\n                if save_images:\n                    out_size = (1024, 1024)\n                else:\n                    out_size = (512, 512)\n                img_array = render_meshes(out.cuda(),\n                                          torch.from_numpy(faces[None]).cuda(),\n                                          out_size)\n                if save_images:\n                    channels = 4\n                else:\n                    channels = 3\n                img_array = (img_array[..., :channels].cpu().numpy() *\n                             255).astype(np.uint8)\n            out_key = f'{wav_f[:-4]}'\n            if save_images:\n                out_images_dir = os.path.join(out_dir, datasetname, out_key)\n                os.makedirs(out_images_dir, exist_ok=True)\n                for imgidx, img in enumerate(img_array):\n                    img_path = os.path.join(out_images_dir,\n                                            f'{imgidx:04d}.png')\n                    cv2.imwrite(img_path, img)\n                continue\n            out_video_path = os.path.join(out_dir,\n                                          f'{out_key}_{datasetname}.mp4')\n            out_video_dir = os.path.dirname(out_video_path)\n            os.makedirs(out_video_dir, exist_ok=True)\n            tmp_path = os.path.join(out_video_dir, 'tmp.mp4')\n            if os.path.exists(tmp_path):\n                os.remove(tmp_path)\n            if annot_type == 'BIWI_23370_vertices':\n                fps = 25\n            else:\n                fps = 30\n            print(img_array.shape, img_array.dtype)\n            array_to_video(img_array, output_path=tmp_path, fps=fps)\n\n            cmd = f'ffmpeg -i {tmp_path} -i  {wav_path} -c:v copy -c:a aac -shortest -strict -2  {out_video_path} -y '\n            print(cmd)\n            subprocess.run(cmd, shell=True)\n            os.remove(tmp_path)\n\n\nif __name__ == '__main__':\n    vis_model_out_proxy()\n"
  },
  {
    "path": "main/train.py",
    "content": "#!/usr/bin/env python\n# yapf: disable\nimport os\nimport torch\nimport torch.optim\nimport torch.utils.data\nfrom tensorboardX import SummaryWriter\n\nfrom dataset.dataset import get_dataloaders\nfrom loss.loss import UniTalkerLoss\nfrom models.unitalker import UniTalker\nfrom utils.utils import (\n    get_average_meter_dict, get_logger, get_parser, get_audio_encoder_dim,\n    load_ckpt, seed_everything, filter_unitalker_state_dict\n)\nfrom .train_eval_loop import train_epoch, validate_epoch\n\n# yapf: enable\n\n\ndef main():\n    seed_everything(42)\n    args = get_parser()\n    args.audio_encoder_feature_dim = get_audio_encoder_dim(args.audio_encoder_repo)\n    train_loader, val_loader, test_loader = get_dataloaders(args)\n    args.identity_num = train_loader.dataset.get_identity_num()\n    args.annot_type_list = train_loader.dataset.annot_type_list\n    if args.random_id_prob > 0.0:\n        args.identity_num = args.identity_num + 1\n\n    os.makedirs(args.save_path, exist_ok=True)\n    log_file = os.path.join(args.save_path, 'train.log')\n    logger = get_logger(log_file)\n    writer = SummaryWriter(args.save_path)\n    logger.info('=> creating model ...')\n    model = UniTalker(args)\n    if args.weight_path is not None:\n        logger.info(f'=> loading model {args.weight_path} ...')\n        load_ckpt(model, args.weight_path, args.re_init_decoder_and_head)\n    logger.info(args)\n    model = model.cuda()\n    model.summary(logger)\n\n    optimizer = torch.optim.Adam(\n        filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)\n\n    loss_module = UniTalkerLoss(args).cuda()\n\n    args.train_meter, args.train_metric_name_list = \\\n        get_average_meter_dict(args, 'train')\n    args.val_meter, args.val_metric_name_list = \\\n        get_average_meter_dict(args, 'val')\n    args.logger = logger\n    args.writer = writer\n\n    for epoch in range(args.epochs):\n        if args.fix_encoder_first:\n            if epoch == 0:\n                model.audio_encoder._freeze_wav2vec2_parameters(True)\n                optimizer = torch.optim.Adam(\n                    filter(lambda p: p.requires_grad, model.parameters()),\n                    lr=args.lr)\n            elif epoch == 1:\n                model.audio_encoder._freeze_wav2vec2_parameters(\n                    args.freeze_wav2vec)\n                model.audio_encoder.feature_extractor._freeze_parameters()\n                optimizer = torch.optim.Adam(\n                    filter(lambda p: p.requires_grad, model.parameters()),\n                    lr=args.lr)\n        epoch_plus = epoch + 1\n        train_epoch(train_loader, model, loss_module, optimizer, epoch_plus,\n                    args)\n\n        if epoch_plus % args.eval_every == 0:\n            validate_epoch(val_loader, model, loss_module, epoch_plus, 'val',\n                           args)\n\n        if epoch_plus % args.test_every == 0:\n            validate_epoch(test_loader, model, loss_module, epoch_plus, 'test',\n                           args)\n\n        if epoch_plus % args.save_every == 0:\n            model_path = os.path.join(args.save_path, f'{epoch_plus:03d}.pt')\n            torch.save(filter_unitalker_state_dict(model.state_dict()), model_path)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "main/train_eval_loop.py",
    "content": "import numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom utils.utils import log_datasetloss, write_to_tensorboard\n\n\ndef train_epoch(data_loader, model, loss_module, optimizer, epoch, args):\n    model.train()\n    train_meter = args.train_meter\n    metric_name_list = args.train_metric_name_list\n    zero_loss = torch.tensor(0.0)\n    mixed_dataset_name = 'mixed_dataset'\n    pbar = tqdm(data_loader)\n    pbar.set_description(f'Epoch: {epoch} train')\n    for batch in pbar:\n        data = batch['data'].cuda(non_blocking=True)\n        template = batch['template'].cuda(non_blocking=True)\n        audio = batch['audio'].cuda(non_blocking=True)\n        identity = batch['id'].cuda(non_blocking=True)\n        if args.random_id_prob > 0.0:\n            if np.random.random() < args.random_id_prob:\n                identity = identity.fill_(args.identity_num - 1)\n        annot_type = batch['annot_type'][0]\n        dataset_name = batch['dataset_name'][0]\n\n        out_motion, out_pca, gt_pca = model(\n            audio, template, data, style_idx=identity, annot_type=annot_type)\n\n        rec_loss = loss_module(out_motion, data, annot_type)\n        if out_pca is not None:\n            pca_loss = loss_module.pca_loss(out_pca, gt_pca)\n        else:\n            pca_loss = zero_loss\n        loss = rec_loss + args.pca_weight * pca_loss\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        for k, v in zip(metric_name_list, (rec_loss, pca_loss)):\n            train_meter[dataset_name][k].update(v.item())\n            train_meter[mixed_dataset_name][k].update(v.item())\n\n    log_datasetloss(args.logger, epoch, 'train', train_meter, only_mix=True)\n    write_to_tensorboard(\n        args.writer, epoch, 'train', train_meter, only_mix=False)\n\n    for dataset_name, sub_metric_dict in train_meter.items():\n        for metric_name in sub_metric_dict.keys():\n            sub_metric_dict[metric_name].reset()\n    return\n\n\ndef validate_epoch(data_loader, model, loss_module, epoch, phase, args):\n    assert phase in ('val', 'test')\n    model.eval()\n    val_meter = args.val_meter\n    metric_name_list = args.val_metric_name_list\n    mixed_dataset_name = 'mixed_dataset'\n    with torch.no_grad():\n        pbar = tqdm(data_loader)\n        pbar.set_description(f'Epoch: {epoch} {phase:<5}')\n        for batch in pbar:\n            data = batch['data'].cuda(non_blocking=True)\n            template = batch['template'].cuda(non_blocking=True)\n            audio = batch['audio'].cuda(non_blocking=True)\n            identity = batch['id'].cuda(non_blocking=True)\n            annot_type = batch['annot_type'][0]\n            dataset_name = batch['dataset_name'][0]\n            if args.random_id_prob >= 1.0:\n                identity = identity.fill_(args.identity_num - 1)\n            out_motion, _, _ = model(\n                audio,\n                template,\n                data,\n                style_idx=identity,\n                annot_type=annot_type)\n            rec_loss = loss_module(out_motion, data, annot_type)\n            mouth_metric_L2, mouth_metric_L2_norm = loss_module.mouth_metric(\n                out_motion, data, annot_type)\n\n            for k, v in zip(metric_name_list,\n                            (rec_loss, mouth_metric_L2, mouth_metric_L2_norm)):\n                val_meter[dataset_name][k].update(v.item())\n                val_meter[mixed_dataset_name][k].update(v.item())\n        if getattr(args, 'logger', None) is not None:\n            log_datasetloss(\n                args.logger, epoch, phase, val_meter, only_mix=True)\n        if getattr(args, 'writer', None) is not None:\n            write_to_tensorboard(\n                args.writer, epoch, phase, val_meter, only_mix=False)\n        for dataset_name, sub_metric_dict in val_meter.items():\n            for metric_name in sub_metric_dict.keys():\n                sub_metric_dict[metric_name].reset()\n    return\n"
  },
  {
    "path": "main/train_pca.py",
    "content": "#!/usr/bin/env python\nimport numpy as np\n\nfrom models.pca import PCA\nfrom utils.utils import get_parser, get_audio_encoder_dim\n\n\ndef main():\n    args = get_parser()\n    args.audio_encoder_feature_dim = get_audio_encoder_dim(\n        args.wav2vec2_type)\n    main_worker(args)\n\n\ndef main_worker(cfg):\n    pca_builder = PCA()\n    from dataset.dataset import get_dataset_list\n    train_dataset_list = get_dataset_list(cfg)\n    for dataset in train_dataset_list:\n        data = pca_builder.load_dataset(dataset)\n        pca_info = pca_builder.build_incremental_PCA(\n            data, batch_size=1095, trunc_dim=647)\n        np.savez('mesh_talk_pca.npz', **pca_info)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "models/base_model.py",
    "content": "import numpy as np\nimport torch.nn as nn\n\n\nclass BaseModel(nn.Module):\n    \"\"\"Base class for all models.\"\"\"\n\n    def __init__(self):\n        super(BaseModel, self).__init__()\n        # self.logger = logging.getLogger(self.__class__.__name__)\n\n    def forward(self, *x):\n        \"\"\"Forward pass logic.\n\n        :return: Model output\n        \"\"\"\n        raise NotImplementedError\n\n    def freeze_model(self, do_freeze: bool = True):\n        for param in self.parameters():\n            param.requires_grad = (not do_freeze)\n\n    def summary(self, logger, writer=None):\n        \"\"\"Model summary.\"\"\"\n        model_parameters = filter(lambda p: p.requires_grad, self.parameters())\n        params = sum([np.prod(p.size())\n                      for p in model_parameters]) / 1e6  # Unit is Mega\n        logger.info('===>Trainable parameters: %.3f M' % params)\n        if writer is not None:\n            writer.add_text('Model Summary',\n                            'Trainable parameters: %.3f M' % params)\n"
  },
  {
    "path": "models/hubert.py",
    "content": "import numpy as np\nimport torch\nfrom transformers import HubertModel \nfrom transformers.modeling_outputs import BaseModelOutput\nfrom typing import Optional, Tuple\n\nfrom .model_utils import linear_interpolation\n\n\n# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501\n# initialize our encoder with the pre-trained wav2vec 2.0 weights.\ndef _compute_mask_indices(\n    shape: Tuple[int, int],\n    mask_prob: float,\n    mask_length: int,\n    attention_mask: Optional[torch.Tensor] = None,\n    min_masks: int = 0,\n) -> np.ndarray:\n    bsz, all_sz = shape\n    mask = np.full((bsz, all_sz), False)\n\n    all_num_mask = int(mask_prob * all_sz / float(mask_length) +\n                       np.random.rand())\n    all_num_mask = max(min_masks, all_num_mask)\n    mask_idcs = []\n    padding_mask = attention_mask.ne(1) if attention_mask is not None else None\n    for i in range(bsz):\n        if padding_mask is not None:\n            sz = all_sz - padding_mask[i].long().sum().item()\n            num_mask = int(mask_prob * sz / float(mask_length) +\n                           np.random.rand())\n            num_mask = max(min_masks, num_mask)\n        else:\n            sz = all_sz\n            num_mask = all_num_mask\n\n        lengths = np.full(num_mask, mask_length)\n\n        if sum(lengths) == 0:\n            lengths[0] = min(mask_length, sz - 1)\n\n        min_len = min(lengths)\n        if sz - min_len <= num_mask:\n            min_len = sz - num_mask - 1\n\n        mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)\n        mask_idc = np.asarray([\n            mask_idc[j] + offset for j in range(len(mask_idc))\n            for offset in range(lengths[j])\n        ])\n        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))\n\n    min_len = min([len(m) for m in mask_idcs])\n    for i, mask_idc in enumerate(mask_idcs):\n        if len(mask_idc) > min_len:\n            mask_idc = np.random.choice(mask_idc, min_len, replace=False)\n        mask[i, mask_idc] = True\n    return mask\n\n\nclass HubertModel(HubertModel):\n\n    def __init__(self, config):\n        super().__init__(config)\n\n    def _freeze_parameters(self, do_freeze: bool = True):\n        for param in self.parameters():\n            param.requires_grad = (not do_freeze)\n\n    def forward(\n        self,\n        input_values,\n        attention_mask=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        frame_num=None,\n        interpolate_pos: int = 0,\n    ):\n        self.config.output_attentions = True\n        output_attentions = output_attentions if \\\n            output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else\n            self.config.output_hidden_states)\n        return_dict = return_dict if return_dict is not None \\\n            else self.config.use_return_dict\n\n        conv_features = self.feature_extractor(input_values)\n        hidden_states = conv_features.transpose(1, 2)\n        if interpolate_pos == 0:\n            hidden_states = linear_interpolation(\n                hidden_states, output_len=frame_num)\n\n        if attention_mask is not None:\n            output_lengths = self._get_feat_extract_output_lengths(\n                attention_mask.sum(-1))\n            attention_mask = torch.zeros(\n                hidden_states.shape[:2],\n                dtype=hidden_states.dtype,\n                device=hidden_states.device)\n            attention_mask[(torch.arange(\n                attention_mask.shape[0],\n                device=hidden_states.device), output_lengths - 1)] = 1\n            attention_mask = attention_mask.flip([-1]).cumsum(-1).flip(\n                [-1]).bool()\n\n        hidden_states = self.feature_projection(hidden_states)\n\n        if self.config.apply_spec_augment and self.training:\n            batch_size, sequence_length, hidden_size = hidden_states.size()\n            if self.config.mask_time_prob > 0:\n                mask_time_indices = _compute_mask_indices(\n                    (batch_size, sequence_length),\n                    self.config.mask_time_prob,\n                    self.config.mask_time_length,\n                    attention_mask=attention_mask,\n                    min_masks=2,\n                )\n                hidden_states[torch.from_numpy(\n                    mask_time_indices)] = self.masked_spec_embed.to(\n                        hidden_states.dtype)\n            if self.config.mask_feature_prob > 0:\n                mask_feature_indices = _compute_mask_indices(\n                    (batch_size, hidden_size),\n                    self.config.mask_feature_prob,\n                    self.config.mask_feature_length,\n                )\n                mask_feature_indices = torch.from_numpy(\n                    mask_feature_indices).to(hidden_states.device)\n                hidden_states[mask_feature_indices[:, None].expand(\n                    -1, sequence_length, -1)] = 0\n        encoder_outputs = self.encoder(\n            hidden_states,\n            attention_mask=attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = encoder_outputs[0]\n        if interpolate_pos == 1:\n            hidden_states = linear_interpolation(\n                hidden_states, output_len=frame_num)\n\n        if not return_dict:\n            return (hidden_states, ) + encoder_outputs[1:]\n\n        return BaseModelOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n        )\n"
  },
  {
    "path": "models/model_utils.py",
    "content": "# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/main.py\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n# linear interpolation layer\ndef linear_interpolation(features, output_len: int):\n    features = features.transpose(1, 2)\n    output_features = F.interpolate(\n        features, size=output_len, align_corners=True, mode='linear')\n    return output_features.transpose(1, 2)\n\n\n# Temporal Bias\ndef init_biased_mask(n_head, max_seq_len, period):\n\n    def get_slopes(n):\n\n        def get_slopes_power_of_2(n):\n            start = (2**(-2**-(math.log2(n) - 3)))\n            ratio = start\n            return [start * ratio**i for i in range(n)]\n\n        if math.log2(n).is_integer():\n            return get_slopes_power_of_2(n)\n        else:\n            closest_power_of_2 = 2**math.floor(math.log2(n))\n            return get_slopes_power_of_2(closest_power_of_2) + get_slopes(\n                2 * closest_power_of_2)[0::2][:n - closest_power_of_2]\n\n    slopes = torch.Tensor(get_slopes(n_head))\n    bias = torch.div(\n        torch.arange(start=0, end=max_seq_len,\n                     step=period).unsqueeze(1).repeat(1, period).view(-1),\n        period,\n        rounding_mode='floor')\n    bias = -torch.flip(bias, dims=[0])\n    alibi = torch.zeros(max_seq_len, max_seq_len)\n    for i in range(max_seq_len):\n        alibi[i, :i + 1] = bias[-(i + 1):]\n    alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)\n    mask = (torch.triu(torch.ones(max_seq_len,\n                                  max_seq_len)) == 1).transpose(0, 1)\n    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(\n        mask == 1, float(0.0))\n    mask = mask.unsqueeze(0) + alibi\n    return mask\n\n\n# Alignment Bias\ndef enc_dec_mask(device, T, S):\n    mask = torch.ones(T, S)\n    for i in range(T):\n        mask[i, i] = 0\n    return (mask == 1).to(device=device)\n\n\n# Periodic Positional Encoding\nclass PeriodicPositionalEncoding(nn.Module):\n\n    def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):\n        super(PeriodicPositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n        pe = torch.zeros(period, d_model)\n        position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(\n            torch.arange(0, d_model, 2).float() *\n            (-math.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0)  # (1, period, d_model)\n        repeat_num = (max_seq_len // period) + 1\n        pe = pe.repeat(1, repeat_num, 1)\n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        x = x + self.pe[:, :x.size(1), :]\n        return self.dropout(x)\n\n\nclass TCN(nn.Module):\n\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n        self.first_net = SeqTranslator1D(\n            in_dim, out_dim, min_layers_num=3, residual=True, norm='ln')\n        self.dropout = nn.Dropout(0.1)\n\n    def forward(self, spectrogram, style_embed, time_steps=None):\n\n        spectrogram = spectrogram\n        spectrogram = self.dropout(spectrogram)\n        style_embed = style_embed.unsqueeze(2)\n        style_embed = style_embed.repeat(1, 1, spectrogram.shape[2])\n        spectrogram = torch.cat([spectrogram, style_embed], dim=1)\n        x1 = self.first_net(spectrogram)  # .permute(0, 2, 1)\n        if time_steps is not None:\n            x1 = F.interpolate(\n                x1, size=time_steps, align_corners=False, mode='linear')\n\n        return x1\n\n\nclass SeqTranslator1D(nn.Module):\n    \"\"\"(B, C, T)->(B, C_out, T)\"\"\"\n\n    def __init__(self,\n                 C_in,\n                 C_out,\n                 kernel_size=None,\n                 stride=None,\n                 min_layers_num=None,\n                 residual=True,\n                 norm='bn'):\n        super(SeqTranslator1D, self).__init__()\n\n        conv_layers = nn.ModuleList([])\n        conv_layers.append(\n            ConvNormRelu(\n                in_channels=C_in,\n                out_channels=C_out,\n                type='1d',\n                kernel_size=kernel_size,\n                stride=stride,\n                residual=residual,\n                norm=norm))\n        self.num_layers = 1\n        if min_layers_num is not None and self.num_layers < min_layers_num:\n            while self.num_layers < min_layers_num:\n                conv_layers.append(\n                    ConvNormRelu(\n                        in_channels=C_out,\n                        out_channels=C_out,\n                        type='1d',\n                        kernel_size=kernel_size,\n                        stride=stride,\n                        residual=residual,\n                        norm=norm))\n                self.num_layers += 1\n        self.conv_layers = nn.Sequential(*conv_layers)\n\n    def forward(self, x):\n        return self.conv_layers(x)\n\n\nclass ConvNormRelu(nn.Module):\n    \"\"\"(B,C_in,H,W) -> (B, C_out, H, W) there exist some kernel size that makes\n    the result is not H/s.\n\n    #TODO: there might some problems with residual\n    \"\"\"\n\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 type='1d',\n                 leaky=False,\n                 downsample=False,\n                 kernel_size=None,\n                 stride=None,\n                 padding=None,\n                 p=0,\n                 groups=1,\n                 residual=False,\n                 norm='bn'):\n        \"\"\"conv-bn-relu.\"\"\"\n        super(ConvNormRelu, self).__init__()\n        self.residual = residual\n        self.norm_type = norm\n        # kernel_size = k\n        # stride = s\n\n        if kernel_size is None and stride is None:\n            if not downsample:\n                kernel_size = 3\n                stride = 1\n            else:\n                kernel_size = 4\n                stride = 2\n\n        if padding is None:\n            if isinstance(kernel_size, int) and isinstance(stride, tuple):\n                padding = tuple(int((kernel_size - st) / 2) for st in stride)\n            elif isinstance(kernel_size, tuple) and isinstance(stride, int):\n                padding = tuple(int((ks - stride) / 2) for ks in kernel_size)\n            elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):\n                padding = tuple(\n                    int((ks - st) / 2) for ks, st in zip(kernel_size, stride))\n            else:\n                padding = int((kernel_size - stride) / 2)\n\n        if self.residual:\n            if downsample:\n                if type == '1d':\n                    self.residual_layer = nn.Sequential(\n                        nn.Conv1d(\n                            in_channels=in_channels,\n                            out_channels=out_channels,\n                            kernel_size=kernel_size,\n                            stride=stride,\n                            padding=padding))\n                elif type == '2d':\n                    self.residual_layer = nn.Sequential(\n                        nn.Conv2d(\n                            in_channels=in_channels,\n                            out_channels=out_channels,\n                            kernel_size=kernel_size,\n                            stride=stride,\n                            padding=padding))\n            else:\n                if in_channels == out_channels:\n                    self.residual_layer = nn.Identity()\n                else:\n                    if type == '1d':\n                        self.residual_layer = nn.Sequential(\n                            nn.Conv1d(\n                                in_channels=in_channels,\n                                out_channels=out_channels,\n                                kernel_size=kernel_size,\n                                stride=stride,\n                                padding=padding))\n                    elif type == '2d':\n                        self.residual_layer = nn.Sequential(\n                            nn.Conv2d(\n                                in_channels=in_channels,\n                                out_channels=out_channels,\n                                kernel_size=kernel_size,\n                                stride=stride,\n                                padding=padding))\n\n        in_channels = in_channels * groups\n        out_channels = out_channels * groups\n        if type == '1d':\n            self.conv = nn.Conv1d(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                padding=padding,\n                groups=groups)\n            self.norm = nn.BatchNorm1d(out_channels)\n            self.dropout = nn.Dropout(p=p)\n        elif type == '2d':\n            self.conv = nn.Conv2d(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                padding=padding,\n                groups=groups)\n            self.norm = nn.BatchNorm2d(out_channels)\n            self.dropout = nn.Dropout2d(p=p)\n        if norm == 'gn':\n            self.norm = nn.GroupNorm(2, out_channels)\n        elif norm == 'ln':\n            self.norm = nn.LayerNorm(out_channels)\n        if leaky:\n            self.relu = nn.LeakyReLU(negative_slope=0.2)\n        else:\n            self.relu = nn.ReLU()\n\n    def forward(self, x, **kwargs):\n        if self.norm_type == 'ln':\n            out = self.dropout(self.conv(x))\n            out = self.norm(out.transpose(1, 2)).transpose(1, 2)\n        else:\n            out = self.norm(self.dropout(self.conv(x)))\n        if self.residual:\n            residual = self.residual_layer(x)\n            out += residual\n        return self.relu(out)\n"
  },
  {
    "path": "models/pca.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\n\n\nclass PCALayer(nn.Module):\n\n    def __init__(self, pca_info_path: str):\n        super().__init__()\n        pca_info = np.load(pca_info_path)\n        feat_to_data = pca_info['components_to_data'].astype(np.float32)\n        data_to_feat = feat_to_data.T\n        data_bias = pca_info['original_data_mean'].astype(np.float32)\n        feat_mean = pca_info['data_components_mean'].astype(np.float32)\n        feat_std = pca_info['data_components_std'].astype(np.float32)\n        self.register_buffer('feat_to_data', torch.from_numpy(feat_to_data))\n        self.register_buffer('data_to_feat', torch.from_numpy(data_to_feat))\n        self.register_buffer('data_bias', torch.from_numpy(data_bias))\n        self.register_buffer('feat_mean', torch.from_numpy(feat_mean))\n        self.register_buffer('feat_std', torch.from_numpy(feat_std))\n        self.pca_dim = len(feat_mean)\n\n    def encode(\n        self,\n        x: torch.Tensor,\n        pca_dim: int = None,\n    ):\n        x = x - self.data_bias\n        feat = x @ (self.data_to_feat[:, :pca_dim])\n        return feat\n\n    def decode(self, feat: torch.Tensor, pca_dim: int = None):\n        data = feat[..., :pca_dim] @ (self.feat_to_data[:pca_dim])\n        data = data + self.data_bias\n        return data\n\n\nclass PCA:\n\n    def __init__(self, trunc_dim: int = None):\n        super().__init__()\n\n    def buld_pca_for_dataset(self, dataset, trunc_dim: int = 1024):\n        annot_array = self.load_dataset(dataset)\n        self.build_pca(annot_array, trunc_dim)\n\n    def load_dataset(self, dataset):\n        annot_list = []\n        indices = np.arange(len(dataset))\n        for i in indices:\n            data_dict = dataset.__getitem__(i)\n            annot_list.append(data_dict['data'] - data_dict['template'])\n        annot_array = np.concatenate(annot_list, axis=0)\n        return annot_array\n\n    def build_incremental_PCA(\n        self,\n        in_data: np.ndarray,\n        trunc_dim=None,\n        batch_size=1024,\n    ):\n        indices = np.arange(len(in_data))\n        np.random.shuffle(indices)\n        in_data = in_data[indices]\n        if in_data.dtype != np.float32:\n            in_data = in_data.astype(np.float32)\n        from sklearn.decomposition import IncrementalPCA\n        ipca = IncrementalPCA(n_components=trunc_dim, batch_size=batch_size)\n        data_mean = in_data.mean(0)\n        in_data = in_data - data_mean\n        in_data_slices = np.split(\n            in_data, np.arange(batch_size, len(in_data), batch_size))\n        for batch_data in in_data_slices:\n            if len(batch_data) != batch_size:\n                continue\n            ipca.partial_fit(batch_data)\n\n        S = ipca.singular_values_\n        components_to_data = ipca.components_\n        explained_variance_ratio = ipca.explained_variance_ratio_\n        data_components = ipca.transform(in_data)\n        data_components_mean = data_components.mean(0)\n        data_components_std = data_components.std(0)\n        ret_dict = {\n            'explained_variance_ratio': explained_variance_ratio,\n            'components_to_data': components_to_data,\n            'original_data_mean': data_mean,\n            'S': S,\n            'data_components_mean': data_components_mean,\n            'data_components_std': data_components_std\n        }\n        return ret_dict\n\n    def build_pca(\n        self,\n        in_data: np.ndarray,\n        trunc_dim=None,\n        save_compactness=True,\n        check_svd=True,\n    ):\n        ret_dict = {}\n        in_data = torch.Tensor(in_data).cuda()\n        count, in_dim = in_data.shape\n        mean_data = in_data.mean(dim=0)\n        in_data_center = in_data - mean_data  # (count, in_dim)\n        in_data_center = in_data_center.cpu()\n        #  U (count, count), S(count, ), Vt (count, in_dim)\n        U, S, Vt = torch.linalg.svd(in_data_center, full_matrices=False)\n\n        if check_svd:\n            S_mat = torch.diag(S)\n            rebuilt = U @ S_mat @ Vt\n            is_precise = torch.allclose(rebuilt, in_data_center)\n            print('svd rebuilt all close ?', is_precise)\n\n        var_S = S**2\n        if save_compactness:\n            compact_vector = var_S / var_S.sum()\n            compact_value = compact_vector[0:trunc_dim].sum()\n            ret_dict['compactness, ', compact_value]\n\n        pca_bases = Vt[0:trunc_dim].reshape(trunc_dim, -1)\n        ret_dict['pca_bases'] = pca_bases\n        ret_dict['pca_std'] = S[0:trunc_dim]\n        ret_dict['mean_data'] = mean_data\n        return ret_dict\n"
  },
  {
    "path": "models/unitalker.py",
    "content": "import os.path as osp\nimport torch.nn as nn\n\nfrom dataset.dataset_config import dataset_config\nfrom models.wav2vec2 import Wav2Vec2Model\nfrom .base_model import BaseModel\n# from models.hubert import HubertModel\nfrom models.wavlm import WavLMModel\nclass OutHead(BaseModel):\n\n    def __init__(self, args, out_dim: int):\n        super().__init__()\n        head_layer = []\n\n        in_dim = args.decoder_dimension\n        for i in range(args.headlayer - 1):\n            head_layer.append(nn.Linear(in_dim, in_dim))\n            head_layer.append(nn.Tanh())\n        head_layer.append(nn.Linear(in_dim, out_dim))\n        self.head_layer = nn.Sequential(*head_layer)\n\n    def forward(self, x):\n        return self.head_layer(x)\n\n\nclass UniTalker(BaseModel):\n\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n        if 'wav2vec2' in args.audio_encoder_repo:\n            self.audio_encoder = Wav2Vec2Model.from_pretrained(\n                args.audio_encoder_repo)\n        elif 'wavlm' in args.audio_encoder_repo:\n            self.audio_encoder = WavLMModel.from_pretrained(\n                args.audio_encoder_repo)   \n        else:\n            raise ValueError(\"wrong audio_encoder_repo\")\n        self.audio_encoder.feature_extractor._freeze_parameters()\n        if args.freeze_wav2vec:\n            self.audio_encoder._freeze_wav2vec2_parameters()\n\n        if args.decoder_type == 'conv':\n            from .unitalker_decoder import UniTalkerDecoderTCN as Decoder\n        elif args.decoder_type == 'transformer':\n            from .unitalker_decoder import \\\n                UniTalkerDecoderTransformer as Decoder\n        else:\n            ValueError('unknown decoder type ')\n        self.decoder = Decoder(args)\n\n        if args.use_pca:\n            pca_layer_dict = {}\n            from models.pca import PCALayer\n            for dataset_name in args.dataset:\n                if dataset_config[dataset_name]['pca']:\n                    annot_type = dataset_config[dataset_name]['annot_type']\n                    dirname = dataset_config[dataset_name]['dirname']\n                    pca_path = osp.join(args.data_root, dirname, 'pca.npz')\n                    pca_layer_dict[annot_type] = PCALayer(pca_path)\n            self.pca_layer_dict = nn.ModuleDict(pca_layer_dict)\n        else:\n            self.pca_layer_dict = {}\n\n        self.pca_dim = args.pca_dim\n        self.use_pca = args.use_pca\n        self.interpolate_pos = args.interpolate_pos\n\n        out_head_dict = {}\n        for dataset_name in args.dataset:\n            annot_type = dataset_config[dataset_name]['annot_type']\n            out_dim = dataset_config[dataset_name]['annot_dim']\n            if (args.use_pca is True) and (annot_type in self.pca_layer_dict):\n                pca_dim = self.pca_layer_dict[annot_type].pca_dim\n                out_dim = min(args.pca_dim, out_dim, pca_dim)\n            if args.headlayer == 1:\n                out_projection = nn.Linear(args.decoder_dimension, out_dim)\n                nn.init.constant_(out_projection.weight, 0)\n                nn.init.constant_(out_projection.bias, 0)\n            else:\n                out_projection = OutHead(args, out_dim)\n            out_head_dict[annot_type] = out_projection\n        self.out_head_dict = nn.ModuleDict(out_head_dict)\n        self.identity_num = args.identity_num\n        return\n\n    def forward(self,\n                audio,\n                template,\n                face_motion,\n                style_idx,\n                annot_type: str,\n                fps: float = None):\n        if face_motion is not None:\n            frame_num = face_motion.shape[1]\n        else:\n            frame_num = round(audio.shape[-1] / 16000 * fps)\n\n        hidden_states = self.audio_encoder(\n            audio, frame_num=frame_num, interpolate_pos=self.interpolate_pos)\n        hidden_states = hidden_states.last_hidden_state\n        decoder_out = self.decoder(hidden_states, style_idx, frame_num)\n\n        if (not self.use_pca) or (annot_type not in self.pca_layer_dict):\n            out_motion = self.out_head_dict[annot_type](decoder_out)\n            out_motion = out_motion + template\n            out_pca, gt_pca = None, None\n        else:\n            out_pca = self.out_head_dict[annot_type](decoder_out)\n            out_motion = self.pca_layer_dict[annot_type].decode(\n                out_pca, self.pca_dim)\n            out_motion = out_motion + template\n            if face_motion is not None:\n                gt_pca = self.pca_layer_dict[annot_type].encode(\n                    face_motion - template, self.pca_dim)\n            else:\n                gt_pca = None\n\n        return out_motion, out_pca, gt_pca\n"
  },
  {
    "path": "models/unitalker_decoder.py",
    "content": "# yapf: disable\nimport torch\nimport torch.nn as nn\n\nfrom .base_model import BaseModel\nfrom .model_utils import (\n    TCN, PeriodicPositionalEncoding, enc_dec_mask, init_biased_mask,\n    linear_interpolation,\n)\n\n# yapf: enable\n\n\nclass UniTalkerDecoderTCN(BaseModel):\n\n    def __init__(self, args) -> None:\n        super().__init__()\n        self.learnable_style_emb = nn.Embedding(args.identity_num, 64)\n        in_dim = args.decoder_dimension\n        out_dim = args.decoder_dimension\n        self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, in_dim)\n        self.tcn = TCN(in_dim + 64, out_dim)\n        self.interpolate_pos = args.interpolate_pos\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                style_idx: torch.Tensor,\n                frame_num: int = None):\n        batch_size = len(hidden_states)\n        if style_idx is None:\n            style_embedding = self.learnable_style_emb.weight[-1].unsqueeze(\n                0).repeat(batch_size, 1)\n        else:\n            style_embedding = self.learnable_style_emb(style_idx)\n        feature = self.audio_feature_map(hidden_states).transpose(1, 2)\n        feature = self.tcn(feature, style_embedding)\n        feature = feature.transpose(1, 2)\n        if self.interpolate_pos == 2:\n            feature = linear_interpolation(feature, output_len=frame_num)\n        return feature\n\n\nclass UniTalkerDecoderTransformer(BaseModel):\n\n    def __init__(self, args) -> None:\n        super().__init__()\n        out_dim = args.decoder_dimension\n        self.learnable_style_emb = nn.Embedding(args.identity_num, out_dim)\n        self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, out_dim)\n        self.PPE = PeriodicPositionalEncoding(\n            out_dim, period=args.period, max_seq_len=3000)\n        self.biased_mask = init_biased_mask(\n            n_head=4, max_seq_len=3000, period=args.period)\n        decoder_layer = nn.TransformerDecoderLayer(\n            d_model=out_dim,\n            nhead=4,\n            dim_feedforward=2 * out_dim,\n            batch_first=True)\n        self.transformer_decoder = nn.TransformerDecoder(\n            decoder_layer, num_layers=1)\n        self.interpolate_pos = args.interpolate_pos\n\n    def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,\n                frame_num: int):\n        obj_embedding = self.learnable_style_emb(style_idx)\n        obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1)\n        hidden_states = self.audio_feature_map(hidden_states)\n        style_input = self.PPE(obj_embedding)\n        tgt_mask = self.biased_mask[:, :style_input.shape[1], :style_input.\n                                    shape[1]].clone().detach().to(\n                                        device=style_input.device)\n        memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1],\n                                   frame_num)\n        feat_out = self.transformer_decoder(\n            style_input,\n            hidden_states,\n            tgt_mask=tgt_mask,\n            memory_mask=memory_mask)\n        if self.interpolate_pos == 2:\n            feat_out = linear_interpolation(feat_out, output_len=frame_num)\n        return feat_out\n"
  },
  {
    "path": "models/wav2vec2.py",
    "content": "import numpy as np\nimport torch\nfrom transformers import Wav2Vec2Model\nfrom transformers.modeling_outputs import BaseModelOutput\nfrom typing import Optional, Tuple\n\nfrom .model_utils import linear_interpolation\n\n\n# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501\n# initialize our encoder with the pre-trained wav2vec 2.0 weights.\ndef _compute_mask_indices(\n    shape: Tuple[int, int],\n    mask_prob: float,\n    mask_length: int,\n    attention_mask: Optional[torch.Tensor] = None,\n    min_masks: int = 0,\n) -> np.ndarray:\n    bsz, all_sz = shape\n    mask = np.full((bsz, all_sz), False)\n\n    all_num_mask = int(mask_prob * all_sz / float(mask_length) +\n                       np.random.rand())\n    all_num_mask = max(min_masks, all_num_mask)\n    mask_idcs = []\n    padding_mask = attention_mask.ne(1) if attention_mask is not None else None\n    for i in range(bsz):\n        if padding_mask is not None:\n            sz = all_sz - padding_mask[i].long().sum().item()\n            num_mask = int(mask_prob * sz / float(mask_length) +\n                           np.random.rand())\n            num_mask = max(min_masks, num_mask)\n        else:\n            sz = all_sz\n            num_mask = all_num_mask\n\n        lengths = np.full(num_mask, mask_length)\n\n        if sum(lengths) == 0:\n            lengths[0] = min(mask_length, sz - 1)\n\n        min_len = min(lengths)\n        if sz - min_len <= num_mask:\n            min_len = sz - num_mask - 1\n\n        mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)\n        mask_idc = np.asarray([\n            mask_idc[j] + offset for j in range(len(mask_idc))\n            for offset in range(lengths[j])\n        ])\n        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))\n\n    min_len = min([len(m) for m in mask_idcs])\n    for i, mask_idc in enumerate(mask_idcs):\n        if len(mask_idc) > min_len:\n            mask_idc = np.random.choice(mask_idc, min_len, replace=False)\n        mask[i, mask_idc] = True\n    return mask\n\n\nclass Wav2Vec2Model(Wav2Vec2Model):\n\n    def __init__(self, config):\n        super().__init__(config)\n\n    def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):\n        for param in self.parameters():\n            param.requires_grad = (not do_freeze)\n\n    def forward(\n        self,\n        input_values,\n        attention_mask=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        frame_num=None,\n        interpolate_pos: int = 0,\n    ):\n        self.config.output_attentions = True\n        output_attentions = output_attentions if \\\n            output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else\n            self.config.output_hidden_states)\n        return_dict = return_dict if return_dict is not None \\\n            else self.config.use_return_dict\n\n        conv_features = self.feature_extractor(input_values)\n        hidden_states = conv_features.transpose(1, 2)\n        if interpolate_pos == 0:\n            hidden_states = linear_interpolation(\n                hidden_states, output_len=frame_num)\n\n        if attention_mask is not None:\n            output_lengths = self._get_feat_extract_output_lengths(\n                attention_mask.sum(-1))\n            attention_mask = torch.zeros(\n                hidden_states.shape[:2],\n                dtype=hidden_states.dtype,\n                device=hidden_states.device)\n            attention_mask[(torch.arange(\n                attention_mask.shape[0],\n                device=hidden_states.device), output_lengths - 1)] = 1\n            attention_mask = attention_mask.flip([-1]).cumsum(-1).flip(\n                [-1]).bool()\n\n        hidden_states, _ = self.feature_projection(hidden_states)\n\n        if self.config.apply_spec_augment and self.training:\n            batch_size, sequence_length, hidden_size = hidden_states.size()\n            if self.config.mask_time_prob > 0:\n                mask_time_indices = _compute_mask_indices(\n                    (batch_size, sequence_length),\n                    self.config.mask_time_prob,\n                    self.config.mask_time_length,\n                    attention_mask=attention_mask,\n                    min_masks=2,\n                )\n                hidden_states[torch.from_numpy(\n                    mask_time_indices)] = self.masked_spec_embed.to(\n                        hidden_states.dtype)\n            if self.config.mask_feature_prob > 0:\n                mask_feature_indices = _compute_mask_indices(\n                    (batch_size, hidden_size),\n                    self.config.mask_feature_prob,\n                    self.config.mask_feature_length,\n                )\n                mask_feature_indices = torch.from_numpy(\n                    mask_feature_indices).to(hidden_states.device)\n                hidden_states[mask_feature_indices[:, None].expand(\n                    -1, sequence_length, -1)] = 0\n        encoder_outputs = self.encoder(\n            hidden_states,\n            attention_mask=attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = encoder_outputs[0]\n        if interpolate_pos == 1:\n            hidden_states = linear_interpolation(\n                hidden_states, output_len=frame_num)\n\n        if not return_dict:\n            return (hidden_states, ) + encoder_outputs[1:]\n\n        return BaseModelOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n        )\n"
  },
  {
    "path": "models/wavlm.py",
    "content": "import numpy as np\nimport torch\nfrom transformers import WavLMModel\nfrom transformers.modeling_outputs import Wav2Vec2BaseModelOutput\nfrom typing import Optional, Tuple, Union\n\n\n\nfrom .model_utils import linear_interpolation\n\n\n# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501\n# initialize our encoder with the pre-trained wav2vec 2.0 weights.\n\n\nclass WavLMModel(WavLMModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n    def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):\n        for param in self.parameters():\n            param.requires_grad = (not do_freeze)\n\n    def forward(\n        self,\n        input_values: Optional[torch.Tensor],\n        attention_mask: Optional[torch.Tensor] = None,\n        mask_time_indices: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        frame_num=None,\n        interpolate_pos: int = 0,\n    ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n            )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        extract_features = self.feature_extractor(input_values)\n        extract_features = extract_features.transpose(1, 2)\n\n        if interpolate_pos == 0:\n            extract_features = linear_interpolation(\n                extract_features, output_len=frame_num)\n            \n        if attention_mask is not None:\n            # compute reduced attention_mask corresponding to feature vectors\n            attention_mask = self._get_feature_vector_attention_mask(\n                extract_features.shape[1], attention_mask, add_adapter=False\n            )\n\n        hidden_states, extract_features = self.feature_projection(extract_features)\n        hidden_states = self._mask_hidden_states(\n            hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask\n        )\n\n        encoder_outputs = self.encoder(\n            hidden_states,\n            attention_mask=attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = encoder_outputs[0]\n\n        if interpolate_pos == 1:\n            hidden_states = linear_interpolation(\n                hidden_states, output_len=frame_num)\n            \n        if self.adapter is not None:\n            hidden_states = self.adapter(hidden_states)\n\n        if not return_dict:\n            return (hidden_states, extract_features) + encoder_outputs[1:]\n\n        return Wav2Vec2BaseModelOutput(\n            last_hidden_state=hidden_states,\n            extract_features=extract_features,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n        )"
  },
  {
    "path": "utils/config.py",
    "content": "# -----------------------------------------------------------------------------\n# Functions for parsing args\n# -----------------------------------------------------------------------------\nimport copy\nimport os\nimport yaml\nfrom ast import literal_eval\n\n\nclass CfgNode(dict):\n    \"\"\"CfgNode represents an internal node in the configuration tree.\n\n    It's a simple dict-like container that allows for attribute-based access to\n    keys.\n    \"\"\"\n\n    def __init__(self, init_dict=None, key_list=None, new_allowed=False):\n        # Recursively convert nested dictionaries in init_dict into CfgNodes\n        init_dict = {} if init_dict is None else init_dict\n        key_list = [] if key_list is None else key_list\n        for k, v in init_dict.items():\n            if type(v) is dict:\n                # Convert dict to CfgNode\n                init_dict[k] = CfgNode(v, key_list=key_list + [k])\n        super(CfgNode, self).__init__(init_dict)\n\n    def __getattr__(self, name):\n        if name in self:\n            return self[name]\n        else:\n            raise AttributeError(name)\n\n    def __setattr__(self, name, value):\n        self[name] = value\n\n    def __str__(self):\n\n        def _indent(s_, num_spaces):\n            s = s_.split('\\n')\n            if len(s) == 1:\n                return s_\n            first = s.pop(0)\n            s = [(num_spaces * ' ') + line for line in s]\n            s = '\\n'.join(s)\n            s = first + '\\n' + s\n            return s\n\n        r = ''\n        s = []\n        for k, v in sorted(self.items()):\n            separator = '\\n' if isinstance(v, CfgNode) else ' '\n            attr_str = '{}:{}{}'.format(str(k), separator, str(v))\n            attr_str = _indent(attr_str, 2)\n            s.append(attr_str)\n        r += '\\n'.join(s)\n        return r\n\n    def __repr__(self):\n        return '{}({})'.format(self.__class__.__name__,\n                               super(CfgNode, self).__repr__())\n\n\ndef load_cfg_from_cfg_file(file):\n    cfg = {}\n    assert os.path.isfile(file) and file.endswith('.yaml'), \\\n        '{} is not a yaml file'.format(file)\n\n    with open(file, 'r') as f:\n        cfg_from_file = yaml.safe_load(f)\n\n    for key in cfg_from_file:\n        for k, v in cfg_from_file[key].items():\n            cfg[k] = v\n\n    cfg = CfgNode(cfg)\n    return cfg\n\n\ndef merge_cfg_from_list(cfg, cfg_list):\n    new_cfg = copy.deepcopy(cfg)\n    assert len(cfg_list) % 2 == 0\n    for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):\n        subkey = full_key.split('.')[-1]\n        assert subkey in cfg, 'Non-existent key: {}'.format(full_key)\n        value = _decode_cfg_value(v)\n        value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,\n                                                 full_key)\n        setattr(new_cfg, subkey, value)\n\n    return new_cfg\n\n\ndef _decode_cfg_value(v):\n    \"\"\"Decodes a raw config value (e.g., from a yaml config files or command\n    line argument) into a Python object.\"\"\"\n    # All remaining processing is only applied to strings\n    if not isinstance(v, str):\n        return v\n    # Try to interpret `v` as a:\n    #   string, number, tuple, list, dict, boolean, or None\n    try:\n        v = literal_eval(v)\n    # The following two excepts allow v to pass through when it represents a\n    # string.\n    #\n    # Longer explanation:\n    # The type of v is always a string (before calling literal_eval), but\n    # sometimes it *represents* a string and other times a data structure, like\n    # a list. In the case that v represents a string, what we got back from the\n    # yaml parser is 'foo' *without quotes* (so, not '\"foo\"'). literal_eval is\n    # ok with '\"foo\"', but will raise a ValueError if given 'foo'. In other\n    # cases, like paths (v = 'foo/bar' and not v = '\"foo/bar\"'), literal_eval\n    # will raise a SyntaxError.\n    except ValueError:\n        pass\n    except SyntaxError:\n        pass\n    return v\n\n\ndef _check_and_coerce_cfg_value_type(replacement, original, key, full_key):\n    \"\"\"Checks that `replacement`, which is intended to replace `original` is of\n    the right type.\n\n    The type is correct if it matches exactly or is one of a few cases in which\n    the type can be easily coerced.\n    \"\"\"\n    original_type = type(original)\n    replacement_type = type(replacement)\n\n    # The types must match (with some exceptions)\n    if replacement_type == original_type or original is None:\n        return replacement\n\n    # types match from_type and to_type\n    def conditional_cast(from_type, to_type):\n        if replacement_type == from_type and original_type == to_type:\n            return True, to_type(replacement)\n        else:\n            return False, None\n\n    # Conditionally casts\n    # list <-> tuple\n    casts = [(tuple, list), (list, tuple)]\n    # For py2: allow converting from str (bytes) to a unicode string\n    try:\n        casts.append((str, unicode))  # noqa: F821\n    except Exception:\n        pass\n\n    for (from_type, to_type) in casts:\n        converted, converted_value = conditional_cast(from_type, to_type)\n        if converted:\n            return converted_value\n\n    raise ValueError(\n        'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '\n        'key: {}'.format(original_type, replacement_type, original,\n                         replacement, full_key))\n"
  },
  {
    "path": "utils/utils.py",
    "content": "import argparse\nimport logging\nimport numpy as np\nimport os\nimport random\nimport torch\nfrom copy import deepcopy\n\nfrom . import config\n\n\ndef get_audio_encoder_dim(audio_encoder: str):\n    if audio_encoder == 'microsoft/wavlm-base-plus':\n        return 768\n    elif audio_encoder == 'facebook/wav2vec2-large-xlsr-53':\n        return  1024\n    else:\n        raise ValueError(\"wrong audio_encoder\")\n\ndef filer_list(in_list: list, key: str):\n    ret_list = []\n    for line in in_list:\n        if line.startswith(key):\n            ret_list.append(line)\n    return ret_list\n\n\ndef count_checkpoint_params(ckpt: dict):\n    params = 0\n    pca_params = 0\n    for k, v in ckpt.items():\n        if 'pca' in k:\n            print(k)\n            pca_params += np.prod(v.size())\n        params += np.prod(v.size())\n    return params, pca_params\n\n\ndef read_obj(in_path):\n    with open(in_path, 'r') as obj_file:\n        # Read the lines of the OBJ file\n        lines = obj_file.readlines()\n\n    # Initialize empty lists for vertices and faces\n    verts = []\n    faces = []\n    for line in lines:\n        line = line.strip()  # Remove leading/trailing whitespace\n        elements = line.split()  # Split the line into elements\n\n        if len(elements) == 0:\n            continue  # Skip empty lines\n\n        # Check the type of line (vertex or face)\n        if elements[0] == 'v':\n            # Vertex line\n            x, y, z = map(float,\n                          elements[1:4])  # Extract the vertex coordinates\n            verts.append((x, y, z))  # Add the vertex to the list\n        elif elements[0] == 'f':\n            # Face line\n            face_indices = [\n                int(index.split('/')[0]) for index in elements[1:]\n            ]  # Extract the vertex indices\n            faces.append(face_indices)  # Add the face to the list\n    verts = np.array(verts)\n    faces = np.array(faces)\n    if faces.min() == 1:\n        faces = faces - 1\n    return verts, faces\n\n\ndef get_logger(log_file: str):\n    logger_name = 'main-logger'\n    logger = logging.getLogger(logger_name)\n    logger.setLevel(logging.INFO)\n    handler = logging.FileHandler(log_file, mode='w')\n    fmt = '[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d]=>%(message)s'  # noqa: E501\n    handler.setFormatter(logging.Formatter(fmt))\n    for hdl in logger.handlers:\n        logger.removeHandler(hdl)\n    logger.addHandler(handler)\n    return logger\n\n\ndef get_template_dict(data_root, annot_type: str):\n    import pickle\n    if annot_type == 'BIWI_23370_vertices':\n        template_pkl_path = os.path.join(data_root, 'BIWI', 'templates.pkl')\n    elif annot_type == 'FLAME_5023_vertices':\n        template_pkl_path = os.path.join(data_root, 'vocaset', 'templates.pkl')\n    with open(template_pkl_path, 'rb') as f:\n        template_dict = pickle.load(f, encoding='latin')\n    return template_dict\n\n\ndef filter_unitalker_state_dict(in_dict):\n    out_dict = {}\n    for k, v in in_dict.items():\n        if not 'pca_layer_dict' in k:\n            out_dict[k] = v\n    return out_dict \n\ndef get_template_verts(data_root, dataset_name: str, subject: int):\n    from dataset.dataset_config import dataset_config\n    template_path = os.path.join(data_root, dataset_config[dataset_name]['dirname'], 'id_template.npy')\n    return np.load(template_path)[subject]\n\n\ndef load_ckpt(model, ckpt_path, re_init_decoder_and_head: bool = False):\n    checkpoint = torch.load(ckpt_path, map_location='cpu')\n    id_embedding_key = 'decoder.learnable_style_emb.weight'\n    if len(checkpoint[id_embedding_key]) != len(\n            model.state_dict()[id_embedding_key]):\n        target_id_num = len(model.state_dict()[id_embedding_key])\n        checkpoint[id_embedding_key] = checkpoint[id_embedding_key][-1][\n            None].repeat(target_id_num, 1)\n    if re_init_decoder_and_head is True:\n        ckpt_keys = list(checkpoint.keys())\n        encoder_keys = filer_list(ckpt_keys, 'audio_encoder')\n        reinit_keys = list(set(ckpt_keys) - set(encoder_keys))\n        for k in reinit_keys:\n            if k in model.state_dict():\n                checkpoint[k] = model.state_dict()[k]\n    model.load_state_dict(checkpoint, strict=False)\n    return\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value.\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n    def get_avg(self, ):\n        return self.avg\n\n\ndef get_average_meter_dict(args, phase):\n    if phase == 'train':\n        metric_name_list = ['rec_L2_loss', 'pca_loss']\n    elif phase == 'val' or phase == 'test':\n        metric_name_list = ['rec_L2_loss', 'LVE', 'LVD']\n    else:\n        raise ValueError('wrong phase for average meter')\n    average_meter_dict_template = {}\n    for k in metric_name_list:\n        average_meter_dict_template[k] = AverageMeter()\n    mixed_dataset_name = 'mixed_dataset'\n    dataset_name_list = args.dataset + [mixed_dataset_name]\n    average_meter_dict = {}\n    for dataset_name in dataset_name_list:\n        average_meter_dict[dataset_name] = deepcopy(\n            average_meter_dict_template)\n    return average_meter_dict, metric_name_list\n\n\ndef log_datasetloss(\n    logger,\n    epoch: int,\n    phase: str,\n    dataset_loss_dict: dict,\n    only_mix: bool = True,\n):\n    base_str = f'{phase.upper():<5} Epoch: {epoch:03d} '\n    for dataset_name, loss_info in dataset_loss_dict.items():\n        if only_mix and dataset_name != 'mixed_dataset':\n            continue\n        for loss_type, avg_meter in loss_info.items():\n            base_str += f'{loss_type}_{dataset_name}:{avg_meter.avg: .8f}, '\n    logger.info(base_str)\n    return base_str\n\n\ndef write_to_tensorboard(writer,\n                         epoch: int,\n                         phase: str,\n                         dataset_loss_dict: dict,\n                         only_mix: bool = False):\n    for dataset_name, loss_info in dataset_loss_dict.items():\n        if only_mix and dataset_name != 'mixed_dataset':\n            continue\n        for loss_type, avg_meter in loss_info.items():\n            key = f'{phase}/{loss_type}_{dataset_name}'\n            writer.add_scalar(key, avg_meter.avg, epoch)\n    return 0\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser(description=' ')\n    parser.add_argument(\n        '--config',\n        type=str,\n        default='config/unitalker.yaml',\n        help='config file')\n    parser.add_argument(\n        '--condition_id', type=str, default='common', help='condition id')\n    parser.add_argument(\n        'opts', help=' ', default=None, nargs=argparse.REMAINDER)\n    args = parser.parse_args()\n    assert args.config is not None\n    cfg = config.load_cfg_from_cfg_file(args.config)\n    if args.opts is not None:\n        cfg = config.merge_cfg_from_list(cfg, args.opts)\n    cfg.condition_id = args.condition_id\n    cfg.dataset = cfg.dataset.split(',')\n    cfg.demo_dataset = cfg.demo_dataset.split(',')\n    if isinstance(cfg.duplicate_list, str):\n        cfg.duplicate_list = cfg.duplicate_list.split(',')\n        cfg.duplicate_list = [int(i) for i in cfg.duplicate_list]\n    return cfg\n\n\ndef seed_everything(seed: int = 42):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n    return\n"
  }
]