Full Code of X-niper/UniTalker for AI

main 21481130f126 cached
23 files
117.9 KB
28.5k tokens
150 symbols
1 requests
Download .txt
Repository: X-niper/UniTalker
Branch: main
Commit: 21481130f126
Files: 23
Total size: 117.9 KB

Directory structure:
gitextract_hq8r3so5/

├── .gitignore
├── LICENSE
├── README.md
├── config/
│   └── unitalker.yaml
├── dataset/
│   ├── data_item.py
│   ├── dataset.py
│   └── dataset_config.py
├── loss/
│   └── loss.py
├── main/
│   ├── demo.py
│   ├── render.py
│   ├── train.py
│   ├── train_eval_loop.py
│   └── train_pca.py
├── models/
│   ├── base_model.py
│   ├── hubert.py
│   ├── model_utils.py
│   ├── pca.py
│   ├── unitalker.py
│   ├── unitalker_decoder.py
│   ├── wav2vec2.py
│   └── wavlm.py
└── utils/
    ├── config.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized files
*__pycache__/
*.py[cod]


# Distribution / packaging
dist/
build/
*.egg-info/
*.egg

# Virtual environments
venv/
env/
.venv/
.pip/
pipenv/

# IDE specific files
.idea/
.vscode/

# Logs and temporary files
*.log
*.swp
*.swo
*.tmp

# Miscellaneous
.DS_Store
records.txt
binary_resources
do_experiment.py
small_tools
scancel_job.py
results
template_resources


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================

# UniTalker: Scaling up Audio-Driven 3D Facial Animation through A Unified Model [ECCV2024]

## Useful Links

<div align="center">
    <a href="https://x-niper.github.io/projects/UniTalker/" class="button"><b>[Homepage]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
    <a href="https://arxiv.org/abs/2408.00762" class="button"><b>[arXiv]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
    <a href="https://www.youtube.com/watch?v=oUUh67ECzig" class="button"><b>[Video]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
</div>

<div align="center">
<img src="https://github.com/X-niper/X-niper.github.io/blob/master/projects/UniTalker/assets/unitalker_architecture.png" alt="UniTalker Architecture" width="80%" align=center/>
</div>


> UniTalker generates realistic facial motion from different audio domains, including clean and noisy voices in various languages, text-to-speech-generated audios, and even noisy songs accompanied by back-ground music.
> 
> UniTalker can output multiple annotations.
> 
> For datasets with new annotations, one can simply plug new heads into UniTalker and train it with existing datasets or solely with new ones, avoiding retopology.

## Installation
### Environment
- Linux
- Python 3.10
- Pytorch 2.2.0
- CUDA 12.1
- transformers 4.39.3
- Pytorch3d 0.7.7 (Optional: just for rendering the results)

```bash
  conda create -n unitalker python==3.10
  conda activate unitalker
  conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia
  pip install transformers librosa tensorboardX smplx chumpy numpy==1.23.5 opencv-python
```

## Inference

### Download checkpoints, PCA models and template resources

[UniTalker-B-[D0-D7]](https://drive.google.com/file/d/1PmF8I6lyo0_64-NgeN5qIQAX6Bg0yw44/view?usp=sharing): The base model in paper. Download it and place it in `./pretrained_models` .

[UniTalker-L-[D0-D7]](https://drive.google.com/file/d/1sH2T7KLFNjUnTM-V1eRMM1Tytxd2sYAp/view?usp=sharing): The default model in paper. Please first try the base model  to run the pipeline through.

[Unitalker-data-release-V1](https://drive.google.com/file/d/1Un7TB0Z5A1CG6bgeqKlhnSOECFN-C6KK/view?usp=sharing): The released datasets, PCA models, data-split json files and id-template numpy array. Download and unzip it in this repo.

[FLAME2020](https://flame.is.tue.mpg.de/download.php): Please download FLAME 2020 and move generic_model.pkl into resources/binary_resources/flame.pkl.

<!-- [PCA models](https://drive.google.com/file/d/1e0sG2vvdrtAMgwD5njctifhX0ai4eu3g/view?usp=sharing): download the pca models and unzip it in "./unitalker_data_release" -->

Use `git lfs pull` to get `./resources.zip` and `./test_audios.zip` and unzip it in this repo.

Finally, these files should be organized as follows:

```text
├── pretrained_models
│   ├── UniTalker-B-D0-D7.pt
│   ├── UniTalker-L-D0-D7.pt
├── resources
│   ├── binary_resources
│   │   ├── 02_flame_mouth_idx.npy
│   │   ├── ...
│   │   └── vocaset_FDD_wo_eyes.npy
│   └── obj_template
│       ├── 3DETF_blendshape_weight.obj
│       ├── ...
│       └── meshtalk_6172_vertices.obj
└── unitalker_data_release_V1
│   ├── D0_BIWI
│   │   ├── id_template.npy
│   │   └── pca.npz
│   ├── D1_vocaset
│   │   ├── id_template.npy
│   │   └── pca.npz
│   ├── D2_meshtalk
│   │   ├── id_template.npy
│   │   └── pca.npz
│   ├── D3D4_3DETF
│   │   ├── D3_HDTF
│   │   └── D4_RAVDESS
│   ├── D5_unitalker_faceforensics++
│   │   ├── id_template.npy
│   │   ├── test
│   │   ├── test.json
│   │   ├── train
│   │   ├── train.json
│   │   ├── val
│   │   └── val.json
│   ├── D6_unitalker_Chinese_speech
│   │   ├── id_template.npy
│   │   ├── test
│   │   ├── test.json
│   │   ├── train
│   │   ├── train.json
│   │   ├── val
│   │   └── val.json
│   └── D7_unitalker_song
│       ├── id_template.npy
│       ├── test
│       ├── test.json
│       ├── train
│       ├── train.json
│       ├── val
│       └── val.json
```

### Demo

```bash
  python -m main.demo --config config/unitalker.yaml test_out_path ./test_results/demo.npz
  python -m main.render ./test_results/demo.npz ./test_audios ./test_results/
```

## Train

### Download Data
[Unitalker-data-release-V1](https://drive.google.com/file/d/1qRBPsTdOWp72ty04oD1Q_ivtwMjrACLH/view?usp=sharing) contains D5, D6 and D7. The datasets have been processed and grouped into train, validation and test. Please use these three datasets to try the training step.
If you want to train the model on the D0-D7, you need to download the datasets following these links: 
[D0: BIWI](https://github.com/Doubiiu/CodeTalker/blob/main/BIWI/README.md).
[D1: VOCASET](https://voca.is.tue.mpg.de/).
[D2: meshtalk](https://github.com/facebookresearch/meshtalk?tab=readme-ov-file).
[D4,D5: 3DETF](https://github.com/psyai-net/EmoTalk_release).

### Modify Config and Train
Please modify `dataset` and `duplicate_list` in `config/unitalker.yaml` according to the datasets you have prepared, ensuring that both lists maintain the same length.


```bash
python -m main.train --config config/unitalker.yaml 
```

================================================
FILE: config/unitalker.yaml
================================================
DATA:
  dataset: "D0,D1,D2,D3,D4,D5,D6,D7"
  demo_dataset: "D0,D1,D2,D3,D4,D6,D7" # To visualize D5,  please download FLAME2020 model first and place it in "resources/binary_resources/flame.pkl"
  data_root: "./unitalker_data_release_V1/"
  duplicate_list: "10,5,4,1,1,1,1,1"
  train_json_key: train.json

LOSS:
  rec_weight: 1.0
  pca_weight: 0.01
  blendshape_weight: 0.0001


NETWORK:
  use_pca: True
  pca_dim: 512

MODEL:
  audio_encoder_repo: 'microsoft/wavlm-base-plus' # choose from 'microsoft/wavlm-base-plus' and facebook/wav2vec2-large-xlsr-53'
  freeze_wav2vec: False
  interpolate_pos: 1
  decoder_dimension: 256
  decoder_type: conv  # choose from conv and transformer
  period: 30
  headlayer: 1

TRAIN:
  fix_encoder_first: True
  random_id_prob: 0.1
  re_init_decoder_and_head: False
  only_head: False
  workers: 4  # data loader workers
  batch_size: 1  # batch size for training
  batch_size_val: 1  # batch size for validation during training, memory and speed tradeoff
  lr: 0.0001
  epochs: 100
  save_every: 1
  save_path: "./results"
  weight_path: "./pretrained_models/UniTalker-B-D0-D7.pt"
  eval_every: 1
  test_every: 1
  log_every: 1
  do_test: True

TEST:
  test_wav_dir: "./test_audios"
  test_out_path: "./test_results/demo.npz"


================================================
FILE: dataset/data_item.py
================================================
import librosa
import numpy as np
from typing import Any

from .dataset_config import dataset_config


class DataItem:

    def __init__(self,
                 annot_path: str,
                 audio_path: str,
                 identity_idx: int,
                 annot_type: str = '',
                 dataset_name: str = '',
                 id_template: np.ndarray = None,
                 fps: float = 30,
                 processor=None) -> None:
        self.annot_path = annot_path
        self.audio_path = audio_path
        self.identity_idx = identity_idx
        self.fps = fps
        self.annot_data = np.load(annot_path)
        if self.fps == 60:
            self.annot_data = self.annot_data[::2]
            self.fps = 30
        elif self.fps == 100:
            self.annot_data = self.annot_data[::4]
            self.fps = 25
        elif self.fps != 30 and self.fps != 25:
            raise ValueError('wrong fps')
        scale = dataset_config[dataset_name]['scale']
        self.annot_data = self.scale_and_offset(self.annot_data, scale, 0.0)
        self.audio_data, sr = self.load_audio(audio_path)
        self.original_audio_data, self.annot_data = self.truncate(
            self.audio_data, sr, self.annot_data, self.fps)
        self.audio_data = np.squeeze(
            processor(self.original_audio_data, sampling_rate=sr).input_values)
        self.audio_data = self.audio_data.astype(np.float32)
        self.annot_type = annot_type
        self.id_template = self.scale_and_offset(id_template, scale, 0.0)
        self.annot_data = self.annot_data.reshape(len(self.annot_data),
                                                  -1).astype(np.float32)
        self.id_template = self.id_template.reshape(1, -1).astype(np.float32)
        self.dataset_name = dataset_name
        self.data_dict = self.to_dict()
        return

    def load_audio(self, audio_path: str):
        if audio_path.endswith('.npy'):
            return np.load(audio_path), 16000
        else:
            return librosa.load(audio_path, sr=16000)

    def truncate(self, audio_array: np.ndarray, sr: int,
                 annot_data: np.ndarray, fps: int):
        audio_duration = len(audio_array) / sr
        annot_duration = len(annot_data) / fps
        duration = min(audio_duration, annot_duration, 20)
        audio_length = round(duration * sr)
        annot_length = round(duration * fps)
        self.duration = duration
        return audio_array[:audio_length], annot_data[:annot_length]

    def offset_id(self, offset: int):
        self.identity_idx = self.identity_idx + offset
        self.data_dict['id'] = self.identity_idx
        return

    def scale_and_offset(self,
                         data: np.ndarray,
                         scale: float = 1.0,
                         offset: np.ndarray = 0.0):
        return data * scale + offset

    def to_dict(self, ) -> Any:
        item = {
            'data': self.annot_data,
            'audio': self.audio_data,
            'original_audio': self.original_audio_data,
            'fps': self.fps,
            'id': self.identity_idx,
            'annot_path': self.annot_path,
            'audio_path': self.audio_path,
            'annot_type': self.annot_type,
            'template': self.id_template,
            'dataset_name': self.dataset_name
        }
        return item

    def get_dict(self, ):
        return self.data_dict


================================================
FILE: dataset/dataset.py
================================================
import json
import numpy as np
import os.path as osp
from torch.utils import data
from tqdm import tqdm
from transformers import Wav2Vec2FeatureExtractor
from typing import Any, List

from .data_item import DataItem
from .dataset_config import dataset_config


class AudioFaceDataset(data.Dataset):

    def __init__(self, data_label_path: str, args: dict = None) -> None:
        super().__init__()
        data_root = osp.dirname(data_label_path, )
        id_template_path = osp.join(data_root, 'id_template.npy')
        id_template_list = np.load(id_template_path)
        with open(data_label_path) as f:
            labels = json.load(f)
        info = labels['info']
        self.id_list = info['id_list']
        self.sr = 16000
        data_info_list = labels['data']
        data_list = []
        for data_info in tqdm(data_info_list):
            data = DataItem(
                annot_path=osp.join(data_root, data_info['annot_path']),
                audio_path=osp.join(data_root, data_info['audio_path']),
                identity_idx=data_info['id'],
                annot_type=data_info['annot_type'],
                dataset_name=data_info['dataset'],
                id_template=id_template_list[data_info['id']],
                fps=data_info['fps'],
                processor=args.processor,
            )
            if data.duration < 0.5:
                continue
            data_list.append(data)
        self.data_list = data_list
        return

    def __len__(self, ):
        return len(self.data_list)

    def __getitem__(self, index: Any) -> Any:
        data = self.data_list[index]
        return data.get_dict()

    def get_identity_num(self, ):
        return len(self.id_list)


class MixAudioFaceDataset(AudioFaceDataset):

    def __init__(self,
                 dataset_list: List[AudioFaceDataset],
                 duplicate_list: list = None) -> None:
        super(AudioFaceDataset).__init__()
        self.id_list = []
        self.sr = 16000
        self.data_list = []
        self.annot_type_list = []
        if duplicate_list is None:
            duplicate_list = [1] * len(dataset_list)
        else:
            assert len(duplicate_list) == len(dataset_list)
        for dup, dataset in zip(duplicate_list, dataset_list):
            id_index_offset = len(self.id_list)
            for d in dataset.data_list:
                d.offset_id(id_index_offset)
                if d.annot_type not in self.annot_type_list:
                    self.annot_type_list.append(d.annot_type)
            self.id_list = self.id_list + dataset.id_list
            self.data_list = self.data_list + dataset.data_list * dup
        return


def get_dataset_list(args):
    dataset_name_list = args.dataset
    dataset_dir_list = [
        dataset_config[name]['dirname'] for name in dataset_name_list
    ]
    train_json_list = [
        osp.join(args.data_root, dataset_name, args.train_json_key)
        for dataset_name in dataset_dir_list
    ]
    args.read_audio = False
    args.processor = None
    train_dataset_list = []
    processor = Wav2Vec2FeatureExtractor.from_pretrained(
        args.audio_encoder_repo)
    args.processor = processor
    for train_json in train_json_list:
        dataset = AudioFaceDataset(
            train_json,
            args,
        )
        train_dataset_list.append(dataset)
    return train_dataset_list


def get_single_dataset(args, json_path: str = ''):
    args.read_audio = True
    processor = Wav2Vec2FeatureExtractor.from_pretrained(
        args.audio_encoder_repo)
    args.processor = processor

    dataset = AudioFaceDataset(
        json_path,
        args,
    )
    test_loader = data.DataLoader(
        dataset=dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
        num_workers=args.workers)
    return test_loader


def get_dataloaders(args):
    dataset_name_list = args.dataset
    print(dataset_name_list)
    dataset_dir_list = [
        dataset_config[name]['dirname'] for name in dataset_name_list
    ]
    train_json_list = [
        osp.join(args.data_root, dataset_dir, args.train_json_key)
        for dataset_dir in dataset_dir_list
    ]
    val_json_list = [
        osp.join(args.data_root, dataset_dir, 'val.json')
        for dataset_dir in dataset_dir_list
    ]
    test_json_list = [
        osp.join(args.data_root, dataset_dir, 'test.json')
        for dataset_dir in dataset_dir_list
    ]
    processor = Wav2Vec2FeatureExtractor.from_pretrained(
        args.audio_encoder_repo)
    args.processor = processor
    if getattr(args, 'do_train', True) is True:
        train_dataset_list = []
        for train_json in train_json_list:
            dataset = AudioFaceDataset(train_json, args)
            train_dataset_list.append(dataset)
        mix_train_dataset = MixAudioFaceDataset(train_dataset_list,
                                                args.duplicate_list)
        train_loader = data.DataLoader(
            dataset=mix_train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=args.workers)
    else:
        train_loader = None
    if getattr(args, 'do_validate', True) is True:
        val_dataset_list = []
        for val_json in val_json_list:
            dataset = AudioFaceDataset(val_json, args)
            val_dataset_list.append(dataset)
        mix_val_dataset = MixAudioFaceDataset(val_dataset_list)
        val_loader = data.DataLoader(
            dataset=mix_val_dataset,
            batch_size=1,
            shuffle=False,
            pin_memory=True,
            num_workers=args.workers)
    else:
        val_loader = None
    if getattr(args, 'do_test', True) is True:
        test_dataset_list = []
        for test_json in test_json_list:
            dataset = AudioFaceDataset(test_json, args)
            test_dataset_list.append(dataset)
        mix_test_dataset = MixAudioFaceDataset(test_dataset_list)
        test_loader = data.DataLoader(
            dataset=mix_test_dataset,
            batch_size=1,
            shuffle=False,
            pin_memory=True,
            num_workers=args.workers)
    else:
        test_loader = None
    return train_loader, val_loader, test_loader


================================================
FILE: dataset/dataset_config.py
================================================
dataset_config = {
    'D0': {
        'dirname': 'D0_BIWI',
        'annot_type': 'BIWI_23370_vertices',
        'scale': 0.2,
        'annot_dim': 23370 * 3,
        'subjects': 6,
        'pca': True,
    },
    'D1': {
        'dirname': 'D1_vocaset',
        'annot_type': 'FLAME_5023_vertices',
        'scale': 1.0,
        'annot_dim': 5023 * 3,
        'subjects': 12,
        'pca': True,
    },
    'D2': {
        'dirname': 'D2_meshtalk',
        'annot_type': 'meshtalk_6172_vertices',
        'scale': 0.001,
        'annot_dim': 6172 * 3,
        'subjects': 13,
        'pca': True,
    },
    'D3': {
        'dirname': 'D3D4_3DETF/D3_HDTF',
        'annot_type': '3DETF_blendshape_weight',
        'scale': 1.0,
        'annot_dim': 52,
        'subjects': 141,
        'pca': False,
    },
    'D4': {
        'dirname': 'D3D4_3DETF/D4_RAVDESS',
        'annot_type': '3DETF_blendshape_weight',
        'scale': 1.0,
        'annot_dim': 52,
        'subjects': 24,
        'pca': False,
    },
    'D5': {
        'dirname': 'D5_unitalker_faceforensics++',
        'annot_type': 'flame_params_from_dadhead',
        'scale': 1.0,
        'annot_dim': 413,
        'subjects': 719,
        'pca': False,
    },
    'D6': {
        'dirname': 'D6_unitalker_Chinese_speech',
        'annot_type': 'inhouse_blendshape_weight',
        'scale': 1.0,
        'annot_dim': 51,
        'subjects': 8,
        'pca': False,
    },
    'D7': {
        'dirname': 'D7_unitalker_song',
        'annot_type': 'inhouse_blendshape_weight',
        'scale': 1.0,
        'annot_dim': 51,
        'subjects': 11,
        'pca': False,
    },
}


================================================
FILE: loss/loss.py
================================================
import numpy as np
import pickle
import torch
from dataclasses import dataclass
from smplx.lbs import lbs
from smplx.utils import Struct, to_np, to_tensor
from torch import nn


@dataclass
class FlameParams:
    betas: torch.Tensor
    jaw: torch.Tensor
    eyeballs: torch.Tensor
    neck: torch.Tensor


class FlameLayer(nn.Module):

    def __init__(self, model_path: str) -> None:
        super().__init__()
        self.dtype = torch.float32
        with open(model_path, 'rb') as f:
            self.flame_model = Struct(**pickle.load(f, encoding='latin1'))
        shapedirs = self.flame_model.shapedirs
        # The shape components
        self.register_buffer('shapedirs',
                             to_tensor(to_np(shapedirs), dtype=self.dtype))

        j_regressor = to_tensor(
            to_np(self.flame_model.J_regressor), dtype=self.dtype)
        self.register_buffer('J_regressor', j_regressor)
        self.register_buffer(
            'v_template',
            to_tensor(to_np(self.flame_model.v_template), dtype=self.dtype),
        )
        num_pose_basis = self.flame_model.posedirs.shape[-1]
        posedirs = np.reshape(self.flame_model.posedirs,
                              [-1, num_pose_basis]).T
        self.register_buffer('posedirs',
                             to_tensor(to_np(posedirs), dtype=self.dtype))
        parents = to_tensor(to_np(self.flame_model.kintree_table[0])).long()
        parents[0] = -1
        self.register_buffer('parents', parents)

        self.register_buffer(
            'lbs_weights',
            to_tensor(to_np(self.flame_model.weights), dtype=self.dtype))
        default_neck_pose = torch.zeros([1, 3],
                                        dtype=self.dtype,
                                        requires_grad=False)
        self.register_parameter(
            'neck_pose', nn.Parameter(default_neck_pose, requires_grad=False))
        default_eyeball_pose = torch.zeros([1, 6],
                                           dtype=self.dtype,
                                           requires_grad=False)
        self.register_parameter(
            'eyeballs',
            nn.Parameter(default_eyeball_pose, requires_grad=False))
        default_jaw = torch.zeros([1, 3],
                                  dtype=self.dtype,
                                  requires_grad=False)
        self.register_parameter('jaw',
                                nn.Parameter(default_jaw, requires_grad=False))

        self.FLAME_CONSTS = {
            'betas': 400,
            'rotation': 6,
            'jaw': 3,
            'eyeballs': 0,
            'neck': 0,
            'translation': 3,
            'scale': 1,
        }

    def flame_params_from_3dmm(self,
                               tensor_3dmm: torch.Tensor,
                               zero_expr: bool = False):
        constants = self.FLAME_CONSTS

        assert tensor_3dmm.ndim == 2

        cur_index: int = 0
        betas = tensor_3dmm[:, :constants['betas']]
        cur_index += constants['betas']
        jaw = tensor_3dmm[:, cur_index:cur_index + constants['jaw']]
        cur_index += constants['jaw']
        cur_index += constants['rotation']
        eyeballs = tensor_3dmm[:, cur_index:cur_index + constants['eyeballs']]
        cur_index += constants['eyeballs']
        neck = tensor_3dmm[:, cur_index:cur_index + constants['neck']]
        cur_index += constants['neck']
        cur_index += constants['translation']
        cur_index += constants['scale']
        return FlameParams(betas=betas, jaw=jaw, eyeballs=eyeballs, neck=neck)

    def forward(self, tensor_3dmm: torch.Tensor):
        bs = tensor_3dmm.shape[0]
        flame_params = self.flame_params_from_3dmm(tensor_3dmm)
        betas = flame_params.betas
        rotation = torch.zeros([bs, 3], device=tensor_3dmm.device)
        neck_pose = flame_params.neck if not (
            0 in flame_params.neck.shape) else self.neck_pose[[0]].expand(
                bs, -1)
        eyeballs = flame_params.eyeballs if not (
            0 in flame_params.eyeballs.shape) else self.eyeballs[[0]].expand(
                bs, -1)
        jaw = flame_params.jaw if not (
            0 in flame_params.jaw.shape) else self.jaw[[0]].expand(bs, -1)

        full_pose = torch.cat([rotation, neck_pose, jaw, eyeballs], dim=1)
        template_vertices = self.v_template.unsqueeze(0).repeat(bs, 1, 1)
        vertices, _ = lbs(
            betas,
            full_pose,
            template_vertices,
            self.shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
        )
        return vertices


class FlameParamLossForDADHead(nn.Module):

    def __init__(self, mouth_indices_path: str, loss_mask_path: str,
                 flame_model_path: str) -> None:
        super().__init__()
        mouth_indices = np.load(mouth_indices_path)
        loss_indices = np.load(loss_mask_path)
        self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))
        self.register_buffer('loss_indices', torch.from_numpy(loss_indices))
        self.mse = nn.MSELoss()
        self.flame_layer = FlameLayer(flame_model_path)

    def forward(self, x: torch.Tensor, target: torch.Tensor):
        ori_shape = x.shape
        x = self.flame_layer(x.reshape(-1, ori_shape[-1]))
        target = self.flame_layer(target.reshape(-1, ori_shape[-1]))
        x = x.reshape((ori_shape[0], ori_shape[1], -1, 3))
        target = target.reshape((ori_shape[0], ori_shape[1], -1, 3))
        return self.mse(x[:, :, self.loss_indices], target[:, :,
                                                           self.loss_indices])

    def get_vertices(
        self,
        x: torch.Tensor,
    ):
        x = self.flame_layer(x.reshape(-1, x.shape[-1]))
        return x[:, self.loss_indices]

    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
        ori_shape = x.shape
        x = self.flame_layer(x.reshape(-1, ori_shape[-1]))
        target = self.flame_layer(target.reshape(-1, ori_shape[-1]))
        metric_L2 = (target[:, self.mouth_idx] - x[:, self.mouth_idx])**2
        metric_L2 = metric_L2.sum(-1)
        metric_L2 = metric_L2.max(-1)[0]
        metric_L2norm = metric_L2**0.5
        return metric_L2.mean(), metric_L2norm.mean()


class VerticesLoss(nn.Module):

    def __init__(self, mouth_indices_path: str) -> None:
        super().__init__()
        mouth_indices = np.load(mouth_indices_path)
        self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))
        self.mse = nn.MSELoss()

    def forward(self, x: torch.Tensor, target: torch.Tensor):
        return self.mse(x, target)

    def get_vertices(
        self,
        x: torch.Tensor,
    ):
        return x.reshape(-1, x.shape[-1] // 3, 3)

    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
        ori_shape = x.shape
        target = target.reshape(*ori_shape[:-1], -1, 3)
        x = x.reshape(*ori_shape[:-1], -1, 3)
        metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2
        metric_L2 = metric_L2.sum(-1)
        metric_L2 = metric_L2.max(-1)[0]
        metric_L2norm = metric_L2**0.5
        return metric_L2.mean(), metric_L2norm.mean()


class BlendShapeLoss(nn.Module):

    def __init__(self,
                 mouth_indices_path: str,
                 blendshape_path: str,
                 bs_beta: float = 0.0,
                 components_num: int = None) -> None:
        super().__init__()
        mouth_indices = np.load(mouth_indices_path)
        blendshape = np.load(blendshape_path)
        self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))
        mean_shape = blendshape['meanshape']
        mean_shape = mean_shape.reshape(-1).astype(np.float32)
        blend_shape = blendshape['blendshape']
        blend_shape = blend_shape.reshape(len(blend_shape),
                                          -1).astype(np.float32)
        self.register_buffer('mean_shape', torch.from_numpy(mean_shape))
        self.register_buffer('blend_shape', torch.from_numpy(blend_shape))
        self.bs_beta = bs_beta
        self.mse = nn.MSELoss()
        self.components_num = components_num

    def forward(self, x: torch.Tensor, target: torch.Tensor):
        return self.bs_beta * self.mse(x, target) + self.mse(
            self.bs2vertices(x), self.bs2vertices(target))

    def bs2vertices(self, x: torch.Tensor):
        return x[:, :self.components_num] @ \
            self.blend_shape[:self.components_num] + \
            self.mean_shape

    def get_vertices(
        self,
        x: torch.Tensor,
    ):
        x = self.bs2vertices(x)
        return x.reshape(-1, x.shape[-1] // 3, 3)

    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
        target = self.bs2vertices(target)
        x = self.bs2vertices(x)
        ori_shape = x.shape
        target = target.reshape(*ori_shape[:-1], -1, 3)
        x = x.reshape(*ori_shape[:-1], -1, 3)
        metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2
        metric_L2 = metric_L2.sum(-1)
        metric_L2 = metric_L2.max(-1)[0]
        metric_L2norm = metric_L2**0.5
        return metric_L2.mean(), metric_L2norm.mean()


class ParamLoss(nn.Module):

    def __init__(self, weight: float):
        super().__init__()
        self.weight = weight
        self.mse = nn.MSELoss()
        self.L1 = nn.L1Loss()

    def forward(self, x: torch.Tensor, target: torch.Tensor):
        return self.weight * self.mse(x, target)

    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
        return self.mse(x, target), self.L1(x, target)

    def get_vertices(
        self,
        x: torch.Tensor,
    ):
        return x.squeeze(0)


class UniTalkerLoss(nn.Module):

    def __init__(self, args) -> None:
        super().__init__()
        loss_config = {
            'flame_params_from_dadhead': {
                'class': FlameParamLossForDADHead,
                'args': {
                    'mouth_indices_path':
                    'resources/binary_resources/02_flame_mouth_idx.npy',
                    'loss_mask_path':
                    'resources/binary_resources/flame_humanface_index.npy',
                    'flame_model_path': 'resources/binary_resources/flame.pkl',
                }
            },
            '3DETF_blendshape_weight': {
                'class': BlendShapeLoss,
                'args': {
                    'mouth_indices_path':
                    'resources/binary_resources/03_emo_talk_mouth_idx.npy',
                    'blendshape_path':
                    'resources/binary_resources/EmoTalk.npz',
                    'bs_beta': args.blendshape_weight,
                },
            },
            'inhouse_blendshape_weight': {
                'class': BlendShapeLoss,
                'args': {
                    'mouth_indices_path':
                    'resources/binary_resources/05_inhouse_arkit_mouth_idx.npy',
                    'blendshape_path':
                    'resources/binary_resources/inhouse_arkit.npz',
                    'bs_beta': args.blendshape_weight,
                },
            },
            'FLAME_5023_vertices': {
                'class': VerticesLoss,
                'args': {
                    'mouth_indices_path':
                    'resources/binary_resources/02_flame_mouth_idx.npy',
                },
            },
            'BIWI_23370_vertices': {
                'class': VerticesLoss,
                'args': {
                    'mouth_indices_path':
                    'resources/binary_resources/04_BIWI_mouth_idx.npy',
                },
            },
            'FLAME_SUB_2055_vertices': {
                'class': VerticesLoss,
                'args': {
                    'mouth_indices_path':
                    'resources/binary_resources/mouth_idx_FLAME_5023_vertices_wo_eyes_wo_head.npy',
                },
            },
            'meshtalk_6172_vertices': {
                'class': VerticesLoss,
                'args': {
                    'mouth_indices_path':
                    'resources/binary_resources/06_mesh_talk_mouth_idx.npy',
                },
            },
            '24_viseme': {
                'class': ParamLoss,
                'args': {
                    'weight': args.blendshape_weight
                },
            }
        }

        loss_module_dict = {}
        for k, v in loss_config.items():
            loss_module = v['class'](**v['args'])
            loss_module_dict[k] = loss_module
        self.loss_module_dict = nn.ModuleDict(loss_module_dict)
        self.mse = nn.MSELoss()
        return

    def forward(
        self,
        x: torch.Tensor,
        target: torch.Tensor,
        annot_type: str,
    ):
        return self.loss_module_dict[annot_type](x, target)

    def pca_loss(self, x: torch.Tensor, target: torch.Tensor):
        return self.mse(x, target)

    def mouth_metric(self, x: torch.Tensor, target: torch.Tensor,
                     annot_type: str):
        return self.loss_module_dict[annot_type].mouth_metric(x, target)

    def get_vertices(self, x: torch.Tensor, annot_type: str):
        return self.loss_module_dict[annot_type].get_vertices(x)


================================================
FILE: main/demo.py
================================================
import librosa
import numpy as np
import os
import torch
import torch.optim
import torch.utils.data
from transformers import Wav2Vec2FeatureExtractor

from dataset.dataset_config import dataset_config
from loss.loss import UniTalkerLoss
from models.unitalker import UniTalker
from utils.utils import get_parser, get_template_verts, get_audio_encoder_dim

def get_all_audios(audio_root:str):
    wav_f_names = []
    for r, _, f in os.walk(audio_root):
        for wav in f:
            if wav.endswith('.wav') or wav.endswith('.mp3'):
                relative_path = os.path.join(r, wav)
                relative_path = os.path.relpath(relative_path,
                                                audio_root)
                wav_f_names.append(relative_path)
    wav_f_names = sorted(wav_f_names)
    return wav_f_names

def split_long_audio(
    audio: np.ndarray,
    processor:Wav2Vec2FeatureExtractor
):
    # audio = audio.squeeze(0)
    a, b = 25, 5
    sr = 16000 
    total_length = len(audio) /sr
    reps = max(0, int(np.ceil((total_length - a) / (a - b)))) + 1
    in_audio_split_list = []
    start, end = 0, int(a * sr)
    step = int((a - b) * sr)
    for i in range(reps):
        audio_split = audio[start:end]
        audio_split = np.squeeze(
            processor(audio_split, sampling_rate=sr).input_values)
        in_audio_split_list.append(audio_split)
        start += step
        end += step
    return in_audio_split_list

def merge_out_list(out_list: list, fps:int):
    if len(out_list) == 1:
        return out_list[0]
    a, b = 25, 5
    left_weight = np.linspace(1, 0, b * fps)[:, np.newaxis]
    right_weight = 1 - left_weight
    a = a * fps 
    b = b * fps 
    offset = a - b

    out_length = len(out_list[-1]) + offset * (len(out_list) - 1)
    merged_out = np.empty((out_length, out_list[-1].shape[-1]),
                            dtype=out_list[-1].dtype)
    merged_out[:a] = out_list[0]
    for out_piece in out_list[1:]:
        merged_out[a - b:a] = left_weight * merged_out[
            a - b:a] + right_weight * out_piece[:b]
        merged_out[a:a + offset] = out_piece[b:]
        a += offset
    return merged_out


def main():
    args = get_parser()
    training_dataset_name_list = args.dataset
    demo_dataset_name_list = args.demo_dataset
    
    condition_id_config = {
        'D0': 3,
        'D1': 3,
        'D2': 0,
        'D3': 0,
        'D4': 0,
        'D5': 0,
        'D6': 4,
        'D7': 0,
    }
    template_id_config = {
        'D0': 3,
        'D1': 3,
        'D2': 0,
        'D3': 0,
        'D4': 0,
        'D5': 0,
        'D6': 4,
        'D7': 0,
    }

    checkpoint = torch.load(args.weight_path, map_location='cpu')
    args.identity_num = len(checkpoint['decoder.learnable_style_emb.weight'])


    start_idx = 0
    for dataset_name in training_dataset_name_list:
        annot_type = dataset_config[dataset_name]['annot_type']
        id_num = dataset_config[dataset_name]['subjects']
        end_idx = start_idx + id_num
        local_condition_idx = condition_id_config[dataset_name]
        template_idx = template_id_config[dataset_name]
        template = get_template_verts(args.data_root, dataset_name, template_idx)
        template_id_config[dataset_name] = torch.Tensor(template.reshape(1, -1))
        if args.condition_id == 'each':
            condition_id_config[dataset_name] = torch.tensor(
                start_idx + local_condition_idx).reshape(1)
        elif args.condition_id == 'common':
            condition_id_config[dataset_name] = torch.tensor(args.identity_num -
                                                      1).reshape(1)
        else:
            try:
                condition_id = int(args.condition_id)
                condition_id_config[dataset_name] = torch.tensor(
                    condition_id).reshape(1)
            except ValueError:
                assert args.condition_id in dataset_config.keys()
                condition_id_config[dataset_name] = torch.tensor(
                    start_idx + local_condition_idx).reshape(1)
        start_idx = end_idx
    if args.random_id_prob > 0:
        assert end_idx == (args.identity_num - 1)
    else:
        assert end_idx == args.identity_num

    args.audio_encoder_feature_dim = get_audio_encoder_dim(
        args.audio_encoder_repo)

    model = UniTalker(args)
    model.load_state_dict(checkpoint, strict=False)
    model.eval()
    model.cuda()

    wav_f_names = get_all_audios(args.test_wav_dir)
    out_npz_path = args.test_out_path
    dirname = os.path.dirname(out_npz_path)
    os.makedirs(dirname, exist_ok=True)


    out_dict = {}
    loss_module = UniTalkerLoss(args).cuda()
    processor = Wav2Vec2FeatureExtractor.from_pretrained(
        args.audio_encoder_repo,)
    
    with torch.no_grad():
        for wav_f in wav_f_names:
            out_dict[wav_f] = {}
            wav_path = os.path.join(args.test_wav_dir, wav_f)
            audio_data, sr = librosa.load(wav_path, sr=16000)
            audio_data = np.squeeze(
                processor(audio_data, sampling_rate=sr).input_values)

            audio_data_splits = split_long_audio(audio_data, processor)
            for dataset_name in demo_dataset_name_list:
                template = template_id_config[dataset_name].cuda()
                scale = dataset_config[dataset_name]['scale']
                template = scale * template
                condition_id = condition_id_config[dataset_name].cuda()
                annot_type = dataset_config[dataset_name]['annot_type']
                if annot_type == 'BIWI_23370_vertices':
                    fps = 25
                else:
                    fps = 30

                out_list = []
                for audio_data in audio_data_splits:
                    audio_data = torch.Tensor(audio_data[None]).cuda()
                    out, _, _ = model(
                        audio_data,
                        template,
                        face_motion=None,
                        style_idx=condition_id,
                        annot_type=annot_type,
                        fps=fps,
                    )
                    out_list.append(out.cpu().numpy().squeeze(0))
                out = merge_out_list(out_list, fps)
                out = loss_module.get_vertices(torch.from_numpy(out).cuda(), annot_type)
                out_dict[wav_f][dataset_name] = out.cpu() 
        print(f"save results to {out_npz_path}")
        np.savez(out_npz_path, **out_dict)


if __name__ == '__main__':
    main()


================================================
FILE: main/render.py
================================================
"""Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V.

(MPG) is holder of all proprietary rights on this computer program. You can
only use this computer program if you have closed a license agreement with MPG
or you get the right to use the computer program from someone who is authorized
to grant you that right. Any use of the computer program without a valid
license is prohibited and liable to prosecution. Copyright 2019 Max-Planck-
Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of
its Max Planck Institute for Intelligent Systems and the Max Planck Institute
for Biological Cybernetics. All rights reserved. More information about VOCA is
available at http://voca.is.tue.mpg.de. For comments or questions, please email
us at voca@tue.mpg.de
"""

import cv2
import numpy as np
import os
import subprocess
import torch
from tqdm import tqdm
from dataset.dataset_config import dataset_config

def render_meshes(mesh_vertices: torch.Tensor,
                  faces: torch.Tensor,
                  img_size: tuple = (256, 256),
                  aa_factor: int = 1,
                  vis_bs: int = 100):
    assert len(mesh_vertices) == len(faces) or len(faces) == 1
    if len(faces) == 1:
        faces = faces.repeat(len(mesh_vertices), 1, 1)

    x_max = y_max = mesh_vertices[..., 0:2].abs().max()
    from pytorch3d.renderer import (
        DirectionalLights,
        FoVOrthographicCameras,
        HardPhongShader,
        Materials,
        MeshRasterizer,
        MeshRenderer,
        RasterizationSettings,
        TexturesVertex,
    )
    from pytorch3d.renderer.blending import BlendParams
    from pytorch3d.renderer.materials import Materials
    from pytorch3d.structures import Meshes
    aa_factor = int(aa_factor)
    device = mesh_vertices.device
    verts_batch_lst = torch.split(mesh_vertices, vis_bs)
    faces_batch_lst = torch.split(faces, vis_bs)
    R = torch.eye(3).to(mesh_vertices.device)
    R[0, 0] = -1
    R[2, 2] = -1
    T = torch.zeros(3).to(mesh_vertices.device)
    T[2] = 10
    R, T = R[None], T[None]
    W, H = img_size
    cameras = FoVOrthographicCameras(
        device=device,
        R=R,
        T=T,
        znear=0.01,
        zfar=3,
        max_x=x_max * 1.2,
        min_x=-x_max * 1.2,
        max_y=y_max * 1.2,
        min_y=-y_max * 1.2,
    )
    # cameras = FoVPerspectiveCameras(
    # 	device=device,
    # 	R=R,
    # 	T=T,
    # 	znear=0.01,
    # 	zfar=3,
    # 	aspect_ratio=1,
    # 	fov=0.3,
    # 	degrees=False
    # 	)

    raster_settings = RasterizationSettings(
        image_size=(W * aa_factor, H * aa_factor),
        blur_radius=0.0,
        faces_per_pixel=1,
    )
    raster = MeshRasterizer(
        cameras=cameras,
        raster_settings=raster_settings,
    )
    blend_params = BlendParams(
        sigma=0.0, gamma=0.0, background_color=(0.0, 0.0, 0.0))
    lights = DirectionalLights(
        device=device,
        direction=((0, 0, 1), ),
        ambient_color=((0.3, 0.3, 0.3), ),
        diffuse_color=((0.6, 0.6, 0.6), ),
        specular_color=((0.1, 0.1, 0.1), ))
    materias = Materials(
        ambient_color=((1, 1, 1), ),
        diffuse_color=((1, 1, 1), ),
        specular_color=((1, 1, 1), ),
        shininess=15,
        device=device)
    shader = HardPhongShader(
        device=device,
        cameras=cameras,
        lights=lights,
        materials=materias,
        blend_params=blend_params)
    renderer = MeshRenderer(
        rasterizer=raster,
        shader=shader,
    )
    rendered_imgs = []
    for verts_batch, faces_batch in tqdm(
            zip(verts_batch_lst, faces_batch_lst),
            total=(len(verts_batch_lst))):
        textures = TexturesVertex(verts_features=torch.ones_like(verts_batch))
        meshes = Meshes(
            verts=verts_batch, faces=faces_batch, textures=textures)
        with torch.no_grad():
            imgs = renderer(meshes)
        rendered_imgs.append(imgs.cpu())
    rendered_imgs = torch.cat(rendered_imgs, dim=0)
    if aa_factor > 1:
        rendered_imgs = rendered_imgs.permute(0, 3, 1, 2)  # NHWC -> NCHW
        rendered_imgs = torch.nn.functional.interpolate(
            rendered_imgs, scale_factor=1 / aa_factor, mode='bicubic')
        rendered_imgs = rendered_imgs.permute(0, 2, 3, 1)  # NCHW -> NHWC
    return rendered_imgs


def read_obj(in_path):
    with open(in_path, 'r') as obj_file:
        # Read the lines of the OBJ file
        lines = obj_file.readlines()

    # Initialize empty lists for vertices and faces
    verts = []
    faces = []
    for line in lines:
        line = line.strip()  # Remove leading/trailing whitespace
        elements = line.split()  # Split the line into elements

        if len(elements) == 0:
            continue  # Skip empty lines

        # Check the type of line (vertex or face)
        if elements[0] == 'v':
            # Vertex line
            x, y, z = map(float,
                          elements[1:4])  # Extract the vertex coordinates
            verts.append((x, y, z))  # Add the vertex to the list
        elif elements[0] == 'f':
            # Face line
            face_indices = [
                int(index.split('/')[0]) for index in elements[1:]
            ]  # Extract the vertex indices
            faces.append(face_indices)  # Add the face to the list
    return np.array(verts), np.array(faces)


def pad_for_libx264(image_array):
    """Pad zeros if width or height of image_array is not divisible by 2.
    Otherwise you will get.

    \"[libx264 @ 0x1b1d560] width not divisible by 2 \"

    Args:
            image_array (np.ndarray):
                    Image or images load by cv2.imread().
                    Possible shapes:
                    1. [height, width]
                    2. [height, width, channels]
                    3. [images, height, width]
                    4. [images, height, width, channels]

    Returns:
            np.ndarray:
                    A image with both edges divisible by 2.
    """
    if image_array.ndim == 2 or \
      (image_array.ndim == 3 and image_array.shape[2] == 3):
        hei_index = 0
        wid_index = 1
    elif image_array.ndim == 4 or \
      (image_array.ndim == 3 and image_array.shape[2] != 3):
        hei_index = 1
        wid_index = 2
    else:
        return image_array
    hei_pad = image_array.shape[hei_index] % 2
    wid_pad = image_array.shape[wid_index] % 2
    if hei_pad + wid_pad > 0:
        pad_width = []
        for dim_index in range(image_array.ndim):
            if dim_index == hei_index:
                pad_width.append((0, hei_pad))
            elif dim_index == wid_index:
                pad_width.append((0, wid_pad))
            else:
                pad_width.append((0, 0))
        values = 0
        image_array = \
         np.pad(image_array,
             pad_width,
             mode='constant', constant_values=values)
    return image_array


def array_to_video(
    image_array: np.ndarray,
    output_path: str,
    fps=30,
    resolution=None,
    disable_log: bool = False,
) -> None:
    """Convert an array to a video directly, gif not supported.

    Args:
            image_array (np.ndarray): shape should be (f * h * w * 3).
            output_path (str): output video file path.
            fps (Union[int, float, optional): fps. Defaults to 30.
            resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
                    optional): (height, width) of the output video.
                    Defaults to None.
            disable_log (bool, optional): whether close the ffmepg command info.
                    Defaults to False.
    Raises:
            FileNotFoundError: check output path.
            TypeError: check input array.

    Returns:
            None.
    """
    if not isinstance(image_array, np.ndarray):
        raise TypeError('Input should be np.ndarray.')
    assert image_array.ndim == 4
    assert image_array.shape[-1] == 3
    if resolution:
        height, width = resolution
        width += width % 2
        height += height % 2
    else:
        image_array = pad_for_libx264(image_array)
        height, width = image_array.shape[1], image_array.shape[2]
    command = [
        '/usr/bin/ffmpeg',
        '-y',  # (optional) overwrite output file if it exists
        '-f',
        'rawvideo',
        '-s',
        f'{int(width)}x{int(height)}',  # size of one frame
        '-pix_fmt',
        'bgr24',
        '-r',
        f'{fps}',  # frames per second
        '-loglevel',
        'error',
        '-threads',
        '4',
        '-i',
        '-',  # The input comes from a pipe
        '-vcodec',
        'libx264',
        '-an',  # Tells FFMPEG not to expect any audio
        output_path,
    ]
    if not disable_log:
        print(f'Running \"{" ".join(command)}\"')
    process = subprocess.Popen(
        command,
        stdin=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    if process.stdin is None or process.stderr is None:
        raise BrokenPipeError('No buffer received.')
    index = 0
    while True:
        if index >= image_array.shape[0]:
            break
        process.stdin.write(image_array[index].tobytes())
        index += 1
    process.stdin.close()
    process.stderr.close()
    process.wait()


def get_obj_faces(annot_type: str):
    template_obj_path_dict = {
        '3DETF_blendshape_weight': 'resources/obj_template/3DETF_blendshape_weight.obj',
        'FLAME_5023_vertices': 'resources/obj_template/FLAME_5023_vertices.obj',
        'BIWI_23370_vertices': 'resources/obj_template/BIWI_23370_vertices.obj',
        'flame_params_from_dadhead': 'resources/obj_template/flame_params_from_dadhead.obj',
        'inhouse_blendshape_weight': 'resources/obj_template/inhouse_blendshape_weight.obj',
        'meshtalk_6172_vertices': 'resources/obj_template/meshtalk_6172_vertices.obj'
    }

    faces = np.array(read_obj(template_obj_path_dict[annot_type])[1])
    if faces.min() == 1:
        faces = faces - 1
    return faces


def vis_model_out_proxy():
    import sys
    npz_path = sys.argv[1]
    audio_dir = sys.argv[2]
    out_dir = sys.argv[3]
    save_images = False 
    os.makedirs(out_dir, exist_ok=True)
    out_dict = np.load(npz_path, allow_pickle=True)
    print(list(out_dict.keys()))
    for wav_f in out_dict.keys():
        wav_path = os.path.join(audio_dir, wav_f)
        cur_out_dict = out_dict[wav_f]
        for datasetname, out in cur_out_dict.item().items():
            out = torch.Tensor(out)
            annot_type = dataset_config[datasetname]['annot_type']
            faces = get_obj_faces(annot_type)
            with torch.no_grad():
                if save_images:
                    out_size = (1024, 1024)
                else:
                    out_size = (512, 512)
                img_array = render_meshes(out.cuda(),
                                          torch.from_numpy(faces[None]).cuda(),
                                          out_size)
                if save_images:
                    channels = 4
                else:
                    channels = 3
                img_array = (img_array[..., :channels].cpu().numpy() *
                             255).astype(np.uint8)
            out_key = f'{wav_f[:-4]}'
            if save_images:
                out_images_dir = os.path.join(out_dir, datasetname, out_key)
                os.makedirs(out_images_dir, exist_ok=True)
                for imgidx, img in enumerate(img_array):
                    img_path = os.path.join(out_images_dir,
                                            f'{imgidx:04d}.png')
                    cv2.imwrite(img_path, img)
                continue
            out_video_path = os.path.join(out_dir,
                                          f'{out_key}_{datasetname}.mp4')
            out_video_dir = os.path.dirname(out_video_path)
            os.makedirs(out_video_dir, exist_ok=True)
            tmp_path = os.path.join(out_video_dir, 'tmp.mp4')
            if os.path.exists(tmp_path):
                os.remove(tmp_path)
            if annot_type == 'BIWI_23370_vertices':
                fps = 25
            else:
                fps = 30
            print(img_array.shape, img_array.dtype)
            array_to_video(img_array, output_path=tmp_path, fps=fps)

            cmd = f'ffmpeg -i {tmp_path} -i  {wav_path} -c:v copy -c:a aac -shortest -strict -2  {out_video_path} -y '
            print(cmd)
            subprocess.run(cmd, shell=True)
            os.remove(tmp_path)


if __name__ == '__main__':
    vis_model_out_proxy()


================================================
FILE: main/train.py
================================================
#!/usr/bin/env python
# yapf: disable
import os
import torch
import torch.optim
import torch.utils.data
from tensorboardX import SummaryWriter

from dataset.dataset import get_dataloaders
from loss.loss import UniTalkerLoss
from models.unitalker import UniTalker
from utils.utils import (
    get_average_meter_dict, get_logger, get_parser, get_audio_encoder_dim,
    load_ckpt, seed_everything, filter_unitalker_state_dict
)
from .train_eval_loop import train_epoch, validate_epoch

# yapf: enable


def main():
    seed_everything(42)
    args = get_parser()
    args.audio_encoder_feature_dim = get_audio_encoder_dim(args.audio_encoder_repo)
    train_loader, val_loader, test_loader = get_dataloaders(args)
    args.identity_num = train_loader.dataset.get_identity_num()
    args.annot_type_list = train_loader.dataset.annot_type_list
    if args.random_id_prob > 0.0:
        args.identity_num = args.identity_num + 1

    os.makedirs(args.save_path, exist_ok=True)
    log_file = os.path.join(args.save_path, 'train.log')
    logger = get_logger(log_file)
    writer = SummaryWriter(args.save_path)
    logger.info('=> creating model ...')
    model = UniTalker(args)
    if args.weight_path is not None:
        logger.info(f'=> loading model {args.weight_path} ...')
        load_ckpt(model, args.weight_path, args.re_init_decoder_and_head)
    logger.info(args)
    model = model.cuda()
    model.summary(logger)

    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)

    loss_module = UniTalkerLoss(args).cuda()

    args.train_meter, args.train_metric_name_list = \
        get_average_meter_dict(args, 'train')
    args.val_meter, args.val_metric_name_list = \
        get_average_meter_dict(args, 'val')
    args.logger = logger
    args.writer = writer

    for epoch in range(args.epochs):
        if args.fix_encoder_first:
            if epoch == 0:
                model.audio_encoder._freeze_wav2vec2_parameters(True)
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, model.parameters()),
                    lr=args.lr)
            elif epoch == 1:
                model.audio_encoder._freeze_wav2vec2_parameters(
                    args.freeze_wav2vec)
                model.audio_encoder.feature_extractor._freeze_parameters()
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, model.parameters()),
                    lr=args.lr)
        epoch_plus = epoch + 1
        train_epoch(train_loader, model, loss_module, optimizer, epoch_plus,
                    args)

        if epoch_plus % args.eval_every == 0:
            validate_epoch(val_loader, model, loss_module, epoch_plus, 'val',
                           args)

        if epoch_plus % args.test_every == 0:
            validate_epoch(test_loader, model, loss_module, epoch_plus, 'test',
                           args)

        if epoch_plus % args.save_every == 0:
            model_path = os.path.join(args.save_path, f'{epoch_plus:03d}.pt')
            torch.save(filter_unitalker_state_dict(model.state_dict()), model_path)


if __name__ == '__main__':
    main()


================================================
FILE: main/train_eval_loop.py
================================================
import numpy as np
import torch
from tqdm import tqdm

from utils.utils import log_datasetloss, write_to_tensorboard


def train_epoch(data_loader, model, loss_module, optimizer, epoch, args):
    model.train()
    train_meter = args.train_meter
    metric_name_list = args.train_metric_name_list
    zero_loss = torch.tensor(0.0)
    mixed_dataset_name = 'mixed_dataset'
    pbar = tqdm(data_loader)
    pbar.set_description(f'Epoch: {epoch} train')
    for batch in pbar:
        data = batch['data'].cuda(non_blocking=True)
        template = batch['template'].cuda(non_blocking=True)
        audio = batch['audio'].cuda(non_blocking=True)
        identity = batch['id'].cuda(non_blocking=True)
        if args.random_id_prob > 0.0:
            if np.random.random() < args.random_id_prob:
                identity = identity.fill_(args.identity_num - 1)
        annot_type = batch['annot_type'][0]
        dataset_name = batch['dataset_name'][0]

        out_motion, out_pca, gt_pca = model(
            audio, template, data, style_idx=identity, annot_type=annot_type)

        rec_loss = loss_module(out_motion, data, annot_type)
        if out_pca is not None:
            pca_loss = loss_module.pca_loss(out_pca, gt_pca)
        else:
            pca_loss = zero_loss
        loss = rec_loss + args.pca_weight * pca_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for k, v in zip(metric_name_list, (rec_loss, pca_loss)):
            train_meter[dataset_name][k].update(v.item())
            train_meter[mixed_dataset_name][k].update(v.item())

    log_datasetloss(args.logger, epoch, 'train', train_meter, only_mix=True)
    write_to_tensorboard(
        args.writer, epoch, 'train', train_meter, only_mix=False)

    for dataset_name, sub_metric_dict in train_meter.items():
        for metric_name in sub_metric_dict.keys():
            sub_metric_dict[metric_name].reset()
    return


def validate_epoch(data_loader, model, loss_module, epoch, phase, args):
    assert phase in ('val', 'test')
    model.eval()
    val_meter = args.val_meter
    metric_name_list = args.val_metric_name_list
    mixed_dataset_name = 'mixed_dataset'
    with torch.no_grad():
        pbar = tqdm(data_loader)
        pbar.set_description(f'Epoch: {epoch} {phase:<5}')
        for batch in pbar:
            data = batch['data'].cuda(non_blocking=True)
            template = batch['template'].cuda(non_blocking=True)
            audio = batch['audio'].cuda(non_blocking=True)
            identity = batch['id'].cuda(non_blocking=True)
            annot_type = batch['annot_type'][0]
            dataset_name = batch['dataset_name'][0]
            if args.random_id_prob >= 1.0:
                identity = identity.fill_(args.identity_num - 1)
            out_motion, _, _ = model(
                audio,
                template,
                data,
                style_idx=identity,
                annot_type=annot_type)
            rec_loss = loss_module(out_motion, data, annot_type)
            mouth_metric_L2, mouth_metric_L2_norm = loss_module.mouth_metric(
                out_motion, data, annot_type)

            for k, v in zip(metric_name_list,
                            (rec_loss, mouth_metric_L2, mouth_metric_L2_norm)):
                val_meter[dataset_name][k].update(v.item())
                val_meter[mixed_dataset_name][k].update(v.item())
        if getattr(args, 'logger', None) is not None:
            log_datasetloss(
                args.logger, epoch, phase, val_meter, only_mix=True)
        if getattr(args, 'writer', None) is not None:
            write_to_tensorboard(
                args.writer, epoch, phase, val_meter, only_mix=False)
        for dataset_name, sub_metric_dict in val_meter.items():
            for metric_name in sub_metric_dict.keys():
                sub_metric_dict[metric_name].reset()
    return


================================================
FILE: main/train_pca.py
================================================
#!/usr/bin/env python
import numpy as np

from models.pca import PCA
from utils.utils import get_parser, get_audio_encoder_dim


def main():
    args = get_parser()
    args.audio_encoder_feature_dim = get_audio_encoder_dim(
        args.wav2vec2_type)
    main_worker(args)


def main_worker(cfg):
    pca_builder = PCA()
    from dataset.dataset import get_dataset_list
    train_dataset_list = get_dataset_list(cfg)
    for dataset in train_dataset_list:
        data = pca_builder.load_dataset(dataset)
        pca_info = pca_builder.build_incremental_PCA(
            data, batch_size=1095, trunc_dim=647)
        np.savez('mesh_talk_pca.npz', **pca_info)


if __name__ == '__main__':
    main()


================================================
FILE: models/base_model.py
================================================
import numpy as np
import torch.nn as nn


class BaseModel(nn.Module):
    """Base class for all models."""

    def __init__(self):
        super(BaseModel, self).__init__()
        # self.logger = logging.getLogger(self.__class__.__name__)

    def forward(self, *x):
        """Forward pass logic.

        :return: Model output
        """
        raise NotImplementedError

    def freeze_model(self, do_freeze: bool = True):
        for param in self.parameters():
            param.requires_grad = (not do_freeze)

    def summary(self, logger, writer=None):
        """Model summary."""
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size())
                      for p in model_parameters]) / 1e6  # Unit is Mega
        logger.info('===>Trainable parameters: %.3f M' % params)
        if writer is not None:
            writer.add_text('Model Summary',
                            'Trainable parameters: %.3f M' % params)


================================================
FILE: models/hubert.py
================================================
import numpy as np
import torch
from transformers import HubertModel 
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Tuple

from .model_utils import linear_interpolation


# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
def _compute_mask_indices(
    shape: Tuple[int, int],
    mask_prob: float,
    mask_length: int,
    attention_mask: Optional[torch.Tensor] = None,
    min_masks: int = 0,
) -> np.ndarray:
    bsz, all_sz = shape
    mask = np.full((bsz, all_sz), False)

    all_num_mask = int(mask_prob * all_sz / float(mask_length) +
                       np.random.rand())
    all_num_mask = max(min_masks, all_num_mask)
    mask_idcs = []
    padding_mask = attention_mask.ne(1) if attention_mask is not None else None
    for i in range(bsz):
        if padding_mask is not None:
            sz = all_sz - padding_mask[i].long().sum().item()
            num_mask = int(mask_prob * sz / float(mask_length) +
                           np.random.rand())
            num_mask = max(min_masks, num_mask)
        else:
            sz = all_sz
            num_mask = all_num_mask

        lengths = np.full(num_mask, mask_length)

        if sum(lengths) == 0:
            lengths[0] = min(mask_length, sz - 1)

        min_len = min(lengths)
        if sz - min_len <= num_mask:
            min_len = sz - num_mask - 1

        mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
        mask_idc = np.asarray([
            mask_idc[j] + offset for j in range(len(mask_idc))
            for offset in range(lengths[j])
        ])
        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))

    min_len = min([len(m) for m in mask_idcs])
    for i, mask_idc in enumerate(mask_idcs):
        if len(mask_idc) > min_len:
            mask_idc = np.random.choice(mask_idc, min_len, replace=False)
        mask[i, mask_idc] = True
    return mask


class HubertModel(HubertModel):

    def __init__(self, config):
        super().__init__(config)

    def _freeze_parameters(self, do_freeze: bool = True):
        for param in self.parameters():
            param.requires_grad = (not do_freeze)

    def forward(
        self,
        input_values,
        attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        frame_num=None,
        interpolate_pos: int = 0,
    ):
        self.config.output_attentions = True
        output_attentions = output_attentions if \
            output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else
            self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None \
            else self.config.use_return_dict

        conv_features = self.feature_extractor(input_values)
        hidden_states = conv_features.transpose(1, 2)
        if interpolate_pos == 0:
            hidden_states = linear_interpolation(
                hidden_states, output_len=frame_num)

        if attention_mask is not None:
            output_lengths = self._get_feat_extract_output_lengths(
                attention_mask.sum(-1))
            attention_mask = torch.zeros(
                hidden_states.shape[:2],
                dtype=hidden_states.dtype,
                device=hidden_states.device)
            attention_mask[(torch.arange(
                attention_mask.shape[0],
                device=hidden_states.device), output_lengths - 1)] = 1
            attention_mask = attention_mask.flip([-1]).cumsum(-1).flip(
                [-1]).bool()

        hidden_states = self.feature_projection(hidden_states)

        if self.config.apply_spec_augment and self.training:
            batch_size, sequence_length, hidden_size = hidden_states.size()
            if self.config.mask_time_prob > 0:
                mask_time_indices = _compute_mask_indices(
                    (batch_size, sequence_length),
                    self.config.mask_time_prob,
                    self.config.mask_time_length,
                    attention_mask=attention_mask,
                    min_masks=2,
                )
                hidden_states[torch.from_numpy(
                    mask_time_indices)] = self.masked_spec_embed.to(
                        hidden_states.dtype)
            if self.config.mask_feature_prob > 0:
                mask_feature_indices = _compute_mask_indices(
                    (batch_size, hidden_size),
                    self.config.mask_feature_prob,
                    self.config.mask_feature_length,
                )
                mask_feature_indices = torch.from_numpy(
                    mask_feature_indices).to(hidden_states.device)
                hidden_states[mask_feature_indices[:, None].expand(
                    -1, sequence_length, -1)] = 0
        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = encoder_outputs[0]
        if interpolate_pos == 1:
            hidden_states = linear_interpolation(
                hidden_states, output_len=frame_num)

        if not return_dict:
            return (hidden_states, ) + encoder_outputs[1:]

        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )


================================================
FILE: models/model_utils.py
================================================
# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/main.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# linear interpolation layer
def linear_interpolation(features, output_len: int):
    features = features.transpose(1, 2)
    output_features = F.interpolate(
        features, size=output_len, align_corners=True, mode='linear')
    return output_features.transpose(1, 2)


# Temporal Bias
def init_biased_mask(n_head, max_seq_len, period):

    def get_slopes(n):

        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2**math.floor(math.log2(n))
            return get_slopes_power_of_2(closest_power_of_2) + get_slopes(
                2 * closest_power_of_2)[0::2][:n - closest_power_of_2]

    slopes = torch.Tensor(get_slopes(n_head))
    bias = torch.div(
        torch.arange(start=0, end=max_seq_len,
                     step=period).unsqueeze(1).repeat(1, period).view(-1),
        period,
        rounding_mode='floor')
    bias = -torch.flip(bias, dims=[0])
    alibi = torch.zeros(max_seq_len, max_seq_len)
    for i in range(max_seq_len):
        alibi[i, :i + 1] = bias[-(i + 1):]
    alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
    mask = (torch.triu(torch.ones(max_seq_len,
                                  max_seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
        mask == 1, float(0.0))
    mask = mask.unsqueeze(0) + alibi
    return mask


# Alignment Bias
def enc_dec_mask(device, T, S):
    mask = torch.ones(T, S)
    for i in range(T):
        mask[i, i] = 0
    return (mask == 1).to(device=device)


# Periodic Positional Encoding
class PeriodicPositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
        super(PeriodicPositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(period, d_model)
        position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() *
            (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, period, d_model)
        repeat_num = (max_seq_len // period) + 1
        pe = pe.repeat(1, repeat_num, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class TCN(nn.Module):

    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.first_net = SeqTranslator1D(
            in_dim, out_dim, min_layers_num=3, residual=True, norm='ln')
        self.dropout = nn.Dropout(0.1)

    def forward(self, spectrogram, style_embed, time_steps=None):

        spectrogram = spectrogram
        spectrogram = self.dropout(spectrogram)
        style_embed = style_embed.unsqueeze(2)
        style_embed = style_embed.repeat(1, 1, spectrogram.shape[2])
        spectrogram = torch.cat([spectrogram, style_embed], dim=1)
        x1 = self.first_net(spectrogram)  # .permute(0, 2, 1)
        if time_steps is not None:
            x1 = F.interpolate(
                x1, size=time_steps, align_corners=False, mode='linear')

        return x1


class SeqTranslator1D(nn.Module):
    """(B, C, T)->(B, C_out, T)"""

    def __init__(self,
                 C_in,
                 C_out,
                 kernel_size=None,
                 stride=None,
                 min_layers_num=None,
                 residual=True,
                 norm='bn'):
        super(SeqTranslator1D, self).__init__()

        conv_layers = nn.ModuleList([])
        conv_layers.append(
            ConvNormRelu(
                in_channels=C_in,
                out_channels=C_out,
                type='1d',
                kernel_size=kernel_size,
                stride=stride,
                residual=residual,
                norm=norm))
        self.num_layers = 1
        if min_layers_num is not None and self.num_layers < min_layers_num:
            while self.num_layers < min_layers_num:
                conv_layers.append(
                    ConvNormRelu(
                        in_channels=C_out,
                        out_channels=C_out,
                        type='1d',
                        kernel_size=kernel_size,
                        stride=stride,
                        residual=residual,
                        norm=norm))
                self.num_layers += 1
        self.conv_layers = nn.Sequential(*conv_layers)

    def forward(self, x):
        return self.conv_layers(x)


class ConvNormRelu(nn.Module):
    """(B,C_in,H,W) -> (B, C_out, H, W) there exist some kernel size that makes
    the result is not H/s.

    #TODO: there might some problems with residual
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 type='1d',
                 leaky=False,
                 downsample=False,
                 kernel_size=None,
                 stride=None,
                 padding=None,
                 p=0,
                 groups=1,
                 residual=False,
                 norm='bn'):
        """conv-bn-relu."""
        super(ConvNormRelu, self).__init__()
        self.residual = residual
        self.norm_type = norm
        # kernel_size = k
        # stride = s

        if kernel_size is None and stride is None:
            if not downsample:
                kernel_size = 3
                stride = 1
            else:
                kernel_size = 4
                stride = 2

        if padding is None:
            if isinstance(kernel_size, int) and isinstance(stride, tuple):
                padding = tuple(int((kernel_size - st) / 2) for st in stride)
            elif isinstance(kernel_size, tuple) and isinstance(stride, int):
                padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
            elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
                padding = tuple(
                    int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
            else:
                padding = int((kernel_size - stride) / 2)

        if self.residual:
            if downsample:
                if type == '1d':
                    self.residual_layer = nn.Sequential(
                        nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=out_channels,
                            kernel_size=kernel_size,
                            stride=stride,
                            padding=padding))
                elif type == '2d':
                    self.residual_layer = nn.Sequential(
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=out_channels,
                            kernel_size=kernel_size,
                            stride=stride,
                            padding=padding))
            else:
                if in_channels == out_channels:
                    self.residual_layer = nn.Identity()
                else:
                    if type == '1d':
                        self.residual_layer = nn.Sequential(
                            nn.Conv1d(
                                in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                stride=stride,
                                padding=padding))
                    elif type == '2d':
                        self.residual_layer = nn.Sequential(
                            nn.Conv2d(
                                in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                stride=stride,
                                padding=padding))

        in_channels = in_channels * groups
        out_channels = out_channels * groups
        if type == '1d':
            self.conv = nn.Conv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups)
            self.norm = nn.BatchNorm1d(out_channels)
            self.dropout = nn.Dropout(p=p)
        elif type == '2d':
            self.conv = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups)
            self.norm = nn.BatchNorm2d(out_channels)
            self.dropout = nn.Dropout2d(p=p)
        if norm == 'gn':
            self.norm = nn.GroupNorm(2, out_channels)
        elif norm == 'ln':
            self.norm = nn.LayerNorm(out_channels)
        if leaky:
            self.relu = nn.LeakyReLU(negative_slope=0.2)
        else:
            self.relu = nn.ReLU()

    def forward(self, x, **kwargs):
        if self.norm_type == 'ln':
            out = self.dropout(self.conv(x))
            out = self.norm(out.transpose(1, 2)).transpose(1, 2)
        else:
            out = self.norm(self.dropout(self.conv(x)))
        if self.residual:
            residual = self.residual_layer(x)
            out += residual
        return self.relu(out)


================================================
FILE: models/pca.py
================================================
import numpy as np
import torch
import torch.nn as nn


class PCALayer(nn.Module):

    def __init__(self, pca_info_path: str):
        super().__init__()
        pca_info = np.load(pca_info_path)
        feat_to_data = pca_info['components_to_data'].astype(np.float32)
        data_to_feat = feat_to_data.T
        data_bias = pca_info['original_data_mean'].astype(np.float32)
        feat_mean = pca_info['data_components_mean'].astype(np.float32)
        feat_std = pca_info['data_components_std'].astype(np.float32)
        self.register_buffer('feat_to_data', torch.from_numpy(feat_to_data))
        self.register_buffer('data_to_feat', torch.from_numpy(data_to_feat))
        self.register_buffer('data_bias', torch.from_numpy(data_bias))
        self.register_buffer('feat_mean', torch.from_numpy(feat_mean))
        self.register_buffer('feat_std', torch.from_numpy(feat_std))
        self.pca_dim = len(feat_mean)

    def encode(
        self,
        x: torch.Tensor,
        pca_dim: int = None,
    ):
        x = x - self.data_bias
        feat = x @ (self.data_to_feat[:, :pca_dim])
        return feat

    def decode(self, feat: torch.Tensor, pca_dim: int = None):
        data = feat[..., :pca_dim] @ (self.feat_to_data[:pca_dim])
        data = data + self.data_bias
        return data


class PCA:

    def __init__(self, trunc_dim: int = None):
        super().__init__()

    def buld_pca_for_dataset(self, dataset, trunc_dim: int = 1024):
        annot_array = self.load_dataset(dataset)
        self.build_pca(annot_array, trunc_dim)

    def load_dataset(self, dataset):
        annot_list = []
        indices = np.arange(len(dataset))
        for i in indices:
            data_dict = dataset.__getitem__(i)
            annot_list.append(data_dict['data'] - data_dict['template'])
        annot_array = np.concatenate(annot_list, axis=0)
        return annot_array

    def build_incremental_PCA(
        self,
        in_data: np.ndarray,
        trunc_dim=None,
        batch_size=1024,
    ):
        indices = np.arange(len(in_data))
        np.random.shuffle(indices)
        in_data = in_data[indices]
        if in_data.dtype != np.float32:
            in_data = in_data.astype(np.float32)
        from sklearn.decomposition import IncrementalPCA
        ipca = IncrementalPCA(n_components=trunc_dim, batch_size=batch_size)
        data_mean = in_data.mean(0)
        in_data = in_data - data_mean
        in_data_slices = np.split(
            in_data, np.arange(batch_size, len(in_data), batch_size))
        for batch_data in in_data_slices:
            if len(batch_data) != batch_size:
                continue
            ipca.partial_fit(batch_data)

        S = ipca.singular_values_
        components_to_data = ipca.components_
        explained_variance_ratio = ipca.explained_variance_ratio_
        data_components = ipca.transform(in_data)
        data_components_mean = data_components.mean(0)
        data_components_std = data_components.std(0)
        ret_dict = {
            'explained_variance_ratio': explained_variance_ratio,
            'components_to_data': components_to_data,
            'original_data_mean': data_mean,
            'S': S,
            'data_components_mean': data_components_mean,
            'data_components_std': data_components_std
        }
        return ret_dict

    def build_pca(
        self,
        in_data: np.ndarray,
        trunc_dim=None,
        save_compactness=True,
        check_svd=True,
    ):
        ret_dict = {}
        in_data = torch.Tensor(in_data).cuda()
        count, in_dim = in_data.shape
        mean_data = in_data.mean(dim=0)
        in_data_center = in_data - mean_data  # (count, in_dim)
        in_data_center = in_data_center.cpu()
        #  U (count, count), S(count, ), Vt (count, in_dim)
        U, S, Vt = torch.linalg.svd(in_data_center, full_matrices=False)

        if check_svd:
            S_mat = torch.diag(S)
            rebuilt = U @ S_mat @ Vt
            is_precise = torch.allclose(rebuilt, in_data_center)
            print('svd rebuilt all close ?', is_precise)

        var_S = S**2
        if save_compactness:
            compact_vector = var_S / var_S.sum()
            compact_value = compact_vector[0:trunc_dim].sum()
            ret_dict['compactness, ', compact_value]

        pca_bases = Vt[0:trunc_dim].reshape(trunc_dim, -1)
        ret_dict['pca_bases'] = pca_bases
        ret_dict['pca_std'] = S[0:trunc_dim]
        ret_dict['mean_data'] = mean_data
        return ret_dict


================================================
FILE: models/unitalker.py
================================================
import os.path as osp
import torch.nn as nn

from dataset.dataset_config import dataset_config
from models.wav2vec2 import Wav2Vec2Model
from .base_model import BaseModel
# from models.hubert import HubertModel
from models.wavlm import WavLMModel
class OutHead(BaseModel):

    def __init__(self, args, out_dim: int):
        super().__init__()
        head_layer = []

        in_dim = args.decoder_dimension
        for i in range(args.headlayer - 1):
            head_layer.append(nn.Linear(in_dim, in_dim))
            head_layer.append(nn.Tanh())
        head_layer.append(nn.Linear(in_dim, out_dim))
        self.head_layer = nn.Sequential(*head_layer)

    def forward(self, x):
        return self.head_layer(x)


class UniTalker(BaseModel):

    def __init__(self, args):
        super().__init__()
        self.args = args
        if 'wav2vec2' in args.audio_encoder_repo:
            self.audio_encoder = Wav2Vec2Model.from_pretrained(
                args.audio_encoder_repo)
        elif 'wavlm' in args.audio_encoder_repo:
            self.audio_encoder = WavLMModel.from_pretrained(
                args.audio_encoder_repo)   
        else:
            raise ValueError("wrong audio_encoder_repo")
        self.audio_encoder.feature_extractor._freeze_parameters()
        if args.freeze_wav2vec:
            self.audio_encoder._freeze_wav2vec2_parameters()

        if args.decoder_type == 'conv':
            from .unitalker_decoder import UniTalkerDecoderTCN as Decoder
        elif args.decoder_type == 'transformer':
            from .unitalker_decoder import \
                UniTalkerDecoderTransformer as Decoder
        else:
            ValueError('unknown decoder type ')
        self.decoder = Decoder(args)

        if args.use_pca:
            pca_layer_dict = {}
            from models.pca import PCALayer
            for dataset_name in args.dataset:
                if dataset_config[dataset_name]['pca']:
                    annot_type = dataset_config[dataset_name]['annot_type']
                    dirname = dataset_config[dataset_name]['dirname']
                    pca_path = osp.join(args.data_root, dirname, 'pca.npz')
                    pca_layer_dict[annot_type] = PCALayer(pca_path)
            self.pca_layer_dict = nn.ModuleDict(pca_layer_dict)
        else:
            self.pca_layer_dict = {}

        self.pca_dim = args.pca_dim
        self.use_pca = args.use_pca
        self.interpolate_pos = args.interpolate_pos

        out_head_dict = {}
        for dataset_name in args.dataset:
            annot_type = dataset_config[dataset_name]['annot_type']
            out_dim = dataset_config[dataset_name]['annot_dim']
            if (args.use_pca is True) and (annot_type in self.pca_layer_dict):
                pca_dim = self.pca_layer_dict[annot_type].pca_dim
                out_dim = min(args.pca_dim, out_dim, pca_dim)
            if args.headlayer == 1:
                out_projection = nn.Linear(args.decoder_dimension, out_dim)
                nn.init.constant_(out_projection.weight, 0)
                nn.init.constant_(out_projection.bias, 0)
            else:
                out_projection = OutHead(args, out_dim)
            out_head_dict[annot_type] = out_projection
        self.out_head_dict = nn.ModuleDict(out_head_dict)
        self.identity_num = args.identity_num
        return

    def forward(self,
                audio,
                template,
                face_motion,
                style_idx,
                annot_type: str,
                fps: float = None):
        if face_motion is not None:
            frame_num = face_motion.shape[1]
        else:
            frame_num = round(audio.shape[-1] / 16000 * fps)

        hidden_states = self.audio_encoder(
            audio, frame_num=frame_num, interpolate_pos=self.interpolate_pos)
        hidden_states = hidden_states.last_hidden_state
        decoder_out = self.decoder(hidden_states, style_idx, frame_num)

        if (not self.use_pca) or (annot_type not in self.pca_layer_dict):
            out_motion = self.out_head_dict[annot_type](decoder_out)
            out_motion = out_motion + template
            out_pca, gt_pca = None, None
        else:
            out_pca = self.out_head_dict[annot_type](decoder_out)
            out_motion = self.pca_layer_dict[annot_type].decode(
                out_pca, self.pca_dim)
            out_motion = out_motion + template
            if face_motion is not None:
                gt_pca = self.pca_layer_dict[annot_type].encode(
                    face_motion - template, self.pca_dim)
            else:
                gt_pca = None

        return out_motion, out_pca, gt_pca


================================================
FILE: models/unitalker_decoder.py
================================================
# yapf: disable
import torch
import torch.nn as nn

from .base_model import BaseModel
from .model_utils import (
    TCN, PeriodicPositionalEncoding, enc_dec_mask, init_biased_mask,
    linear_interpolation,
)

# yapf: enable


class UniTalkerDecoderTCN(BaseModel):

    def __init__(self, args) -> None:
        super().__init__()
        self.learnable_style_emb = nn.Embedding(args.identity_num, 64)
        in_dim = args.decoder_dimension
        out_dim = args.decoder_dimension
        self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, in_dim)
        self.tcn = TCN(in_dim + 64, out_dim)
        self.interpolate_pos = args.interpolate_pos

    def forward(self,
                hidden_states: torch.Tensor,
                style_idx: torch.Tensor,
                frame_num: int = None):
        batch_size = len(hidden_states)
        if style_idx is None:
            style_embedding = self.learnable_style_emb.weight[-1].unsqueeze(
                0).repeat(batch_size, 1)
        else:
            style_embedding = self.learnable_style_emb(style_idx)
        feature = self.audio_feature_map(hidden_states).transpose(1, 2)
        feature = self.tcn(feature, style_embedding)
        feature = feature.transpose(1, 2)
        if self.interpolate_pos == 2:
            feature = linear_interpolation(feature, output_len=frame_num)
        return feature


class UniTalkerDecoderTransformer(BaseModel):

    def __init__(self, args) -> None:
        super().__init__()
        out_dim = args.decoder_dimension
        self.learnable_style_emb = nn.Embedding(args.identity_num, out_dim)
        self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, out_dim)
        self.PPE = PeriodicPositionalEncoding(
            out_dim, period=args.period, max_seq_len=3000)
        self.biased_mask = init_biased_mask(
            n_head=4, max_seq_len=3000, period=args.period)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=out_dim,
            nhead=4,
            dim_feedforward=2 * out_dim,
            batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer, num_layers=1)
        self.interpolate_pos = args.interpolate_pos

    def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,
                frame_num: int):
        obj_embedding = self.learnable_style_emb(style_idx)
        obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1)
        hidden_states = self.audio_feature_map(hidden_states)
        style_input = self.PPE(obj_embedding)
        tgt_mask = self.biased_mask[:, :style_input.shape[1], :style_input.
                                    shape[1]].clone().detach().to(
                                        device=style_input.device)
        memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1],
                                   frame_num)
        feat_out = self.transformer_decoder(
            style_input,
            hidden_states,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask)
        if self.interpolate_pos == 2:
            feat_out = linear_interpolation(feat_out, output_len=frame_num)
        return feat_out


================================================
FILE: models/wav2vec2.py
================================================
import numpy as np
import torch
from transformers import Wav2Vec2Model
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Tuple

from .model_utils import linear_interpolation


# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
def _compute_mask_indices(
    shape: Tuple[int, int],
    mask_prob: float,
    mask_length: int,
    attention_mask: Optional[torch.Tensor] = None,
    min_masks: int = 0,
) -> np.ndarray:
    bsz, all_sz = shape
    mask = np.full((bsz, all_sz), False)

    all_num_mask = int(mask_prob * all_sz / float(mask_length) +
                       np.random.rand())
    all_num_mask = max(min_masks, all_num_mask)
    mask_idcs = []
    padding_mask = attention_mask.ne(1) if attention_mask is not None else None
    for i in range(bsz):
        if padding_mask is not None:
            sz = all_sz - padding_mask[i].long().sum().item()
            num_mask = int(mask_prob * sz / float(mask_length) +
                           np.random.rand())
            num_mask = max(min_masks, num_mask)
        else:
            sz = all_sz
            num_mask = all_num_mask

        lengths = np.full(num_mask, mask_length)

        if sum(lengths) == 0:
            lengths[0] = min(mask_length, sz - 1)

        min_len = min(lengths)
        if sz - min_len <= num_mask:
            min_len = sz - num_mask - 1

        mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
        mask_idc = np.asarray([
            mask_idc[j] + offset for j in range(len(mask_idc))
            for offset in range(lengths[j])
        ])
        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))

    min_len = min([len(m) for m in mask_idcs])
    for i, mask_idc in enumerate(mask_idcs):
        if len(mask_idc) > min_len:
            mask_idc = np.random.choice(mask_idc, min_len, replace=False)
        mask[i, mask_idc] = True
    return mask


class Wav2Vec2Model(Wav2Vec2Model):

    def __init__(self, config):
        super().__init__(config)

    def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
        for param in self.parameters():
            param.requires_grad = (not do_freeze)

    def forward(
        self,
        input_values,
        attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        frame_num=None,
        interpolate_pos: int = 0,
    ):
        self.config.output_attentions = True
        output_attentions = output_attentions if \
            output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else
            self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None \
            else self.config.use_return_dict

        conv_features = self.feature_extractor(input_values)
        hidden_states = conv_features.transpose(1, 2)
        if interpolate_pos == 0:
            hidden_states = linear_interpolation(
                hidden_states, output_len=frame_num)

        if attention_mask is not None:
            output_lengths = self._get_feat_extract_output_lengths(
                attention_mask.sum(-1))
            attention_mask = torch.zeros(
                hidden_states.shape[:2],
                dtype=hidden_states.dtype,
                device=hidden_states.device)
            attention_mask[(torch.arange(
                attention_mask.shape[0],
                device=hidden_states.device), output_lengths - 1)] = 1
            attention_mask = attention_mask.flip([-1]).cumsum(-1).flip(
                [-1]).bool()

        hidden_states, _ = self.feature_projection(hidden_states)

        if self.config.apply_spec_augment and self.training:
            batch_size, sequence_length, hidden_size = hidden_states.size()
            if self.config.mask_time_prob > 0:
                mask_time_indices = _compute_mask_indices(
                    (batch_size, sequence_length),
                    self.config.mask_time_prob,
                    self.config.mask_time_length,
                    attention_mask=attention_mask,
                    min_masks=2,
                )
                hidden_states[torch.from_numpy(
                    mask_time_indices)] = self.masked_spec_embed.to(
                        hidden_states.dtype)
            if self.config.mask_feature_prob > 0:
                mask_feature_indices = _compute_mask_indices(
                    (batch_size, hidden_size),
                    self.config.mask_feature_prob,
                    self.config.mask_feature_length,
                )
                mask_feature_indices = torch.from_numpy(
                    mask_feature_indices).to(hidden_states.device)
                hidden_states[mask_feature_indices[:, None].expand(
                    -1, sequence_length, -1)] = 0
        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = encoder_outputs[0]
        if interpolate_pos == 1:
            hidden_states = linear_interpolation(
                hidden_states, output_len=frame_num)

        if not return_dict:
            return (hidden_states, ) + encoder_outputs[1:]

        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )


================================================
FILE: models/wavlm.py
================================================
import numpy as np
import torch
from transformers import WavLMModel
from transformers.modeling_outputs import Wav2Vec2BaseModelOutput
from typing import Optional, Tuple, Union



from .model_utils import linear_interpolation


# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.


class WavLMModel(WavLMModel):
    def __init__(self, config):
        super().__init__(config)

    def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
        for param in self.parameters():
            param.requires_grad = (not do_freeze)

    def forward(
        self,
        input_values: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        mask_time_indices: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        frame_num=None,
        interpolate_pos: int = 0,
    ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        extract_features = self.feature_extractor(input_values)
        extract_features = extract_features.transpose(1, 2)

        if interpolate_pos == 0:
            extract_features = linear_interpolation(
                extract_features, output_len=frame_num)
            
        if attention_mask is not None:
            # compute reduced attention_mask corresponding to feature vectors
            attention_mask = self._get_feature_vector_attention_mask(
                extract_features.shape[1], attention_mask, add_adapter=False
            )

        hidden_states, extract_features = self.feature_projection(extract_features)
        hidden_states = self._mask_hidden_states(
            hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
        )

        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = encoder_outputs[0]

        if interpolate_pos == 1:
            hidden_states = linear_interpolation(
                hidden_states, output_len=frame_num)
            
        if self.adapter is not None:
            hidden_states = self.adapter(hidden_states)

        if not return_dict:
            return (hidden_states, extract_features) + encoder_outputs[1:]

        return Wav2Vec2BaseModelOutput(
            last_hidden_state=hidden_states,
            extract_features=extract_features,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

================================================
FILE: utils/config.py
================================================
# -----------------------------------------------------------------------------
# Functions for parsing args
# -----------------------------------------------------------------------------
import copy
import os
import yaml
from ast import literal_eval


class CfgNode(dict):
    """CfgNode represents an internal node in the configuration tree.

    It's a simple dict-like container that allows for attribute-based access to
    keys.
    """

    def __init__(self, init_dict=None, key_list=None, new_allowed=False):
        # Recursively convert nested dictionaries in init_dict into CfgNodes
        init_dict = {} if init_dict is None else init_dict
        key_list = [] if key_list is None else key_list
        for k, v in init_dict.items():
            if type(v) is dict:
                # Convert dict to CfgNode
                init_dict[k] = CfgNode(v, key_list=key_list + [k])
        super(CfgNode, self).__init__(init_dict)

    def __getattr__(self, name):
        if name in self:
            return self[name]
        else:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        self[name] = value

    def __str__(self):

        def _indent(s_, num_spaces):
            s = s_.split('\n')
            if len(s) == 1:
                return s_
            first = s.pop(0)
            s = [(num_spaces * ' ') + line for line in s]
            s = '\n'.join(s)
            s = first + '\n' + s
            return s

        r = ''
        s = []
        for k, v in sorted(self.items()):
            separator = '\n' if isinstance(v, CfgNode) else ' '
            attr_str = '{}:{}{}'.format(str(k), separator, str(v))
            attr_str = _indent(attr_str, 2)
            s.append(attr_str)
        r += '\n'.join(s)
        return r

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__,
                               super(CfgNode, self).__repr__())


def load_cfg_from_cfg_file(file):
    cfg = {}
    assert os.path.isfile(file) and file.endswith('.yaml'), \
        '{} is not a yaml file'.format(file)

    with open(file, 'r') as f:
        cfg_from_file = yaml.safe_load(f)

    for key in cfg_from_file:
        for k, v in cfg_from_file[key].items():
            cfg[k] = v

    cfg = CfgNode(cfg)
    return cfg


def merge_cfg_from_list(cfg, cfg_list):
    new_cfg = copy.deepcopy(cfg)
    assert len(cfg_list) % 2 == 0
    for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
        subkey = full_key.split('.')[-1]
        assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
        value = _decode_cfg_value(v)
        value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
                                                 full_key)
        setattr(new_cfg, subkey, value)

    return new_cfg


def _decode_cfg_value(v):
    """Decodes a raw config value (e.g., from a yaml config files or command
    line argument) into a Python object."""
    # All remaining processing is only applied to strings
    if not isinstance(v, str):
        return v
    # Try to interpret `v` as a:
    #   string, number, tuple, list, dict, boolean, or None
    try:
        v = literal_eval(v)
    # The following two excepts allow v to pass through when it represents a
    # string.
    #
    # Longer explanation:
    # The type of v is always a string (before calling literal_eval), but
    # sometimes it *represents* a string and other times a data structure, like
    # a list. In the case that v represents a string, what we got back from the
    # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
    # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
    # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
    # will raise a SyntaxError.
    except ValueError:
        pass
    except SyntaxError:
        pass
    return v


def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
    """Checks that `replacement`, which is intended to replace `original` is of
    the right type.

    The type is correct if it matches exactly or is one of a few cases in which
    the type can be easily coerced.
    """
    original_type = type(original)
    replacement_type = type(replacement)

    # The types must match (with some exceptions)
    if replacement_type == original_type or original is None:
        return replacement

    # types match from_type and to_type
    def conditional_cast(from_type, to_type):
        if replacement_type == from_type and original_type == to_type:
            return True, to_type(replacement)
        else:
            return False, None

    # Conditionally casts
    # list <-> tuple
    casts = [(tuple, list), (list, tuple)]
    # For py2: allow converting from str (bytes) to a unicode string
    try:
        casts.append((str, unicode))  # noqa: F821
    except Exception:
        pass

    for (from_type, to_type) in casts:
        converted, converted_value = conditional_cast(from_type, to_type)
        if converted:
            return converted_value

    raise ValueError(
        'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
        'key: {}'.format(original_type, replacement_type, original,
                         replacement, full_key))


================================================
FILE: utils/utils.py
================================================
import argparse
import logging
import numpy as np
import os
import random
import torch
from copy import deepcopy

from . import config


def get_audio_encoder_dim(audio_encoder: str):
    if audio_encoder == 'microsoft/wavlm-base-plus':
        return 768
    elif audio_encoder == 'facebook/wav2vec2-large-xlsr-53':
        return  1024
    else:
        raise ValueError("wrong audio_encoder")

def filer_list(in_list: list, key: str):
    ret_list = []
    for line in in_list:
        if line.startswith(key):
            ret_list.append(line)
    return ret_list


def count_checkpoint_params(ckpt: dict):
    params = 0
    pca_params = 0
    for k, v in ckpt.items():
        if 'pca' in k:
            print(k)
            pca_params += np.prod(v.size())
        params += np.prod(v.size())
    return params, pca_params


def read_obj(in_path):
    with open(in_path, 'r') as obj_file:
        # Read the lines of the OBJ file
        lines = obj_file.readlines()

    # Initialize empty lists for vertices and faces
    verts = []
    faces = []
    for line in lines:
        line = line.strip()  # Remove leading/trailing whitespace
        elements = line.split()  # Split the line into elements

        if len(elements) == 0:
            continue  # Skip empty lines

        # Check the type of line (vertex or face)
        if elements[0] == 'v':
            # Vertex line
            x, y, z = map(float,
                          elements[1:4])  # Extract the vertex coordinates
            verts.append((x, y, z))  # Add the vertex to the list
        elif elements[0] == 'f':
            # Face line
            face_indices = [
                int(index.split('/')[0]) for index in elements[1:]
            ]  # Extract the vertex indices
            faces.append(face_indices)  # Add the face to the list
    verts = np.array(verts)
    faces = np.array(faces)
    if faces.min() == 1:
        faces = faces - 1
    return verts, faces


def get_logger(log_file: str):
    logger_name = 'main-logger'
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.INFO)
    handler = logging.FileHandler(log_file, mode='w')
    fmt = '[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d]=>%(message)s'  # noqa: E501
    handler.setFormatter(logging.Formatter(fmt))
    for hdl in logger.handlers:
        logger.removeHandler(hdl)
    logger.addHandler(handler)
    return logger


def get_template_dict(data_root, annot_type: str):
    import pickle
    if annot_type == 'BIWI_23370_vertices':
        template_pkl_path = os.path.join(data_root, 'BIWI', 'templates.pkl')
    elif annot_type == 'FLAME_5023_vertices':
        template_pkl_path = os.path.join(data_root, 'vocaset', 'templates.pkl')
    with open(template_pkl_path, 'rb') as f:
        template_dict = pickle.load(f, encoding='latin')
    return template_dict


def filter_unitalker_state_dict(in_dict):
    out_dict = {}
    for k, v in in_dict.items():
        if not 'pca_layer_dict' in k:
            out_dict[k] = v
    return out_dict 

def get_template_verts(data_root, dataset_name: str, subject: int):
    from dataset.dataset_config import dataset_config
    template_path = os.path.join(data_root, dataset_config[dataset_name]['dirname'], 'id_template.npy')
    return np.load(template_path)[subject]


def load_ckpt(model, ckpt_path, re_init_decoder_and_head: bool = False):
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    id_embedding_key = 'decoder.learnable_style_emb.weight'
    if len(checkpoint[id_embedding_key]) != len(
            model.state_dict()[id_embedding_key]):
        target_id_num = len(model.state_dict()[id_embedding_key])
        checkpoint[id_embedding_key] = checkpoint[id_embedding_key][-1][
            None].repeat(target_id_num, 1)
    if re_init_decoder_and_head is True:
        ckpt_keys = list(checkpoint.keys())
        encoder_keys = filer_list(ckpt_keys, 'audio_encoder')
        reinit_keys = list(set(ckpt_keys) - set(encoder_keys))
        for k in reinit_keys:
            if k in model.state_dict():
                checkpoint[k] = model.state_dict()[k]
    model.load_state_dict(checkpoint, strict=False)
    return


class AverageMeter(object):
    """Computes and stores the average and current value."""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def get_avg(self, ):
        return self.avg


def get_average_meter_dict(args, phase):
    if phase == 'train':
        metric_name_list = ['rec_L2_loss', 'pca_loss']
    elif phase == 'val' or phase == 'test':
        metric_name_list = ['rec_L2_loss', 'LVE', 'LVD']
    else:
        raise ValueError('wrong phase for average meter')
    average_meter_dict_template = {}
    for k in metric_name_list:
        average_meter_dict_template[k] = AverageMeter()
    mixed_dataset_name = 'mixed_dataset'
    dataset_name_list = args.dataset + [mixed_dataset_name]
    average_meter_dict = {}
    for dataset_name in dataset_name_list:
        average_meter_dict[dataset_name] = deepcopy(
            average_meter_dict_template)
    return average_meter_dict, metric_name_list


def log_datasetloss(
    logger,
    epoch: int,
    phase: str,
    dataset_loss_dict: dict,
    only_mix: bool = True,
):
    base_str = f'{phase.upper():<5} Epoch: {epoch:03d} '
    for dataset_name, loss_info in dataset_loss_dict.items():
        if only_mix and dataset_name != 'mixed_dataset':
            continue
        for loss_type, avg_meter in loss_info.items():
            base_str += f'{loss_type}_{dataset_name}:{avg_meter.avg: .8f}, '
    logger.info(base_str)
    return base_str


def write_to_tensorboard(writer,
                         epoch: int,
                         phase: str,
                         dataset_loss_dict: dict,
                         only_mix: bool = False):
    for dataset_name, loss_info in dataset_loss_dict.items():
        if only_mix and dataset_name != 'mixed_dataset':
            continue
        for loss_type, avg_meter in loss_info.items():
            key = f'{phase}/{loss_type}_{dataset_name}'
            writer.add_scalar(key, avg_meter.avg, epoch)
    return 0


def get_parser():
    parser = argparse.ArgumentParser(description=' ')
    parser.add_argument(
        '--config',
        type=str,
        default='config/unitalker.yaml',
        help='config file')
    parser.add_argument(
        '--condition_id', type=str, default='common', help='condition id')
    parser.add_argument(
        'opts', help=' ', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()
    assert args.config is not None
    cfg = config.load_cfg_from_cfg_file(args.config)
    if args.opts is not None:
        cfg = config.merge_cfg_from_list(cfg, args.opts)
    cfg.condition_id = args.condition_id
    cfg.dataset = cfg.dataset.split(',')
    cfg.demo_dataset = cfg.demo_dataset.split(',')
    if isinstance(cfg.duplicate_list, str):
        cfg.duplicate_list = cfg.duplicate_list.split(',')
        cfg.duplicate_list = [int(i) for i in cfg.duplicate_list]
    return cfg


def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return
Download .txt
gitextract_hq8r3so5/

├── .gitignore
├── LICENSE
├── README.md
├── config/
│   └── unitalker.yaml
├── dataset/
│   ├── data_item.py
│   ├── dataset.py
│   └── dataset_config.py
├── loss/
│   └── loss.py
├── main/
│   ├── demo.py
│   ├── render.py
│   ├── train.py
│   ├── train_eval_loop.py
│   └── train_pca.py
├── models/
│   ├── base_model.py
│   ├── hubert.py
│   ├── model_utils.py
│   ├── pca.py
│   ├── unitalker.py
│   ├── unitalker_decoder.py
│   ├── wav2vec2.py
│   └── wavlm.py
└── utils/
    ├── config.py
    └── utils.py
Download .txt
SYMBOL INDEX (150 symbols across 18 files)

FILE: dataset/data_item.py
  class DataItem (line 8) | class DataItem:
    method __init__ (line 10) | def __init__(self,
    method load_audio (line 49) | def load_audio(self, audio_path: str):
    method truncate (line 55) | def truncate(self, audio_array: np.ndarray, sr: int,
    method offset_id (line 65) | def offset_id(self, offset: int):
    method scale_and_offset (line 70) | def scale_and_offset(self,
    method to_dict (line 76) | def to_dict(self, ) -> Any:
    method get_dict (line 91) | def get_dict(self, ):

FILE: dataset/dataset.py
  class AudioFaceDataset (line 13) | class AudioFaceDataset(data.Dataset):
    method __init__ (line 15) | def __init__(self, data_label_path: str, args: dict = None) -> None:
    method __len__ (line 44) | def __len__(self, ):
    method __getitem__ (line 47) | def __getitem__(self, index: Any) -> Any:
    method get_identity_num (line 51) | def get_identity_num(self, ):
  class MixAudioFaceDataset (line 55) | class MixAudioFaceDataset(AudioFaceDataset):
    method __init__ (line 57) | def __init__(self,
  function get_dataset_list (line 80) | def get_dataset_list(args):
  function get_single_dataset (line 104) | def get_single_dataset(args, json_path: str = ''):
  function get_dataloaders (line 123) | def get_dataloaders(args):

FILE: loss/loss.py
  class FlameParams (line 11) | class FlameParams:
  class FlameLayer (line 18) | class FlameLayer(nn.Module):
    method __init__ (line 20) | def __init__(self, model_path: str) -> None:
    method flame_params_from_3dmm (line 76) | def flame_params_from_3dmm(self,
    method forward (line 97) | def forward(self, tensor_3dmm: torch.Tensor):
  class FlameParamLossForDADHead (line 126) | class FlameParamLossForDADHead(nn.Module):
    method __init__ (line 128) | def __init__(self, mouth_indices_path: str, loss_mask_path: str,
    method forward (line 138) | def forward(self, x: torch.Tensor, target: torch.Tensor):
    method get_vertices (line 147) | def get_vertices(
    method mouth_metric (line 154) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
  class VerticesLoss (line 165) | class VerticesLoss(nn.Module):
    method __init__ (line 167) | def __init__(self, mouth_indices_path: str) -> None:
    method forward (line 173) | def forward(self, x: torch.Tensor, target: torch.Tensor):
    method get_vertices (line 176) | def get_vertices(
    method mouth_metric (line 182) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
  class BlendShapeLoss (line 193) | class BlendShapeLoss(nn.Module):
    method __init__ (line 195) | def __init__(self,
    method forward (line 215) | def forward(self, x: torch.Tensor, target: torch.Tensor):
    method bs2vertices (line 219) | def bs2vertices(self, x: torch.Tensor):
    method get_vertices (line 224) | def get_vertices(
    method mouth_metric (line 231) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
  class ParamLoss (line 244) | class ParamLoss(nn.Module):
    method __init__ (line 246) | def __init__(self, weight: float):
    method forward (line 252) | def forward(self, x: torch.Tensor, target: torch.Tensor):
    method mouth_metric (line 255) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
    method get_vertices (line 258) | def get_vertices(
  class UniTalkerLoss (line 265) | class UniTalkerLoss(nn.Module):
    method __init__ (line 267) | def __init__(self, args) -> None:
    method forward (line 344) | def forward(
    method pca_loss (line 352) | def pca_loss(self, x: torch.Tensor, target: torch.Tensor):
    method mouth_metric (line 355) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor,
    method get_vertices (line 359) | def get_vertices(self, x: torch.Tensor, annot_type: str):

FILE: main/demo.py
  function get_all_audios (line 14) | def get_all_audios(audio_root:str):
  function split_long_audio (line 26) | def split_long_audio(
  function merge_out_list (line 47) | def merge_out_list(out_list: list, fps:int):
  function main (line 69) | def main():

FILE: main/render.py
  function render_meshes (line 23) | def render_meshes(mesh_vertices: torch.Tensor,
  function read_obj (line 131) | def read_obj(in_path):
  function pad_for_libx264 (line 161) | def pad_for_libx264(image_array):
  function array_to_video (line 209) | def array_to_video(
  function get_obj_faces (line 287) | def get_obj_faces(annot_type: str):
  function vis_model_out_proxy (line 303) | def vis_model_out_proxy():

FILE: main/train.py
  function main (line 21) | def main():

FILE: main/train_eval_loop.py
  function train_epoch (line 8) | def train_epoch(data_loader, model, loss_module, optimizer, epoch, args):
  function validate_epoch (line 53) | def validate_epoch(data_loader, model, loss_module, epoch, phase, args):

FILE: main/train_pca.py
  function main (line 8) | def main():
  function main_worker (line 15) | def main_worker(cfg):

FILE: models/base_model.py
  class BaseModel (line 5) | class BaseModel(nn.Module):
    method __init__ (line 8) | def __init__(self):
    method forward (line 12) | def forward(self, *x):
    method freeze_model (line 19) | def freeze_model(self, do_freeze: bool = True):
    method summary (line 23) | def summary(self, logger, writer=None):

FILE: models/hubert.py
  function _compute_mask_indices (line 12) | def _compute_mask_indices(
  class HubertModel (line 61) | class HubertModel(HubertModel):
    method __init__ (line 63) | def __init__(self, config):
    method _freeze_parameters (line 66) | def _freeze_parameters(self, do_freeze: bool = True):
    method forward (line 70) | def forward(

FILE: models/model_utils.py
  function linear_interpolation (line 9) | def linear_interpolation(features, output_len: int):
  function init_biased_mask (line 17) | def init_biased_mask(n_head, max_seq_len, period):
  function enc_dec_mask (line 53) | def enc_dec_mask(device, T, S):
  class PeriodicPositionalEncoding (line 61) | class PeriodicPositionalEncoding(nn.Module):
    method __init__ (line 63) | def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
    method forward (line 78) | def forward(self, x):
  class TCN (line 83) | class TCN(nn.Module):
    method __init__ (line 85) | def __init__(self, in_dim, out_dim):
    method forward (line 91) | def forward(self, spectrogram, style_embed, time_steps=None):
  class SeqTranslator1D (line 106) | class SeqTranslator1D(nn.Module):
    method __init__ (line 109) | def __init__(self,
    method forward (line 144) | def forward(self, x):
  class ConvNormRelu (line 148) | class ConvNormRelu(nn.Module):
    method __init__ (line 155) | def __init__(self,
    method forward (line 264) | def forward(self, x, **kwargs):

FILE: models/pca.py
  class PCALayer (line 6) | class PCALayer(nn.Module):
    method __init__ (line 8) | def __init__(self, pca_info_path: str):
    method encode (line 23) | def encode(
    method decode (line 32) | def decode(self, feat: torch.Tensor, pca_dim: int = None):
  class PCA (line 38) | class PCA:
    method __init__ (line 40) | def __init__(self, trunc_dim: int = None):
    method buld_pca_for_dataset (line 43) | def buld_pca_for_dataset(self, dataset, trunc_dim: int = 1024):
    method load_dataset (line 47) | def load_dataset(self, dataset):
    method build_incremental_PCA (line 56) | def build_incremental_PCA(
    method build_pca (line 94) | def build_pca(

FILE: models/unitalker.py
  class OutHead (line 9) | class OutHead(BaseModel):
    method __init__ (line 11) | def __init__(self, args, out_dim: int):
    method forward (line 22) | def forward(self, x):
  class UniTalker (line 26) | class UniTalker(BaseModel):
    method __init__ (line 28) | def __init__(self, args):
    method forward (line 87) | def forward(self,

FILE: models/unitalker_decoder.py
  class UniTalkerDecoderTCN (line 14) | class UniTalkerDecoderTCN(BaseModel):
    method __init__ (line 16) | def __init__(self, args) -> None:
    method forward (line 25) | def forward(self,
  class UniTalkerDecoderTransformer (line 43) | class UniTalkerDecoderTransformer(BaseModel):
    method __init__ (line 45) | def __init__(self, args) -> None:
    method forward (line 63) | def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,

FILE: models/wav2vec2.py
  function _compute_mask_indices (line 12) | def _compute_mask_indices(
  class Wav2Vec2Model (line 61) | class Wav2Vec2Model(Wav2Vec2Model):
    method __init__ (line 63) | def __init__(self, config):
    method _freeze_wav2vec2_parameters (line 66) | def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
    method forward (line 70) | def forward(

FILE: models/wavlm.py
  class WavLMModel (line 16) | class WavLMModel(WavLMModel):
    method __init__ (line 17) | def __init__(self, config):
    method _freeze_wav2vec2_parameters (line 20) | def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
    method forward (line 24) | def forward(

FILE: utils/config.py
  class CfgNode (line 10) | class CfgNode(dict):
    method __init__ (line 17) | def __init__(self, init_dict=None, key_list=None, new_allowed=False):
    method __getattr__ (line 27) | def __getattr__(self, name):
    method __setattr__ (line 33) | def __setattr__(self, name, value):
    method __str__ (line 36) | def __str__(self):
    method __repr__ (line 58) | def __repr__(self):
  function load_cfg_from_cfg_file (line 63) | def load_cfg_from_cfg_file(file):
  function merge_cfg_from_list (line 79) | def merge_cfg_from_list(cfg, cfg_list):
  function _decode_cfg_value (line 93) | def _decode_cfg_value(v):
  function _check_and_coerce_cfg_value_type (line 121) | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):

FILE: utils/utils.py
  function get_audio_encoder_dim (line 12) | def get_audio_encoder_dim(audio_encoder: str):
  function filer_list (line 20) | def filer_list(in_list: list, key: str):
  function count_checkpoint_params (line 28) | def count_checkpoint_params(ckpt: dict):
  function read_obj (line 39) | def read_obj(in_path):
  function get_logger (line 73) | def get_logger(log_file: str):
  function get_template_dict (line 86) | def get_template_dict(data_root, annot_type: str):
  function filter_unitalker_state_dict (line 97) | def filter_unitalker_state_dict(in_dict):
  function get_template_verts (line 104) | def get_template_verts(data_root, dataset_name: str, subject: int):
  function load_ckpt (line 110) | def load_ckpt(model, ckpt_path, re_init_decoder_and_head: bool = False):
  class AverageMeter (line 129) | class AverageMeter(object):
    method __init__ (line 132) | def __init__(self):
    method reset (line 135) | def reset(self):
    method update (line 141) | def update(self, val, n=1):
    method get_avg (line 147) | def get_avg(self, ):
  function get_average_meter_dict (line 151) | def get_average_meter_dict(args, phase):
  function log_datasetloss (line 170) | def log_datasetloss(
  function write_to_tensorboard (line 187) | def write_to_tensorboard(writer,
  function get_parser (line 201) | def get_parser():
  function seed_everything (line 226) | def seed_everything(seed: int = 42):
Condensed preview — 23 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (126K chars).
[
  {
    "path": ".gitignore",
    "chars": 391,
    "preview": "# Byte-compiled / optimized files\n*__pycache__/\n*.py[cod]\n\n\n# Distribution / packaging\ndist/\nbuild/\n*.egg-info/\n*.egg\n\n#"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 5024,
    "preview": "\n# UniTalker: Scaling up Audio-Driven 3D Facial Animation through A Unified Model [ECCV2024]\n\n## Useful Links\n\n<div alig"
  },
  {
    "path": "config/unitalker.yaml",
    "chars": 1262,
    "preview": "DATA:\n  dataset: \"D0,D1,D2,D3,D4,D5,D6,D7\"\n  demo_dataset: \"D0,D1,D2,D3,D4,D6,D7\" # To visualize D5,  please download FL"
  },
  {
    "path": "dataset/data_item.py",
    "chars": 3432,
    "preview": "import librosa\nimport numpy as np\nfrom typing import Any\n\nfrom .dataset_config import dataset_config\n\n\nclass DataItem:\n\n"
  },
  {
    "path": "dataset/dataset.py",
    "chars": 6275,
    "preview": "import json\nimport numpy as np\nimport os.path as osp\nfrom torch.utils import data\nfrom tqdm import tqdm\nfrom transformer"
  },
  {
    "path": "dataset/dataset_config.py",
    "chars": 1648,
    "preview": "dataset_config = {\n    'D0': {\n        'dirname': 'D0_BIWI',\n        'annot_type': 'BIWI_23370_vertices',\n        'scale"
  },
  {
    "path": "loss/loss.py",
    "chars": 13328,
    "preview": "import numpy as np\nimport pickle\nimport torch\nfrom dataclasses import dataclass\nfrom smplx.lbs import lbs\nfrom smplx.uti"
  },
  {
    "path": "main/demo.py",
    "chars": 6550,
    "preview": "import librosa\nimport numpy as np\nimport os\nimport torch\nimport torch.optim\nimport torch.utils.data\nfrom transformers im"
  },
  {
    "path": "main/render.py",
    "chars": 12620,
    "preview": "\"\"\"Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V.\n\n(MPG) is holder of all proprietary rights on this com"
  },
  {
    "path": "main/train.py",
    "chars": 3210,
    "preview": "#!/usr/bin/env python\n# yapf: disable\nimport os\nimport torch\nimport torch.optim\nimport torch.utils.data\nfrom tensorboard"
  },
  {
    "path": "main/train_eval_loop.py",
    "chars": 3900,
    "preview": "import numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom utils.utils import log_datasetloss, write_to_tensorboard\n\n\nd"
  },
  {
    "path": "main/train_pca.py",
    "chars": 701,
    "preview": "#!/usr/bin/env python\nimport numpy as np\n\nfrom models.pca import PCA\nfrom utils.utils import get_parser, get_audio_encod"
  },
  {
    "path": "models/base_model.py",
    "chars": 997,
    "preview": "import numpy as np\nimport torch.nn as nn\n\n\nclass BaseModel(nn.Module):\n    \"\"\"Base class for all models.\"\"\"\n\n    def __i"
  },
  {
    "path": "models/hubert.py",
    "chars": 5821,
    "preview": "import numpy as np\nimport torch\nfrom transformers import HubertModel \nfrom transformers.modeling_outputs import BaseMode"
  },
  {
    "path": "models/model_utils.py",
    "chars": 9809,
    "preview": "# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/main.py\nimport math\nimport torch\nimport torch.nn as nn"
  },
  {
    "path": "models/pca.py",
    "chars": 4533,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\n\n\nclass PCALayer(nn.Module):\n\n    def __init__(self, pca_info_path"
  },
  {
    "path": "models/unitalker.py",
    "chars": 4678,
    "preview": "import os.path as osp\nimport torch.nn as nn\n\nfrom dataset.dataset_config import dataset_config\nfrom models.wav2vec2 impo"
  },
  {
    "path": "models/unitalker_decoder.py",
    "chars": 3224,
    "preview": "# yapf: disable\nimport torch\nimport torch.nn as nn\n\nfrom .base_model import BaseModel\nfrom .model_utils import (\n    TCN"
  },
  {
    "path": "models/wav2vec2.py",
    "chars": 5838,
    "preview": "import numpy as np\nimport torch\nfrom transformers import Wav2Vec2Model\nfrom transformers.modeling_outputs import BaseMod"
  },
  {
    "path": "models/wavlm.py",
    "chars": 3241,
    "preview": "import numpy as np\nimport torch\nfrom transformers import WavLMModel\nfrom transformers.modeling_outputs import Wav2Vec2Ba"
  },
  {
    "path": "utils/config.py",
    "chars": 5317,
    "preview": "# -----------------------------------------------------------------------------\n# Functions for parsing args\n# ---------"
  },
  {
    "path": "utils/utils.py",
    "chars": 7570,
    "preview": "import argparse\nimport logging\nimport numpy as np\nimport os\nimport random\nimport torch\nfrom copy import deepcopy\n\nfrom ."
  }
]

About this extraction

This page contains the full source code of the X-niper/UniTalker GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 23 files (117.9 KB), approximately 28.5k tokens, and a symbol index with 150 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!