Repository: X-niper/UniTalker
Branch: main
Commit: 21481130f126
Files: 23
Total size: 117.9 KB
Directory structure:
gitextract_hq8r3so5/
├── .gitignore
├── LICENSE
├── README.md
├── config/
│ └── unitalker.yaml
├── dataset/
│ ├── data_item.py
│ ├── dataset.py
│ └── dataset_config.py
├── loss/
│ └── loss.py
├── main/
│ ├── demo.py
│ ├── render.py
│ ├── train.py
│ ├── train_eval_loop.py
│ └── train_pca.py
├── models/
│ ├── base_model.py
│ ├── hubert.py
│ ├── model_utils.py
│ ├── pca.py
│ ├── unitalker.py
│ ├── unitalker_decoder.py
│ ├── wav2vec2.py
│ └── wavlm.py
└── utils/
├── config.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized files
*__pycache__/
*.py[cod]
# Distribution / packaging
dist/
build/
*.egg-info/
*.egg
# Virtual environments
venv/
env/
.venv/
.pip/
pipenv/
# IDE specific files
.idea/
.vscode/
# Logs and temporary files
*.log
*.swp
*.swo
*.tmp
# Miscellaneous
.DS_Store
records.txt
binary_resources
do_experiment.py
small_tools
scancel_job.py
results
template_resources
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# UniTalker: Scaling up Audio-Driven 3D Facial Animation through A Unified Model [ECCV2024]
## Useful Links
<div align="center">
<a href="https://x-niper.github.io/projects/UniTalker/" class="button"><b>[Homepage]</b></a>
<a href="https://arxiv.org/abs/2408.00762" class="button"><b>[arXiv]</b></a>
<a href="https://www.youtube.com/watch?v=oUUh67ECzig" class="button"><b>[Video]</b></a>
</div>
<div align="center">
<img src="https://github.com/X-niper/X-niper.github.io/blob/master/projects/UniTalker/assets/unitalker_architecture.png" alt="UniTalker Architecture" width="80%" align=center/>
</div>
> UniTalker generates realistic facial motion from different audio domains, including clean and noisy voices in various languages, text-to-speech-generated audios, and even noisy songs accompanied by back-ground music.
>
> UniTalker can output multiple annotations.
>
> For datasets with new annotations, one can simply plug new heads into UniTalker and train it with existing datasets or solely with new ones, avoiding retopology.
## Installation
### Environment
- Linux
- Python 3.10
- Pytorch 2.2.0
- CUDA 12.1
- transformers 4.39.3
- Pytorch3d 0.7.7 (Optional: just for rendering the results)
```bash
conda create -n unitalker python==3.10
conda activate unitalker
conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install transformers librosa tensorboardX smplx chumpy numpy==1.23.5 opencv-python
```
## Inference
### Download checkpoints, PCA models and template resources
[UniTalker-B-[D0-D7]](https://drive.google.com/file/d/1PmF8I6lyo0_64-NgeN5qIQAX6Bg0yw44/view?usp=sharing): The base model in paper. Download it and place it in `./pretrained_models` .
[UniTalker-L-[D0-D7]](https://drive.google.com/file/d/1sH2T7KLFNjUnTM-V1eRMM1Tytxd2sYAp/view?usp=sharing): The default model in paper. Please first try the base model to run the pipeline through.
[Unitalker-data-release-V1](https://drive.google.com/file/d/1Un7TB0Z5A1CG6bgeqKlhnSOECFN-C6KK/view?usp=sharing): The released datasets, PCA models, data-split json files and id-template numpy array. Download and unzip it in this repo.
[FLAME2020](https://flame.is.tue.mpg.de/download.php): Please download FLAME 2020 and move generic_model.pkl into resources/binary_resources/flame.pkl.
<!-- [PCA models](https://drive.google.com/file/d/1e0sG2vvdrtAMgwD5njctifhX0ai4eu3g/view?usp=sharing): download the pca models and unzip it in "./unitalker_data_release" -->
Use `git lfs pull` to get `./resources.zip` and `./test_audios.zip` and unzip it in this repo.
Finally, these files should be organized as follows:
```text
├── pretrained_models
│ ├── UniTalker-B-D0-D7.pt
│ ├── UniTalker-L-D0-D7.pt
├── resources
│ ├── binary_resources
│ │ ├── 02_flame_mouth_idx.npy
│ │ ├── ...
│ │ └── vocaset_FDD_wo_eyes.npy
│ └── obj_template
│ ├── 3DETF_blendshape_weight.obj
│ ├── ...
│ └── meshtalk_6172_vertices.obj
└── unitalker_data_release_V1
│ ├── D0_BIWI
│ │ ├── id_template.npy
│ │ └── pca.npz
│ ├── D1_vocaset
│ │ ├── id_template.npy
│ │ └── pca.npz
│ ├── D2_meshtalk
│ │ ├── id_template.npy
│ │ └── pca.npz
│ ├── D3D4_3DETF
│ │ ├── D3_HDTF
│ │ └── D4_RAVDESS
│ ├── D5_unitalker_faceforensics++
│ │ ├── id_template.npy
│ │ ├── test
│ │ ├── test.json
│ │ ├── train
│ │ ├── train.json
│ │ ├── val
│ │ └── val.json
│ ├── D6_unitalker_Chinese_speech
│ │ ├── id_template.npy
│ │ ├── test
│ │ ├── test.json
│ │ ├── train
│ │ ├── train.json
│ │ ├── val
│ │ └── val.json
│ └── D7_unitalker_song
│ ├── id_template.npy
│ ├── test
│ ├── test.json
│ ├── train
│ ├── train.json
│ ├── val
│ └── val.json
```
### Demo
```bash
python -m main.demo --config config/unitalker.yaml test_out_path ./test_results/demo.npz
python -m main.render ./test_results/demo.npz ./test_audios ./test_results/
```
## Train
### Download Data
[Unitalker-data-release-V1](https://drive.google.com/file/d/1qRBPsTdOWp72ty04oD1Q_ivtwMjrACLH/view?usp=sharing) contains D5, D6 and D7. The datasets have been processed and grouped into train, validation and test. Please use these three datasets to try the training step.
If you want to train the model on the D0-D7, you need to download the datasets following these links:
[D0: BIWI](https://github.com/Doubiiu/CodeTalker/blob/main/BIWI/README.md).
[D1: VOCASET](https://voca.is.tue.mpg.de/).
[D2: meshtalk](https://github.com/facebookresearch/meshtalk?tab=readme-ov-file).
[D4,D5: 3DETF](https://github.com/psyai-net/EmoTalk_release).
### Modify Config and Train
Please modify `dataset` and `duplicate_list` in `config/unitalker.yaml` according to the datasets you have prepared, ensuring that both lists maintain the same length.
```bash
python -m main.train --config config/unitalker.yaml
```
================================================
FILE: config/unitalker.yaml
================================================
DATA:
dataset: "D0,D1,D2,D3,D4,D5,D6,D7"
demo_dataset: "D0,D1,D2,D3,D4,D6,D7" # To visualize D5, please download FLAME2020 model first and place it in "resources/binary_resources/flame.pkl"
data_root: "./unitalker_data_release_V1/"
duplicate_list: "10,5,4,1,1,1,1,1"
train_json_key: train.json
LOSS:
rec_weight: 1.0
pca_weight: 0.01
blendshape_weight: 0.0001
NETWORK:
use_pca: True
pca_dim: 512
MODEL:
audio_encoder_repo: 'microsoft/wavlm-base-plus' # choose from 'microsoft/wavlm-base-plus' and facebook/wav2vec2-large-xlsr-53'
freeze_wav2vec: False
interpolate_pos: 1
decoder_dimension: 256
decoder_type: conv # choose from conv and transformer
period: 30
headlayer: 1
TRAIN:
fix_encoder_first: True
random_id_prob: 0.1
re_init_decoder_and_head: False
only_head: False
workers: 4 # data loader workers
batch_size: 1 # batch size for training
batch_size_val: 1 # batch size for validation during training, memory and speed tradeoff
lr: 0.0001
epochs: 100
save_every: 1
save_path: "./results"
weight_path: "./pretrained_models/UniTalker-B-D0-D7.pt"
eval_every: 1
test_every: 1
log_every: 1
do_test: True
TEST:
test_wav_dir: "./test_audios"
test_out_path: "./test_results/demo.npz"
================================================
FILE: dataset/data_item.py
================================================
import librosa
import numpy as np
from typing import Any
from .dataset_config import dataset_config
class DataItem:
def __init__(self,
annot_path: str,
audio_path: str,
identity_idx: int,
annot_type: str = '',
dataset_name: str = '',
id_template: np.ndarray = None,
fps: float = 30,
processor=None) -> None:
self.annot_path = annot_path
self.audio_path = audio_path
self.identity_idx = identity_idx
self.fps = fps
self.annot_data = np.load(annot_path)
if self.fps == 60:
self.annot_data = self.annot_data[::2]
self.fps = 30
elif self.fps == 100:
self.annot_data = self.annot_data[::4]
self.fps = 25
elif self.fps != 30 and self.fps != 25:
raise ValueError('wrong fps')
scale = dataset_config[dataset_name]['scale']
self.annot_data = self.scale_and_offset(self.annot_data, scale, 0.0)
self.audio_data, sr = self.load_audio(audio_path)
self.original_audio_data, self.annot_data = self.truncate(
self.audio_data, sr, self.annot_data, self.fps)
self.audio_data = np.squeeze(
processor(self.original_audio_data, sampling_rate=sr).input_values)
self.audio_data = self.audio_data.astype(np.float32)
self.annot_type = annot_type
self.id_template = self.scale_and_offset(id_template, scale, 0.0)
self.annot_data = self.annot_data.reshape(len(self.annot_data),
-1).astype(np.float32)
self.id_template = self.id_template.reshape(1, -1).astype(np.float32)
self.dataset_name = dataset_name
self.data_dict = self.to_dict()
return
def load_audio(self, audio_path: str):
if audio_path.endswith('.npy'):
return np.load(audio_path), 16000
else:
return librosa.load(audio_path, sr=16000)
def truncate(self, audio_array: np.ndarray, sr: int,
annot_data: np.ndarray, fps: int):
audio_duration = len(audio_array) / sr
annot_duration = len(annot_data) / fps
duration = min(audio_duration, annot_duration, 20)
audio_length = round(duration * sr)
annot_length = round(duration * fps)
self.duration = duration
return audio_array[:audio_length], annot_data[:annot_length]
def offset_id(self, offset: int):
self.identity_idx = self.identity_idx + offset
self.data_dict['id'] = self.identity_idx
return
def scale_and_offset(self,
data: np.ndarray,
scale: float = 1.0,
offset: np.ndarray = 0.0):
return data * scale + offset
def to_dict(self, ) -> Any:
item = {
'data': self.annot_data,
'audio': self.audio_data,
'original_audio': self.original_audio_data,
'fps': self.fps,
'id': self.identity_idx,
'annot_path': self.annot_path,
'audio_path': self.audio_path,
'annot_type': self.annot_type,
'template': self.id_template,
'dataset_name': self.dataset_name
}
return item
def get_dict(self, ):
return self.data_dict
================================================
FILE: dataset/dataset.py
================================================
import json
import numpy as np
import os.path as osp
from torch.utils import data
from tqdm import tqdm
from transformers import Wav2Vec2FeatureExtractor
from typing import Any, List
from .data_item import DataItem
from .dataset_config import dataset_config
class AudioFaceDataset(data.Dataset):
def __init__(self, data_label_path: str, args: dict = None) -> None:
super().__init__()
data_root = osp.dirname(data_label_path, )
id_template_path = osp.join(data_root, 'id_template.npy')
id_template_list = np.load(id_template_path)
with open(data_label_path) as f:
labels = json.load(f)
info = labels['info']
self.id_list = info['id_list']
self.sr = 16000
data_info_list = labels['data']
data_list = []
for data_info in tqdm(data_info_list):
data = DataItem(
annot_path=osp.join(data_root, data_info['annot_path']),
audio_path=osp.join(data_root, data_info['audio_path']),
identity_idx=data_info['id'],
annot_type=data_info['annot_type'],
dataset_name=data_info['dataset'],
id_template=id_template_list[data_info['id']],
fps=data_info['fps'],
processor=args.processor,
)
if data.duration < 0.5:
continue
data_list.append(data)
self.data_list = data_list
return
def __len__(self, ):
return len(self.data_list)
def __getitem__(self, index: Any) -> Any:
data = self.data_list[index]
return data.get_dict()
def get_identity_num(self, ):
return len(self.id_list)
class MixAudioFaceDataset(AudioFaceDataset):
def __init__(self,
dataset_list: List[AudioFaceDataset],
duplicate_list: list = None) -> None:
super(AudioFaceDataset).__init__()
self.id_list = []
self.sr = 16000
self.data_list = []
self.annot_type_list = []
if duplicate_list is None:
duplicate_list = [1] * len(dataset_list)
else:
assert len(duplicate_list) == len(dataset_list)
for dup, dataset in zip(duplicate_list, dataset_list):
id_index_offset = len(self.id_list)
for d in dataset.data_list:
d.offset_id(id_index_offset)
if d.annot_type not in self.annot_type_list:
self.annot_type_list.append(d.annot_type)
self.id_list = self.id_list + dataset.id_list
self.data_list = self.data_list + dataset.data_list * dup
return
def get_dataset_list(args):
dataset_name_list = args.dataset
dataset_dir_list = [
dataset_config[name]['dirname'] for name in dataset_name_list
]
train_json_list = [
osp.join(args.data_root, dataset_name, args.train_json_key)
for dataset_name in dataset_dir_list
]
args.read_audio = False
args.processor = None
train_dataset_list = []
processor = Wav2Vec2FeatureExtractor.from_pretrained(
args.audio_encoder_repo)
args.processor = processor
for train_json in train_json_list:
dataset = AudioFaceDataset(
train_json,
args,
)
train_dataset_list.append(dataset)
return train_dataset_list
def get_single_dataset(args, json_path: str = ''):
args.read_audio = True
processor = Wav2Vec2FeatureExtractor.from_pretrained(
args.audio_encoder_repo)
args.processor = processor
dataset = AudioFaceDataset(
json_path,
args,
)
test_loader = data.DataLoader(
dataset=dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
num_workers=args.workers)
return test_loader
def get_dataloaders(args):
dataset_name_list = args.dataset
print(dataset_name_list)
dataset_dir_list = [
dataset_config[name]['dirname'] for name in dataset_name_list
]
train_json_list = [
osp.join(args.data_root, dataset_dir, args.train_json_key)
for dataset_dir in dataset_dir_list
]
val_json_list = [
osp.join(args.data_root, dataset_dir, 'val.json')
for dataset_dir in dataset_dir_list
]
test_json_list = [
osp.join(args.data_root, dataset_dir, 'test.json')
for dataset_dir in dataset_dir_list
]
processor = Wav2Vec2FeatureExtractor.from_pretrained(
args.audio_encoder_repo)
args.processor = processor
if getattr(args, 'do_train', True) is True:
train_dataset_list = []
for train_json in train_json_list:
dataset = AudioFaceDataset(train_json, args)
train_dataset_list.append(dataset)
mix_train_dataset = MixAudioFaceDataset(train_dataset_list,
args.duplicate_list)
train_loader = data.DataLoader(
dataset=mix_train_dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=args.workers)
else:
train_loader = None
if getattr(args, 'do_validate', True) is True:
val_dataset_list = []
for val_json in val_json_list:
dataset = AudioFaceDataset(val_json, args)
val_dataset_list.append(dataset)
mix_val_dataset = MixAudioFaceDataset(val_dataset_list)
val_loader = data.DataLoader(
dataset=mix_val_dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
num_workers=args.workers)
else:
val_loader = None
if getattr(args, 'do_test', True) is True:
test_dataset_list = []
for test_json in test_json_list:
dataset = AudioFaceDataset(test_json, args)
test_dataset_list.append(dataset)
mix_test_dataset = MixAudioFaceDataset(test_dataset_list)
test_loader = data.DataLoader(
dataset=mix_test_dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
num_workers=args.workers)
else:
test_loader = None
return train_loader, val_loader, test_loader
================================================
FILE: dataset/dataset_config.py
================================================
dataset_config = {
'D0': {
'dirname': 'D0_BIWI',
'annot_type': 'BIWI_23370_vertices',
'scale': 0.2,
'annot_dim': 23370 * 3,
'subjects': 6,
'pca': True,
},
'D1': {
'dirname': 'D1_vocaset',
'annot_type': 'FLAME_5023_vertices',
'scale': 1.0,
'annot_dim': 5023 * 3,
'subjects': 12,
'pca': True,
},
'D2': {
'dirname': 'D2_meshtalk',
'annot_type': 'meshtalk_6172_vertices',
'scale': 0.001,
'annot_dim': 6172 * 3,
'subjects': 13,
'pca': True,
},
'D3': {
'dirname': 'D3D4_3DETF/D3_HDTF',
'annot_type': '3DETF_blendshape_weight',
'scale': 1.0,
'annot_dim': 52,
'subjects': 141,
'pca': False,
},
'D4': {
'dirname': 'D3D4_3DETF/D4_RAVDESS',
'annot_type': '3DETF_blendshape_weight',
'scale': 1.0,
'annot_dim': 52,
'subjects': 24,
'pca': False,
},
'D5': {
'dirname': 'D5_unitalker_faceforensics++',
'annot_type': 'flame_params_from_dadhead',
'scale': 1.0,
'annot_dim': 413,
'subjects': 719,
'pca': False,
},
'D6': {
'dirname': 'D6_unitalker_Chinese_speech',
'annot_type': 'inhouse_blendshape_weight',
'scale': 1.0,
'annot_dim': 51,
'subjects': 8,
'pca': False,
},
'D7': {
'dirname': 'D7_unitalker_song',
'annot_type': 'inhouse_blendshape_weight',
'scale': 1.0,
'annot_dim': 51,
'subjects': 11,
'pca': False,
},
}
================================================
FILE: loss/loss.py
================================================
import numpy as np
import pickle
import torch
from dataclasses import dataclass
from smplx.lbs import lbs
from smplx.utils import Struct, to_np, to_tensor
from torch import nn
@dataclass
class FlameParams:
betas: torch.Tensor
jaw: torch.Tensor
eyeballs: torch.Tensor
neck: torch.Tensor
class FlameLayer(nn.Module):
def __init__(self, model_path: str) -> None:
super().__init__()
self.dtype = torch.float32
with open(model_path, 'rb') as f:
self.flame_model = Struct(**pickle.load(f, encoding='latin1'))
shapedirs = self.flame_model.shapedirs
# The shape components
self.register_buffer('shapedirs',
to_tensor(to_np(shapedirs), dtype=self.dtype))
j_regressor = to_tensor(
to_np(self.flame_model.J_regressor), dtype=self.dtype)
self.register_buffer('J_regressor', j_regressor)
self.register_buffer(
'v_template',
to_tensor(to_np(self.flame_model.v_template), dtype=self.dtype),
)
num_pose_basis = self.flame_model.posedirs.shape[-1]
posedirs = np.reshape(self.flame_model.posedirs,
[-1, num_pose_basis]).T
self.register_buffer('posedirs',
to_tensor(to_np(posedirs), dtype=self.dtype))
parents = to_tensor(to_np(self.flame_model.kintree_table[0])).long()
parents[0] = -1
self.register_buffer('parents', parents)
self.register_buffer(
'lbs_weights',
to_tensor(to_np(self.flame_model.weights), dtype=self.dtype))
default_neck_pose = torch.zeros([1, 3],
dtype=self.dtype,
requires_grad=False)
self.register_parameter(
'neck_pose', nn.Parameter(default_neck_pose, requires_grad=False))
default_eyeball_pose = torch.zeros([1, 6],
dtype=self.dtype,
requires_grad=False)
self.register_parameter(
'eyeballs',
nn.Parameter(default_eyeball_pose, requires_grad=False))
default_jaw = torch.zeros([1, 3],
dtype=self.dtype,
requires_grad=False)
self.register_parameter('jaw',
nn.Parameter(default_jaw, requires_grad=False))
self.FLAME_CONSTS = {
'betas': 400,
'rotation': 6,
'jaw': 3,
'eyeballs': 0,
'neck': 0,
'translation': 3,
'scale': 1,
}
def flame_params_from_3dmm(self,
tensor_3dmm: torch.Tensor,
zero_expr: bool = False):
constants = self.FLAME_CONSTS
assert tensor_3dmm.ndim == 2
cur_index: int = 0
betas = tensor_3dmm[:, :constants['betas']]
cur_index += constants['betas']
jaw = tensor_3dmm[:, cur_index:cur_index + constants['jaw']]
cur_index += constants['jaw']
cur_index += constants['rotation']
eyeballs = tensor_3dmm[:, cur_index:cur_index + constants['eyeballs']]
cur_index += constants['eyeballs']
neck = tensor_3dmm[:, cur_index:cur_index + constants['neck']]
cur_index += constants['neck']
cur_index += constants['translation']
cur_index += constants['scale']
return FlameParams(betas=betas, jaw=jaw, eyeballs=eyeballs, neck=neck)
def forward(self, tensor_3dmm: torch.Tensor):
bs = tensor_3dmm.shape[0]
flame_params = self.flame_params_from_3dmm(tensor_3dmm)
betas = flame_params.betas
rotation = torch.zeros([bs, 3], device=tensor_3dmm.device)
neck_pose = flame_params.neck if not (
0 in flame_params.neck.shape) else self.neck_pose[[0]].expand(
bs, -1)
eyeballs = flame_params.eyeballs if not (
0 in flame_params.eyeballs.shape) else self.eyeballs[[0]].expand(
bs, -1)
jaw = flame_params.jaw if not (
0 in flame_params.jaw.shape) else self.jaw[[0]].expand(bs, -1)
full_pose = torch.cat([rotation, neck_pose, jaw, eyeballs], dim=1)
template_vertices = self.v_template.unsqueeze(0).repeat(bs, 1, 1)
vertices, _ = lbs(
betas,
full_pose,
template_vertices,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
)
return vertices
class FlameParamLossForDADHead(nn.Module):
def __init__(self, mouth_indices_path: str, loss_mask_path: str,
flame_model_path: str) -> None:
super().__init__()
mouth_indices = np.load(mouth_indices_path)
loss_indices = np.load(loss_mask_path)
self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))
self.register_buffer('loss_indices', torch.from_numpy(loss_indices))
self.mse = nn.MSELoss()
self.flame_layer = FlameLayer(flame_model_path)
def forward(self, x: torch.Tensor, target: torch.Tensor):
ori_shape = x.shape
x = self.flame_layer(x.reshape(-1, ori_shape[-1]))
target = self.flame_layer(target.reshape(-1, ori_shape[-1]))
x = x.reshape((ori_shape[0], ori_shape[1], -1, 3))
target = target.reshape((ori_shape[0], ori_shape[1], -1, 3))
return self.mse(x[:, :, self.loss_indices], target[:, :,
self.loss_indices])
def get_vertices(
self,
x: torch.Tensor,
):
x = self.flame_layer(x.reshape(-1, x.shape[-1]))
return x[:, self.loss_indices]
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
ori_shape = x.shape
x = self.flame_layer(x.reshape(-1, ori_shape[-1]))
target = self.flame_layer(target.reshape(-1, ori_shape[-1]))
metric_L2 = (target[:, self.mouth_idx] - x[:, self.mouth_idx])**2
metric_L2 = metric_L2.sum(-1)
metric_L2 = metric_L2.max(-1)[0]
metric_L2norm = metric_L2**0.5
return metric_L2.mean(), metric_L2norm.mean()
class VerticesLoss(nn.Module):
def __init__(self, mouth_indices_path: str) -> None:
super().__init__()
mouth_indices = np.load(mouth_indices_path)
self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))
self.mse = nn.MSELoss()
def forward(self, x: torch.Tensor, target: torch.Tensor):
return self.mse(x, target)
def get_vertices(
self,
x: torch.Tensor,
):
return x.reshape(-1, x.shape[-1] // 3, 3)
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
ori_shape = x.shape
target = target.reshape(*ori_shape[:-1], -1, 3)
x = x.reshape(*ori_shape[:-1], -1, 3)
metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2
metric_L2 = metric_L2.sum(-1)
metric_L2 = metric_L2.max(-1)[0]
metric_L2norm = metric_L2**0.5
return metric_L2.mean(), metric_L2norm.mean()
class BlendShapeLoss(nn.Module):
def __init__(self,
mouth_indices_path: str,
blendshape_path: str,
bs_beta: float = 0.0,
components_num: int = None) -> None:
super().__init__()
mouth_indices = np.load(mouth_indices_path)
blendshape = np.load(blendshape_path)
self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))
mean_shape = blendshape['meanshape']
mean_shape = mean_shape.reshape(-1).astype(np.float32)
blend_shape = blendshape['blendshape']
blend_shape = blend_shape.reshape(len(blend_shape),
-1).astype(np.float32)
self.register_buffer('mean_shape', torch.from_numpy(mean_shape))
self.register_buffer('blend_shape', torch.from_numpy(blend_shape))
self.bs_beta = bs_beta
self.mse = nn.MSELoss()
self.components_num = components_num
def forward(self, x: torch.Tensor, target: torch.Tensor):
return self.bs_beta * self.mse(x, target) + self.mse(
self.bs2vertices(x), self.bs2vertices(target))
def bs2vertices(self, x: torch.Tensor):
return x[:, :self.components_num] @ \
self.blend_shape[:self.components_num] + \
self.mean_shape
def get_vertices(
self,
x: torch.Tensor,
):
x = self.bs2vertices(x)
return x.reshape(-1, x.shape[-1] // 3, 3)
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
target = self.bs2vertices(target)
x = self.bs2vertices(x)
ori_shape = x.shape
target = target.reshape(*ori_shape[:-1], -1, 3)
x = x.reshape(*ori_shape[:-1], -1, 3)
metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2
metric_L2 = metric_L2.sum(-1)
metric_L2 = metric_L2.max(-1)[0]
metric_L2norm = metric_L2**0.5
return metric_L2.mean(), metric_L2norm.mean()
class ParamLoss(nn.Module):
def __init__(self, weight: float):
super().__init__()
self.weight = weight
self.mse = nn.MSELoss()
self.L1 = nn.L1Loss()
def forward(self, x: torch.Tensor, target: torch.Tensor):
return self.weight * self.mse(x, target)
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
return self.mse(x, target), self.L1(x, target)
def get_vertices(
self,
x: torch.Tensor,
):
return x.squeeze(0)
class UniTalkerLoss(nn.Module):
def __init__(self, args) -> None:
super().__init__()
loss_config = {
'flame_params_from_dadhead': {
'class': FlameParamLossForDADHead,
'args': {
'mouth_indices_path':
'resources/binary_resources/02_flame_mouth_idx.npy',
'loss_mask_path':
'resources/binary_resources/flame_humanface_index.npy',
'flame_model_path': 'resources/binary_resources/flame.pkl',
}
},
'3DETF_blendshape_weight': {
'class': BlendShapeLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/03_emo_talk_mouth_idx.npy',
'blendshape_path':
'resources/binary_resources/EmoTalk.npz',
'bs_beta': args.blendshape_weight,
},
},
'inhouse_blendshape_weight': {
'class': BlendShapeLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/05_inhouse_arkit_mouth_idx.npy',
'blendshape_path':
'resources/binary_resources/inhouse_arkit.npz',
'bs_beta': args.blendshape_weight,
},
},
'FLAME_5023_vertices': {
'class': VerticesLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/02_flame_mouth_idx.npy',
},
},
'BIWI_23370_vertices': {
'class': VerticesLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/04_BIWI_mouth_idx.npy',
},
},
'FLAME_SUB_2055_vertices': {
'class': VerticesLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/mouth_idx_FLAME_5023_vertices_wo_eyes_wo_head.npy',
},
},
'meshtalk_6172_vertices': {
'class': VerticesLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/06_mesh_talk_mouth_idx.npy',
},
},
'24_viseme': {
'class': ParamLoss,
'args': {
'weight': args.blendshape_weight
},
}
}
loss_module_dict = {}
for k, v in loss_config.items():
loss_module = v['class'](**v['args'])
loss_module_dict[k] = loss_module
self.loss_module_dict = nn.ModuleDict(loss_module_dict)
self.mse = nn.MSELoss()
return
def forward(
self,
x: torch.Tensor,
target: torch.Tensor,
annot_type: str,
):
return self.loss_module_dict[annot_type](x, target)
def pca_loss(self, x: torch.Tensor, target: torch.Tensor):
return self.mse(x, target)
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor,
annot_type: str):
return self.loss_module_dict[annot_type].mouth_metric(x, target)
def get_vertices(self, x: torch.Tensor, annot_type: str):
return self.loss_module_dict[annot_type].get_vertices(x)
================================================
FILE: main/demo.py
================================================
import librosa
import numpy as np
import os
import torch
import torch.optim
import torch.utils.data
from transformers import Wav2Vec2FeatureExtractor
from dataset.dataset_config import dataset_config
from loss.loss import UniTalkerLoss
from models.unitalker import UniTalker
from utils.utils import get_parser, get_template_verts, get_audio_encoder_dim
def get_all_audios(audio_root:str):
wav_f_names = []
for r, _, f in os.walk(audio_root):
for wav in f:
if wav.endswith('.wav') or wav.endswith('.mp3'):
relative_path = os.path.join(r, wav)
relative_path = os.path.relpath(relative_path,
audio_root)
wav_f_names.append(relative_path)
wav_f_names = sorted(wav_f_names)
return wav_f_names
def split_long_audio(
audio: np.ndarray,
processor:Wav2Vec2FeatureExtractor
):
# audio = audio.squeeze(0)
a, b = 25, 5
sr = 16000
total_length = len(audio) /sr
reps = max(0, int(np.ceil((total_length - a) / (a - b)))) + 1
in_audio_split_list = []
start, end = 0, int(a * sr)
step = int((a - b) * sr)
for i in range(reps):
audio_split = audio[start:end]
audio_split = np.squeeze(
processor(audio_split, sampling_rate=sr).input_values)
in_audio_split_list.append(audio_split)
start += step
end += step
return in_audio_split_list
def merge_out_list(out_list: list, fps:int):
if len(out_list) == 1:
return out_list[0]
a, b = 25, 5
left_weight = np.linspace(1, 0, b * fps)[:, np.newaxis]
right_weight = 1 - left_weight
a = a * fps
b = b * fps
offset = a - b
out_length = len(out_list[-1]) + offset * (len(out_list) - 1)
merged_out = np.empty((out_length, out_list[-1].shape[-1]),
dtype=out_list[-1].dtype)
merged_out[:a] = out_list[0]
for out_piece in out_list[1:]:
merged_out[a - b:a] = left_weight * merged_out[
a - b:a] + right_weight * out_piece[:b]
merged_out[a:a + offset] = out_piece[b:]
a += offset
return merged_out
def main():
args = get_parser()
training_dataset_name_list = args.dataset
demo_dataset_name_list = args.demo_dataset
condition_id_config = {
'D0': 3,
'D1': 3,
'D2': 0,
'D3': 0,
'D4': 0,
'D5': 0,
'D6': 4,
'D7': 0,
}
template_id_config = {
'D0': 3,
'D1': 3,
'D2': 0,
'D3': 0,
'D4': 0,
'D5': 0,
'D6': 4,
'D7': 0,
}
checkpoint = torch.load(args.weight_path, map_location='cpu')
args.identity_num = len(checkpoint['decoder.learnable_style_emb.weight'])
start_idx = 0
for dataset_name in training_dataset_name_list:
annot_type = dataset_config[dataset_name]['annot_type']
id_num = dataset_config[dataset_name]['subjects']
end_idx = start_idx + id_num
local_condition_idx = condition_id_config[dataset_name]
template_idx = template_id_config[dataset_name]
template = get_template_verts(args.data_root, dataset_name, template_idx)
template_id_config[dataset_name] = torch.Tensor(template.reshape(1, -1))
if args.condition_id == 'each':
condition_id_config[dataset_name] = torch.tensor(
start_idx + local_condition_idx).reshape(1)
elif args.condition_id == 'common':
condition_id_config[dataset_name] = torch.tensor(args.identity_num -
1).reshape(1)
else:
try:
condition_id = int(args.condition_id)
condition_id_config[dataset_name] = torch.tensor(
condition_id).reshape(1)
except ValueError:
assert args.condition_id in dataset_config.keys()
condition_id_config[dataset_name] = torch.tensor(
start_idx + local_condition_idx).reshape(1)
start_idx = end_idx
if args.random_id_prob > 0:
assert end_idx == (args.identity_num - 1)
else:
assert end_idx == args.identity_num
args.audio_encoder_feature_dim = get_audio_encoder_dim(
args.audio_encoder_repo)
model = UniTalker(args)
model.load_state_dict(checkpoint, strict=False)
model.eval()
model.cuda()
wav_f_names = get_all_audios(args.test_wav_dir)
out_npz_path = args.test_out_path
dirname = os.path.dirname(out_npz_path)
os.makedirs(dirname, exist_ok=True)
out_dict = {}
loss_module = UniTalkerLoss(args).cuda()
processor = Wav2Vec2FeatureExtractor.from_pretrained(
args.audio_encoder_repo,)
with torch.no_grad():
for wav_f in wav_f_names:
out_dict[wav_f] = {}
wav_path = os.path.join(args.test_wav_dir, wav_f)
audio_data, sr = librosa.load(wav_path, sr=16000)
audio_data = np.squeeze(
processor(audio_data, sampling_rate=sr).input_values)
audio_data_splits = split_long_audio(audio_data, processor)
for dataset_name in demo_dataset_name_list:
template = template_id_config[dataset_name].cuda()
scale = dataset_config[dataset_name]['scale']
template = scale * template
condition_id = condition_id_config[dataset_name].cuda()
annot_type = dataset_config[dataset_name]['annot_type']
if annot_type == 'BIWI_23370_vertices':
fps = 25
else:
fps = 30
out_list = []
for audio_data in audio_data_splits:
audio_data = torch.Tensor(audio_data[None]).cuda()
out, _, _ = model(
audio_data,
template,
face_motion=None,
style_idx=condition_id,
annot_type=annot_type,
fps=fps,
)
out_list.append(out.cpu().numpy().squeeze(0))
out = merge_out_list(out_list, fps)
out = loss_module.get_vertices(torch.from_numpy(out).cuda(), annot_type)
out_dict[wav_f][dataset_name] = out.cpu()
print(f"save results to {out_npz_path}")
np.savez(out_npz_path, **out_dict)
if __name__ == '__main__':
main()
================================================
FILE: main/render.py
================================================
"""Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V.
(MPG) is holder of all proprietary rights on this computer program. You can
only use this computer program if you have closed a license agreement with MPG
or you get the right to use the computer program from someone who is authorized
to grant you that right. Any use of the computer program without a valid
license is prohibited and liable to prosecution. Copyright 2019 Max-Planck-
Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of
its Max Planck Institute for Intelligent Systems and the Max Planck Institute
for Biological Cybernetics. All rights reserved. More information about VOCA is
available at http://voca.is.tue.mpg.de. For comments or questions, please email
us at voca@tue.mpg.de
"""
import cv2
import numpy as np
import os
import subprocess
import torch
from tqdm import tqdm
from dataset.dataset_config import dataset_config
def render_meshes(mesh_vertices: torch.Tensor,
faces: torch.Tensor,
img_size: tuple = (256, 256),
aa_factor: int = 1,
vis_bs: int = 100):
assert len(mesh_vertices) == len(faces) or len(faces) == 1
if len(faces) == 1:
faces = faces.repeat(len(mesh_vertices), 1, 1)
x_max = y_max = mesh_vertices[..., 0:2].abs().max()
from pytorch3d.renderer import (
DirectionalLights,
FoVOrthographicCameras,
HardPhongShader,
Materials,
MeshRasterizer,
MeshRenderer,
RasterizationSettings,
TexturesVertex,
)
from pytorch3d.renderer.blending import BlendParams
from pytorch3d.renderer.materials import Materials
from pytorch3d.structures import Meshes
aa_factor = int(aa_factor)
device = mesh_vertices.device
verts_batch_lst = torch.split(mesh_vertices, vis_bs)
faces_batch_lst = torch.split(faces, vis_bs)
R = torch.eye(3).to(mesh_vertices.device)
R[0, 0] = -1
R[2, 2] = -1
T = torch.zeros(3).to(mesh_vertices.device)
T[2] = 10
R, T = R[None], T[None]
W, H = img_size
cameras = FoVOrthographicCameras(
device=device,
R=R,
T=T,
znear=0.01,
zfar=3,
max_x=x_max * 1.2,
min_x=-x_max * 1.2,
max_y=y_max * 1.2,
min_y=-y_max * 1.2,
)
# cameras = FoVPerspectiveCameras(
# device=device,
# R=R,
# T=T,
# znear=0.01,
# zfar=3,
# aspect_ratio=1,
# fov=0.3,
# degrees=False
# )
raster_settings = RasterizationSettings(
image_size=(W * aa_factor, H * aa_factor),
blur_radius=0.0,
faces_per_pixel=1,
)
raster = MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings,
)
blend_params = BlendParams(
sigma=0.0, gamma=0.0, background_color=(0.0, 0.0, 0.0))
lights = DirectionalLights(
device=device,
direction=((0, 0, 1), ),
ambient_color=((0.3, 0.3, 0.3), ),
diffuse_color=((0.6, 0.6, 0.6), ),
specular_color=((0.1, 0.1, 0.1), ))
materias = Materials(
ambient_color=((1, 1, 1), ),
diffuse_color=((1, 1, 1), ),
specular_color=((1, 1, 1), ),
shininess=15,
device=device)
shader = HardPhongShader(
device=device,
cameras=cameras,
lights=lights,
materials=materias,
blend_params=blend_params)
renderer = MeshRenderer(
rasterizer=raster,
shader=shader,
)
rendered_imgs = []
for verts_batch, faces_batch in tqdm(
zip(verts_batch_lst, faces_batch_lst),
total=(len(verts_batch_lst))):
textures = TexturesVertex(verts_features=torch.ones_like(verts_batch))
meshes = Meshes(
verts=verts_batch, faces=faces_batch, textures=textures)
with torch.no_grad():
imgs = renderer(meshes)
rendered_imgs.append(imgs.cpu())
rendered_imgs = torch.cat(rendered_imgs, dim=0)
if aa_factor > 1:
rendered_imgs = rendered_imgs.permute(0, 3, 1, 2) # NHWC -> NCHW
rendered_imgs = torch.nn.functional.interpolate(
rendered_imgs, scale_factor=1 / aa_factor, mode='bicubic')
rendered_imgs = rendered_imgs.permute(0, 2, 3, 1) # NCHW -> NHWC
return rendered_imgs
def read_obj(in_path):
with open(in_path, 'r') as obj_file:
# Read the lines of the OBJ file
lines = obj_file.readlines()
# Initialize empty lists for vertices and faces
verts = []
faces = []
for line in lines:
line = line.strip() # Remove leading/trailing whitespace
elements = line.split() # Split the line into elements
if len(elements) == 0:
continue # Skip empty lines
# Check the type of line (vertex or face)
if elements[0] == 'v':
# Vertex line
x, y, z = map(float,
elements[1:4]) # Extract the vertex coordinates
verts.append((x, y, z)) # Add the vertex to the list
elif elements[0] == 'f':
# Face line
face_indices = [
int(index.split('/')[0]) for index in elements[1:]
] # Extract the vertex indices
faces.append(face_indices) # Add the face to the list
return np.array(verts), np.array(faces)
def pad_for_libx264(image_array):
"""Pad zeros if width or height of image_array is not divisible by 2.
Otherwise you will get.
\"[libx264 @ 0x1b1d560] width not divisible by 2 \"
Args:
image_array (np.ndarray):
Image or images load by cv2.imread().
Possible shapes:
1. [height, width]
2. [height, width, channels]
3. [images, height, width]
4. [images, height, width, channels]
Returns:
np.ndarray:
A image with both edges divisible by 2.
"""
if image_array.ndim == 2 or \
(image_array.ndim == 3 and image_array.shape[2] == 3):
hei_index = 0
wid_index = 1
elif image_array.ndim == 4 or \
(image_array.ndim == 3 and image_array.shape[2] != 3):
hei_index = 1
wid_index = 2
else:
return image_array
hei_pad = image_array.shape[hei_index] % 2
wid_pad = image_array.shape[wid_index] % 2
if hei_pad + wid_pad > 0:
pad_width = []
for dim_index in range(image_array.ndim):
if dim_index == hei_index:
pad_width.append((0, hei_pad))
elif dim_index == wid_index:
pad_width.append((0, wid_pad))
else:
pad_width.append((0, 0))
values = 0
image_array = \
np.pad(image_array,
pad_width,
mode='constant', constant_values=values)
return image_array
def array_to_video(
image_array: np.ndarray,
output_path: str,
fps=30,
resolution=None,
disable_log: bool = False,
) -> None:
"""Convert an array to a video directly, gif not supported.
Args:
image_array (np.ndarray): shape should be (f * h * w * 3).
output_path (str): output video file path.
fps (Union[int, float, optional): fps. Defaults to 30.
resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
optional): (height, width) of the output video.
Defaults to None.
disable_log (bool, optional): whether close the ffmepg command info.
Defaults to False.
Raises:
FileNotFoundError: check output path.
TypeError: check input array.
Returns:
None.
"""
if not isinstance(image_array, np.ndarray):
raise TypeError('Input should be np.ndarray.')
assert image_array.ndim == 4
assert image_array.shape[-1] == 3
if resolution:
height, width = resolution
width += width % 2
height += height % 2
else:
image_array = pad_for_libx264(image_array)
height, width = image_array.shape[1], image_array.shape[2]
command = [
'/usr/bin/ffmpeg',
'-y', # (optional) overwrite output file if it exists
'-f',
'rawvideo',
'-s',
f'{int(width)}x{int(height)}', # size of one frame
'-pix_fmt',
'bgr24',
'-r',
f'{fps}', # frames per second
'-loglevel',
'error',
'-threads',
'4',
'-i',
'-', # The input comes from a pipe
'-vcodec',
'libx264',
'-an', # Tells FFMPEG not to expect any audio
output_path,
]
if not disable_log:
print(f'Running \"{" ".join(command)}\"')
process = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if process.stdin is None or process.stderr is None:
raise BrokenPipeError('No buffer received.')
index = 0
while True:
if index >= image_array.shape[0]:
break
process.stdin.write(image_array[index].tobytes())
index += 1
process.stdin.close()
process.stderr.close()
process.wait()
def get_obj_faces(annot_type: str):
template_obj_path_dict = {
'3DETF_blendshape_weight': 'resources/obj_template/3DETF_blendshape_weight.obj',
'FLAME_5023_vertices': 'resources/obj_template/FLAME_5023_vertices.obj',
'BIWI_23370_vertices': 'resources/obj_template/BIWI_23370_vertices.obj',
'flame_params_from_dadhead': 'resources/obj_template/flame_params_from_dadhead.obj',
'inhouse_blendshape_weight': 'resources/obj_template/inhouse_blendshape_weight.obj',
'meshtalk_6172_vertices': 'resources/obj_template/meshtalk_6172_vertices.obj'
}
faces = np.array(read_obj(template_obj_path_dict[annot_type])[1])
if faces.min() == 1:
faces = faces - 1
return faces
def vis_model_out_proxy():
import sys
npz_path = sys.argv[1]
audio_dir = sys.argv[2]
out_dir = sys.argv[3]
save_images = False
os.makedirs(out_dir, exist_ok=True)
out_dict = np.load(npz_path, allow_pickle=True)
print(list(out_dict.keys()))
for wav_f in out_dict.keys():
wav_path = os.path.join(audio_dir, wav_f)
cur_out_dict = out_dict[wav_f]
for datasetname, out in cur_out_dict.item().items():
out = torch.Tensor(out)
annot_type = dataset_config[datasetname]['annot_type']
faces = get_obj_faces(annot_type)
with torch.no_grad():
if save_images:
out_size = (1024, 1024)
else:
out_size = (512, 512)
img_array = render_meshes(out.cuda(),
torch.from_numpy(faces[None]).cuda(),
out_size)
if save_images:
channels = 4
else:
channels = 3
img_array = (img_array[..., :channels].cpu().numpy() *
255).astype(np.uint8)
out_key = f'{wav_f[:-4]}'
if save_images:
out_images_dir = os.path.join(out_dir, datasetname, out_key)
os.makedirs(out_images_dir, exist_ok=True)
for imgidx, img in enumerate(img_array):
img_path = os.path.join(out_images_dir,
f'{imgidx:04d}.png')
cv2.imwrite(img_path, img)
continue
out_video_path = os.path.join(out_dir,
f'{out_key}_{datasetname}.mp4')
out_video_dir = os.path.dirname(out_video_path)
os.makedirs(out_video_dir, exist_ok=True)
tmp_path = os.path.join(out_video_dir, 'tmp.mp4')
if os.path.exists(tmp_path):
os.remove(tmp_path)
if annot_type == 'BIWI_23370_vertices':
fps = 25
else:
fps = 30
print(img_array.shape, img_array.dtype)
array_to_video(img_array, output_path=tmp_path, fps=fps)
cmd = f'ffmpeg -i {tmp_path} -i {wav_path} -c:v copy -c:a aac -shortest -strict -2 {out_video_path} -y '
print(cmd)
subprocess.run(cmd, shell=True)
os.remove(tmp_path)
if __name__ == '__main__':
vis_model_out_proxy()
================================================
FILE: main/train.py
================================================
#!/usr/bin/env python
# yapf: disable
import os
import torch
import torch.optim
import torch.utils.data
from tensorboardX import SummaryWriter
from dataset.dataset import get_dataloaders
from loss.loss import UniTalkerLoss
from models.unitalker import UniTalker
from utils.utils import (
get_average_meter_dict, get_logger, get_parser, get_audio_encoder_dim,
load_ckpt, seed_everything, filter_unitalker_state_dict
)
from .train_eval_loop import train_epoch, validate_epoch
# yapf: enable
def main():
seed_everything(42)
args = get_parser()
args.audio_encoder_feature_dim = get_audio_encoder_dim(args.audio_encoder_repo)
train_loader, val_loader, test_loader = get_dataloaders(args)
args.identity_num = train_loader.dataset.get_identity_num()
args.annot_type_list = train_loader.dataset.annot_type_list
if args.random_id_prob > 0.0:
args.identity_num = args.identity_num + 1
os.makedirs(args.save_path, exist_ok=True)
log_file = os.path.join(args.save_path, 'train.log')
logger = get_logger(log_file)
writer = SummaryWriter(args.save_path)
logger.info('=> creating model ...')
model = UniTalker(args)
if args.weight_path is not None:
logger.info(f'=> loading model {args.weight_path} ...')
load_ckpt(model, args.weight_path, args.re_init_decoder_and_head)
logger.info(args)
model = model.cuda()
model.summary(logger)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
loss_module = UniTalkerLoss(args).cuda()
args.train_meter, args.train_metric_name_list = \
get_average_meter_dict(args, 'train')
args.val_meter, args.val_metric_name_list = \
get_average_meter_dict(args, 'val')
args.logger = logger
args.writer = writer
for epoch in range(args.epochs):
if args.fix_encoder_first:
if epoch == 0:
model.audio_encoder._freeze_wav2vec2_parameters(True)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr)
elif epoch == 1:
model.audio_encoder._freeze_wav2vec2_parameters(
args.freeze_wav2vec)
model.audio_encoder.feature_extractor._freeze_parameters()
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr)
epoch_plus = epoch + 1
train_epoch(train_loader, model, loss_module, optimizer, epoch_plus,
args)
if epoch_plus % args.eval_every == 0:
validate_epoch(val_loader, model, loss_module, epoch_plus, 'val',
args)
if epoch_plus % args.test_every == 0:
validate_epoch(test_loader, model, loss_module, epoch_plus, 'test',
args)
if epoch_plus % args.save_every == 0:
model_path = os.path.join(args.save_path, f'{epoch_plus:03d}.pt')
torch.save(filter_unitalker_state_dict(model.state_dict()), model_path)
if __name__ == '__main__':
main()
================================================
FILE: main/train_eval_loop.py
================================================
import numpy as np
import torch
from tqdm import tqdm
from utils.utils import log_datasetloss, write_to_tensorboard
def train_epoch(data_loader, model, loss_module, optimizer, epoch, args):
model.train()
train_meter = args.train_meter
metric_name_list = args.train_metric_name_list
zero_loss = torch.tensor(0.0)
mixed_dataset_name = 'mixed_dataset'
pbar = tqdm(data_loader)
pbar.set_description(f'Epoch: {epoch} train')
for batch in pbar:
data = batch['data'].cuda(non_blocking=True)
template = batch['template'].cuda(non_blocking=True)
audio = batch['audio'].cuda(non_blocking=True)
identity = batch['id'].cuda(non_blocking=True)
if args.random_id_prob > 0.0:
if np.random.random() < args.random_id_prob:
identity = identity.fill_(args.identity_num - 1)
annot_type = batch['annot_type'][0]
dataset_name = batch['dataset_name'][0]
out_motion, out_pca, gt_pca = model(
audio, template, data, style_idx=identity, annot_type=annot_type)
rec_loss = loss_module(out_motion, data, annot_type)
if out_pca is not None:
pca_loss = loss_module.pca_loss(out_pca, gt_pca)
else:
pca_loss = zero_loss
loss = rec_loss + args.pca_weight * pca_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
for k, v in zip(metric_name_list, (rec_loss, pca_loss)):
train_meter[dataset_name][k].update(v.item())
train_meter[mixed_dataset_name][k].update(v.item())
log_datasetloss(args.logger, epoch, 'train', train_meter, only_mix=True)
write_to_tensorboard(
args.writer, epoch, 'train', train_meter, only_mix=False)
for dataset_name, sub_metric_dict in train_meter.items():
for metric_name in sub_metric_dict.keys():
sub_metric_dict[metric_name].reset()
return
def validate_epoch(data_loader, model, loss_module, epoch, phase, args):
assert phase in ('val', 'test')
model.eval()
val_meter = args.val_meter
metric_name_list = args.val_metric_name_list
mixed_dataset_name = 'mixed_dataset'
with torch.no_grad():
pbar = tqdm(data_loader)
pbar.set_description(f'Epoch: {epoch} {phase:<5}')
for batch in pbar:
data = batch['data'].cuda(non_blocking=True)
template = batch['template'].cuda(non_blocking=True)
audio = batch['audio'].cuda(non_blocking=True)
identity = batch['id'].cuda(non_blocking=True)
annot_type = batch['annot_type'][0]
dataset_name = batch['dataset_name'][0]
if args.random_id_prob >= 1.0:
identity = identity.fill_(args.identity_num - 1)
out_motion, _, _ = model(
audio,
template,
data,
style_idx=identity,
annot_type=annot_type)
rec_loss = loss_module(out_motion, data, annot_type)
mouth_metric_L2, mouth_metric_L2_norm = loss_module.mouth_metric(
out_motion, data, annot_type)
for k, v in zip(metric_name_list,
(rec_loss, mouth_metric_L2, mouth_metric_L2_norm)):
val_meter[dataset_name][k].update(v.item())
val_meter[mixed_dataset_name][k].update(v.item())
if getattr(args, 'logger', None) is not None:
log_datasetloss(
args.logger, epoch, phase, val_meter, only_mix=True)
if getattr(args, 'writer', None) is not None:
write_to_tensorboard(
args.writer, epoch, phase, val_meter, only_mix=False)
for dataset_name, sub_metric_dict in val_meter.items():
for metric_name in sub_metric_dict.keys():
sub_metric_dict[metric_name].reset()
return
================================================
FILE: main/train_pca.py
================================================
#!/usr/bin/env python
import numpy as np
from models.pca import PCA
from utils.utils import get_parser, get_audio_encoder_dim
def main():
args = get_parser()
args.audio_encoder_feature_dim = get_audio_encoder_dim(
args.wav2vec2_type)
main_worker(args)
def main_worker(cfg):
pca_builder = PCA()
from dataset.dataset import get_dataset_list
train_dataset_list = get_dataset_list(cfg)
for dataset in train_dataset_list:
data = pca_builder.load_dataset(dataset)
pca_info = pca_builder.build_incremental_PCA(
data, batch_size=1095, trunc_dim=647)
np.savez('mesh_talk_pca.npz', **pca_info)
if __name__ == '__main__':
main()
================================================
FILE: models/base_model.py
================================================
import numpy as np
import torch.nn as nn
class BaseModel(nn.Module):
"""Base class for all models."""
def __init__(self):
super(BaseModel, self).__init__()
# self.logger = logging.getLogger(self.__class__.__name__)
def forward(self, *x):
"""Forward pass logic.
:return: Model output
"""
raise NotImplementedError
def freeze_model(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def summary(self, logger, writer=None):
"""Model summary."""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size())
for p in model_parameters]) / 1e6 # Unit is Mega
logger.info('===>Trainable parameters: %.3f M' % params)
if writer is not None:
writer.add_text('Model Summary',
'Trainable parameters: %.3f M' % params)
================================================
FILE: models/hubert.py
================================================
import numpy as np
import torch
from transformers import HubertModel
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Tuple
from .model_utils import linear_interpolation
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.Tensor] = None,
min_masks: int = 0,
) -> np.ndarray:
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(mask_prob * all_sz / float(mask_length) +
np.random.rand())
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(mask_prob * sz / float(mask_length) +
np.random.rand())
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
lengths = np.full(num_mask, mask_length)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([
mask_idc[j] + offset for j in range(len(mask_idc))
for offset in range(lengths[j])
])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
class HubertModel(HubertModel):
def __init__(self, config):
super().__init__(config)
def _freeze_parameters(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
frame_num=None,
interpolate_pos: int = 0,
):
self.config.output_attentions = True
output_attentions = output_attentions if \
output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
conv_features = self.feature_extractor(input_values)
hidden_states = conv_features.transpose(1, 2)
if interpolate_pos == 0:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if attention_mask is not None:
output_lengths = self._get_feat_extract_output_lengths(
attention_mask.sum(-1))
attention_mask = torch.zeros(
hidden_states.shape[:2],
dtype=hidden_states.dtype,
device=hidden_states.device)
attention_mask[(torch.arange(
attention_mask.shape[0],
device=hidden_states.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip(
[-1]).bool()
hidden_states = self.feature_projection(hidden_states)
if self.config.apply_spec_augment and self.training:
batch_size, sequence_length, hidden_size = hidden_states.size()
if self.config.mask_time_prob > 0:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
self.config.mask_time_prob,
self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=2,
)
hidden_states[torch.from_numpy(
mask_time_indices)] = self.masked_spec_embed.to(
hidden_states.dtype)
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
self.config.mask_feature_prob,
self.config.mask_feature_length,
)
mask_feature_indices = torch.from_numpy(
mask_feature_indices).to(hidden_states.device)
hidden_states[mask_feature_indices[:, None].expand(
-1, sequence_length, -1)] = 0
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if interpolate_pos == 1:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if not return_dict:
return (hidden_states, ) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
================================================
FILE: models/model_utils.py
================================================
# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/main.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# linear interpolation layer
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(
features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
# Temporal Bias
def init_biased_mask(n_head, max_seq_len, period):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
slopes = torch.Tensor(get_slopes(n_head))
bias = torch.div(
torch.arange(start=0, end=max_seq_len,
step=period).unsqueeze(1).repeat(1, period).view(-1),
period,
rounding_mode='floor')
bias = -torch.flip(bias, dims=[0])
alibi = torch.zeros(max_seq_len, max_seq_len)
for i in range(max_seq_len):
alibi[i, :i + 1] = bias[-(i + 1):]
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
mask = (torch.triu(torch.ones(max_seq_len,
max_seq_len)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
mask == 1, float(0.0))
mask = mask.unsqueeze(0) + alibi
return mask
# Alignment Bias
def enc_dec_mask(device, T, S):
mask = torch.ones(T, S)
for i in range(T):
mask[i, i] = 0
return (mask == 1).to(device=device)
# Periodic Positional Encoding
class PeriodicPositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
super(PeriodicPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(period, d_model)
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, period, d_model)
repeat_num = (max_seq_len // period) + 1
pe = pe.repeat(1, repeat_num, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class TCN(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.first_net = SeqTranslator1D(
in_dim, out_dim, min_layers_num=3, residual=True, norm='ln')
self.dropout = nn.Dropout(0.1)
def forward(self, spectrogram, style_embed, time_steps=None):
spectrogram = spectrogram
spectrogram = self.dropout(spectrogram)
style_embed = style_embed.unsqueeze(2)
style_embed = style_embed.repeat(1, 1, spectrogram.shape[2])
spectrogram = torch.cat([spectrogram, style_embed], dim=1)
x1 = self.first_net(spectrogram) # .permute(0, 2, 1)
if time_steps is not None:
x1 = F.interpolate(
x1, size=time_steps, align_corners=False, mode='linear')
return x1
class SeqTranslator1D(nn.Module):
"""(B, C, T)->(B, C_out, T)"""
def __init__(self,
C_in,
C_out,
kernel_size=None,
stride=None,
min_layers_num=None,
residual=True,
norm='bn'):
super(SeqTranslator1D, self).__init__()
conv_layers = nn.ModuleList([])
conv_layers.append(
ConvNormRelu(
in_channels=C_in,
out_channels=C_out,
type='1d',
kernel_size=kernel_size,
stride=stride,
residual=residual,
norm=norm))
self.num_layers = 1
if min_layers_num is not None and self.num_layers < min_layers_num:
while self.num_layers < min_layers_num:
conv_layers.append(
ConvNormRelu(
in_channels=C_out,
out_channels=C_out,
type='1d',
kernel_size=kernel_size,
stride=stride,
residual=residual,
norm=norm))
self.num_layers += 1
self.conv_layers = nn.Sequential(*conv_layers)
def forward(self, x):
return self.conv_layers(x)
class ConvNormRelu(nn.Module):
"""(B,C_in,H,W) -> (B, C_out, H, W) there exist some kernel size that makes
the result is not H/s.
#TODO: there might some problems with residual
"""
def __init__(self,
in_channels,
out_channels,
type='1d',
leaky=False,
downsample=False,
kernel_size=None,
stride=None,
padding=None,
p=0,
groups=1,
residual=False,
norm='bn'):
"""conv-bn-relu."""
super(ConvNormRelu, self).__init__()
self.residual = residual
self.norm_type = norm
# kernel_size = k
# stride = s
if kernel_size is None and stride is None:
if not downsample:
kernel_size = 3
stride = 1
else:
kernel_size = 4
stride = 2
if padding is None:
if isinstance(kernel_size, int) and isinstance(stride, tuple):
padding = tuple(int((kernel_size - st) / 2) for st in stride)
elif isinstance(kernel_size, tuple) and isinstance(stride, int):
padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
padding = tuple(
int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
else:
padding = int((kernel_size - stride) / 2)
if self.residual:
if downsample:
if type == '1d':
self.residual_layer = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding))
elif type == '2d':
self.residual_layer = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding))
else:
if in_channels == out_channels:
self.residual_layer = nn.Identity()
else:
if type == '1d':
self.residual_layer = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding))
elif type == '2d':
self.residual_layer = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding))
in_channels = in_channels * groups
out_channels = out_channels * groups
if type == '1d':
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm1d(out_channels)
self.dropout = nn.Dropout(p=p)
elif type == '2d':
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm2d(out_channels)
self.dropout = nn.Dropout2d(p=p)
if norm == 'gn':
self.norm = nn.GroupNorm(2, out_channels)
elif norm == 'ln':
self.norm = nn.LayerNorm(out_channels)
if leaky:
self.relu = nn.LeakyReLU(negative_slope=0.2)
else:
self.relu = nn.ReLU()
def forward(self, x, **kwargs):
if self.norm_type == 'ln':
out = self.dropout(self.conv(x))
out = self.norm(out.transpose(1, 2)).transpose(1, 2)
else:
out = self.norm(self.dropout(self.conv(x)))
if self.residual:
residual = self.residual_layer(x)
out += residual
return self.relu(out)
================================================
FILE: models/pca.py
================================================
import numpy as np
import torch
import torch.nn as nn
class PCALayer(nn.Module):
def __init__(self, pca_info_path: str):
super().__init__()
pca_info = np.load(pca_info_path)
feat_to_data = pca_info['components_to_data'].astype(np.float32)
data_to_feat = feat_to_data.T
data_bias = pca_info['original_data_mean'].astype(np.float32)
feat_mean = pca_info['data_components_mean'].astype(np.float32)
feat_std = pca_info['data_components_std'].astype(np.float32)
self.register_buffer('feat_to_data', torch.from_numpy(feat_to_data))
self.register_buffer('data_to_feat', torch.from_numpy(data_to_feat))
self.register_buffer('data_bias', torch.from_numpy(data_bias))
self.register_buffer('feat_mean', torch.from_numpy(feat_mean))
self.register_buffer('feat_std', torch.from_numpy(feat_std))
self.pca_dim = len(feat_mean)
def encode(
self,
x: torch.Tensor,
pca_dim: int = None,
):
x = x - self.data_bias
feat = x @ (self.data_to_feat[:, :pca_dim])
return feat
def decode(self, feat: torch.Tensor, pca_dim: int = None):
data = feat[..., :pca_dim] @ (self.feat_to_data[:pca_dim])
data = data + self.data_bias
return data
class PCA:
def __init__(self, trunc_dim: int = None):
super().__init__()
def buld_pca_for_dataset(self, dataset, trunc_dim: int = 1024):
annot_array = self.load_dataset(dataset)
self.build_pca(annot_array, trunc_dim)
def load_dataset(self, dataset):
annot_list = []
indices = np.arange(len(dataset))
for i in indices:
data_dict = dataset.__getitem__(i)
annot_list.append(data_dict['data'] - data_dict['template'])
annot_array = np.concatenate(annot_list, axis=0)
return annot_array
def build_incremental_PCA(
self,
in_data: np.ndarray,
trunc_dim=None,
batch_size=1024,
):
indices = np.arange(len(in_data))
np.random.shuffle(indices)
in_data = in_data[indices]
if in_data.dtype != np.float32:
in_data = in_data.astype(np.float32)
from sklearn.decomposition import IncrementalPCA
ipca = IncrementalPCA(n_components=trunc_dim, batch_size=batch_size)
data_mean = in_data.mean(0)
in_data = in_data - data_mean
in_data_slices = np.split(
in_data, np.arange(batch_size, len(in_data), batch_size))
for batch_data in in_data_slices:
if len(batch_data) != batch_size:
continue
ipca.partial_fit(batch_data)
S = ipca.singular_values_
components_to_data = ipca.components_
explained_variance_ratio = ipca.explained_variance_ratio_
data_components = ipca.transform(in_data)
data_components_mean = data_components.mean(0)
data_components_std = data_components.std(0)
ret_dict = {
'explained_variance_ratio': explained_variance_ratio,
'components_to_data': components_to_data,
'original_data_mean': data_mean,
'S': S,
'data_components_mean': data_components_mean,
'data_components_std': data_components_std
}
return ret_dict
def build_pca(
self,
in_data: np.ndarray,
trunc_dim=None,
save_compactness=True,
check_svd=True,
):
ret_dict = {}
in_data = torch.Tensor(in_data).cuda()
count, in_dim = in_data.shape
mean_data = in_data.mean(dim=0)
in_data_center = in_data - mean_data # (count, in_dim)
in_data_center = in_data_center.cpu()
# U (count, count), S(count, ), Vt (count, in_dim)
U, S, Vt = torch.linalg.svd(in_data_center, full_matrices=False)
if check_svd:
S_mat = torch.diag(S)
rebuilt = U @ S_mat @ Vt
is_precise = torch.allclose(rebuilt, in_data_center)
print('svd rebuilt all close ?', is_precise)
var_S = S**2
if save_compactness:
compact_vector = var_S / var_S.sum()
compact_value = compact_vector[0:trunc_dim].sum()
ret_dict['compactness, ', compact_value]
pca_bases = Vt[0:trunc_dim].reshape(trunc_dim, -1)
ret_dict['pca_bases'] = pca_bases
ret_dict['pca_std'] = S[0:trunc_dim]
ret_dict['mean_data'] = mean_data
return ret_dict
================================================
FILE: models/unitalker.py
================================================
import os.path as osp
import torch.nn as nn
from dataset.dataset_config import dataset_config
from models.wav2vec2 import Wav2Vec2Model
from .base_model import BaseModel
# from models.hubert import HubertModel
from models.wavlm import WavLMModel
class OutHead(BaseModel):
def __init__(self, args, out_dim: int):
super().__init__()
head_layer = []
in_dim = args.decoder_dimension
for i in range(args.headlayer - 1):
head_layer.append(nn.Linear(in_dim, in_dim))
head_layer.append(nn.Tanh())
head_layer.append(nn.Linear(in_dim, out_dim))
self.head_layer = nn.Sequential(*head_layer)
def forward(self, x):
return self.head_layer(x)
class UniTalker(BaseModel):
def __init__(self, args):
super().__init__()
self.args = args
if 'wav2vec2' in args.audio_encoder_repo:
self.audio_encoder = Wav2Vec2Model.from_pretrained(
args.audio_encoder_repo)
elif 'wavlm' in args.audio_encoder_repo:
self.audio_encoder = WavLMModel.from_pretrained(
args.audio_encoder_repo)
else:
raise ValueError("wrong audio_encoder_repo")
self.audio_encoder.feature_extractor._freeze_parameters()
if args.freeze_wav2vec:
self.audio_encoder._freeze_wav2vec2_parameters()
if args.decoder_type == 'conv':
from .unitalker_decoder import UniTalkerDecoderTCN as Decoder
elif args.decoder_type == 'transformer':
from .unitalker_decoder import \
UniTalkerDecoderTransformer as Decoder
else:
ValueError('unknown decoder type ')
self.decoder = Decoder(args)
if args.use_pca:
pca_layer_dict = {}
from models.pca import PCALayer
for dataset_name in args.dataset:
if dataset_config[dataset_name]['pca']:
annot_type = dataset_config[dataset_name]['annot_type']
dirname = dataset_config[dataset_name]['dirname']
pca_path = osp.join(args.data_root, dirname, 'pca.npz')
pca_layer_dict[annot_type] = PCALayer(pca_path)
self.pca_layer_dict = nn.ModuleDict(pca_layer_dict)
else:
self.pca_layer_dict = {}
self.pca_dim = args.pca_dim
self.use_pca = args.use_pca
self.interpolate_pos = args.interpolate_pos
out_head_dict = {}
for dataset_name in args.dataset:
annot_type = dataset_config[dataset_name]['annot_type']
out_dim = dataset_config[dataset_name]['annot_dim']
if (args.use_pca is True) and (annot_type in self.pca_layer_dict):
pca_dim = self.pca_layer_dict[annot_type].pca_dim
out_dim = min(args.pca_dim, out_dim, pca_dim)
if args.headlayer == 1:
out_projection = nn.Linear(args.decoder_dimension, out_dim)
nn.init.constant_(out_projection.weight, 0)
nn.init.constant_(out_projection.bias, 0)
else:
out_projection = OutHead(args, out_dim)
out_head_dict[annot_type] = out_projection
self.out_head_dict = nn.ModuleDict(out_head_dict)
self.identity_num = args.identity_num
return
def forward(self,
audio,
template,
face_motion,
style_idx,
annot_type: str,
fps: float = None):
if face_motion is not None:
frame_num = face_motion.shape[1]
else:
frame_num = round(audio.shape[-1] / 16000 * fps)
hidden_states = self.audio_encoder(
audio, frame_num=frame_num, interpolate_pos=self.interpolate_pos)
hidden_states = hidden_states.last_hidden_state
decoder_out = self.decoder(hidden_states, style_idx, frame_num)
if (not self.use_pca) or (annot_type not in self.pca_layer_dict):
out_motion = self.out_head_dict[annot_type](decoder_out)
out_motion = out_motion + template
out_pca, gt_pca = None, None
else:
out_pca = self.out_head_dict[annot_type](decoder_out)
out_motion = self.pca_layer_dict[annot_type].decode(
out_pca, self.pca_dim)
out_motion = out_motion + template
if face_motion is not None:
gt_pca = self.pca_layer_dict[annot_type].encode(
face_motion - template, self.pca_dim)
else:
gt_pca = None
return out_motion, out_pca, gt_pca
================================================
FILE: models/unitalker_decoder.py
================================================
# yapf: disable
import torch
import torch.nn as nn
from .base_model import BaseModel
from .model_utils import (
TCN, PeriodicPositionalEncoding, enc_dec_mask, init_biased_mask,
linear_interpolation,
)
# yapf: enable
class UniTalkerDecoderTCN(BaseModel):
def __init__(self, args) -> None:
super().__init__()
self.learnable_style_emb = nn.Embedding(args.identity_num, 64)
in_dim = args.decoder_dimension
out_dim = args.decoder_dimension
self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, in_dim)
self.tcn = TCN(in_dim + 64, out_dim)
self.interpolate_pos = args.interpolate_pos
def forward(self,
hidden_states: torch.Tensor,
style_idx: torch.Tensor,
frame_num: int = None):
batch_size = len(hidden_states)
if style_idx is None:
style_embedding = self.learnable_style_emb.weight[-1].unsqueeze(
0).repeat(batch_size, 1)
else:
style_embedding = self.learnable_style_emb(style_idx)
feature = self.audio_feature_map(hidden_states).transpose(1, 2)
feature = self.tcn(feature, style_embedding)
feature = feature.transpose(1, 2)
if self.interpolate_pos == 2:
feature = linear_interpolation(feature, output_len=frame_num)
return feature
class UniTalkerDecoderTransformer(BaseModel):
def __init__(self, args) -> None:
super().__init__()
out_dim = args.decoder_dimension
self.learnable_style_emb = nn.Embedding(args.identity_num, out_dim)
self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, out_dim)
self.PPE = PeriodicPositionalEncoding(
out_dim, period=args.period, max_seq_len=3000)
self.biased_mask = init_biased_mask(
n_head=4, max_seq_len=3000, period=args.period)
decoder_layer = nn.TransformerDecoderLayer(
d_model=out_dim,
nhead=4,
dim_feedforward=2 * out_dim,
batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer, num_layers=1)
self.interpolate_pos = args.interpolate_pos
def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,
frame_num: int):
obj_embedding = self.learnable_style_emb(style_idx)
obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1)
hidden_states = self.audio_feature_map(hidden_states)
style_input = self.PPE(obj_embedding)
tgt_mask = self.biased_mask[:, :style_input.shape[1], :style_input.
shape[1]].clone().detach().to(
device=style_input.device)
memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1],
frame_num)
feat_out = self.transformer_decoder(
style_input,
hidden_states,
tgt_mask=tgt_mask,
memory_mask=memory_mask)
if self.interpolate_pos == 2:
feat_out = linear_interpolation(feat_out, output_len=frame_num)
return feat_out
================================================
FILE: models/wav2vec2.py
================================================
import numpy as np
import torch
from transformers import Wav2Vec2Model
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Tuple
from .model_utils import linear_interpolation
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.Tensor] = None,
min_masks: int = 0,
) -> np.ndarray:
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(mask_prob * all_sz / float(mask_length) +
np.random.rand())
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(mask_prob * sz / float(mask_length) +
np.random.rand())
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
lengths = np.full(num_mask, mask_length)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([
mask_idc[j] + offset for j in range(len(mask_idc))
for offset in range(lengths[j])
])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
class Wav2Vec2Model(Wav2Vec2Model):
def __init__(self, config):
super().__init__(config)
def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
frame_num=None,
interpolate_pos: int = 0,
):
self.config.output_attentions = True
output_attentions = output_attentions if \
output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
conv_features = self.feature_extractor(input_values)
hidden_states = conv_features.transpose(1, 2)
if interpolate_pos == 0:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if attention_mask is not None:
output_lengths = self._get_feat_extract_output_lengths(
attention_mask.sum(-1))
attention_mask = torch.zeros(
hidden_states.shape[:2],
dtype=hidden_states.dtype,
device=hidden_states.device)
attention_mask[(torch.arange(
attention_mask.shape[0],
device=hidden_states.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip(
[-1]).bool()
hidden_states, _ = self.feature_projection(hidden_states)
if self.config.apply_spec_augment and self.training:
batch_size, sequence_length, hidden_size = hidden_states.size()
if self.config.mask_time_prob > 0:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
self.config.mask_time_prob,
self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=2,
)
hidden_states[torch.from_numpy(
mask_time_indices)] = self.masked_spec_embed.to(
hidden_states.dtype)
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
self.config.mask_feature_prob,
self.config.mask_feature_length,
)
mask_feature_indices = torch.from_numpy(
mask_feature_indices).to(hidden_states.device)
hidden_states[mask_feature_indices[:, None].expand(
-1, sequence_length, -1)] = 0
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if interpolate_pos == 1:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if not return_dict:
return (hidden_states, ) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
================================================
FILE: models/wavlm.py
================================================
import numpy as np
import torch
from transformers import WavLMModel
from transformers.modeling_outputs import Wav2Vec2BaseModelOutput
from typing import Optional, Tuple, Union
from .model_utils import linear_interpolation
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
class WavLMModel(WavLMModel):
def __init__(self, config):
super().__init__(config)
def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
frame_num=None,
interpolate_pos: int = 0,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if interpolate_pos == 0:
extract_features = linear_interpolation(
extract_features, output_len=frame_num)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if interpolate_pos == 1:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
================================================
FILE: utils/config.py
================================================
# -----------------------------------------------------------------------------
# Functions for parsing args
# -----------------------------------------------------------------------------
import copy
import os
import yaml
from ast import literal_eval
class CfgNode(dict):
"""CfgNode represents an internal node in the configuration tree.
It's a simple dict-like container that allows for attribute-based access to
keys.
"""
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
# Recursively convert nested dictionaries in init_dict into CfgNodes
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
for k, v in init_dict.items():
if type(v) is dict:
# Convert dict to CfgNode
init_dict[k] = CfgNode(v, key_list=key_list + [k])
super(CfgNode, self).__init__(init_dict)
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
self[name] = value
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
r = ''
s = []
for k, v in sorted(self.items()):
separator = '\n' if isinstance(v, CfgNode) else ' '
attr_str = '{}:{}{}'.format(str(k), separator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += '\n'.join(s)
return r
def __repr__(self):
return '{}({})'.format(self.__class__.__name__,
super(CfgNode, self).__repr__())
def load_cfg_from_cfg_file(file):
cfg = {}
assert os.path.isfile(file) and file.endswith('.yaml'), \
'{} is not a yaml file'.format(file)
with open(file, 'r') as f:
cfg_from_file = yaml.safe_load(f)
for key in cfg_from_file:
for k, v in cfg_from_file[key].items():
cfg[k] = v
cfg = CfgNode(cfg)
return cfg
def merge_cfg_from_list(cfg, cfg_list):
new_cfg = copy.deepcopy(cfg)
assert len(cfg_list) % 2 == 0
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
subkey = full_key.split('.')[-1]
assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
value = _decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
full_key)
setattr(new_cfg, subkey, value)
return new_cfg
def _decode_cfg_value(v):
"""Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object."""
# All remaining processing is only applied to strings
if not isinstance(v, str):
return v
# Try to interpret `v` as a:
# string, number, tuple, list, dict, boolean, or None
try:
v = literal_eval(v)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return v
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
"""Checks that `replacement`, which is intended to replace `original` is of
the right type.
The type is correct if it matches exactly or is one of a few cases in which
the type can be easily coerced.
"""
original_type = type(original)
replacement_type = type(replacement)
# The types must match (with some exceptions)
if replacement_type == original_type or original is None:
return replacement
# types match from_type and to_type
def conditional_cast(from_type, to_type):
if replacement_type == from_type and original_type == to_type:
return True, to_type(replacement)
else:
return False, None
# Conditionally casts
# list <-> tuple
casts = [(tuple, list), (list, tuple)]
# For py2: allow converting from str (bytes) to a unicode string
try:
casts.append((str, unicode)) # noqa: F821
except Exception:
pass
for (from_type, to_type) in casts:
converted, converted_value = conditional_cast(from_type, to_type)
if converted:
return converted_value
raise ValueError(
'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
'key: {}'.format(original_type, replacement_type, original,
replacement, full_key))
================================================
FILE: utils/utils.py
================================================
import argparse
import logging
import numpy as np
import os
import random
import torch
from copy import deepcopy
from . import config
def get_audio_encoder_dim(audio_encoder: str):
if audio_encoder == 'microsoft/wavlm-base-plus':
return 768
elif audio_encoder == 'facebook/wav2vec2-large-xlsr-53':
return 1024
else:
raise ValueError("wrong audio_encoder")
def filer_list(in_list: list, key: str):
ret_list = []
for line in in_list:
if line.startswith(key):
ret_list.append(line)
return ret_list
def count_checkpoint_params(ckpt: dict):
params = 0
pca_params = 0
for k, v in ckpt.items():
if 'pca' in k:
print(k)
pca_params += np.prod(v.size())
params += np.prod(v.size())
return params, pca_params
def read_obj(in_path):
with open(in_path, 'r') as obj_file:
# Read the lines of the OBJ file
lines = obj_file.readlines()
# Initialize empty lists for vertices and faces
verts = []
faces = []
for line in lines:
line = line.strip() # Remove leading/trailing whitespace
elements = line.split() # Split the line into elements
if len(elements) == 0:
continue # Skip empty lines
# Check the type of line (vertex or face)
if elements[0] == 'v':
# Vertex line
x, y, z = map(float,
elements[1:4]) # Extract the vertex coordinates
verts.append((x, y, z)) # Add the vertex to the list
elif elements[0] == 'f':
# Face line
face_indices = [
int(index.split('/')[0]) for index in elements[1:]
] # Extract the vertex indices
faces.append(face_indices) # Add the face to the list
verts = np.array(verts)
faces = np.array(faces)
if faces.min() == 1:
faces = faces - 1
return verts, faces
def get_logger(log_file: str):
logger_name = 'main-logger'
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
handler = logging.FileHandler(log_file, mode='w')
fmt = '[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d]=>%(message)s' # noqa: E501
handler.setFormatter(logging.Formatter(fmt))
for hdl in logger.handlers:
logger.removeHandler(hdl)
logger.addHandler(handler)
return logger
def get_template_dict(data_root, annot_type: str):
import pickle
if annot_type == 'BIWI_23370_vertices':
template_pkl_path = os.path.join(data_root, 'BIWI', 'templates.pkl')
elif annot_type == 'FLAME_5023_vertices':
template_pkl_path = os.path.join(data_root, 'vocaset', 'templates.pkl')
with open(template_pkl_path, 'rb') as f:
template_dict = pickle.load(f, encoding='latin')
return template_dict
def filter_unitalker_state_dict(in_dict):
out_dict = {}
for k, v in in_dict.items():
if not 'pca_layer_dict' in k:
out_dict[k] = v
return out_dict
def get_template_verts(data_root, dataset_name: str, subject: int):
from dataset.dataset_config import dataset_config
template_path = os.path.join(data_root, dataset_config[dataset_name]['dirname'], 'id_template.npy')
return np.load(template_path)[subject]
def load_ckpt(model, ckpt_path, re_init_decoder_and_head: bool = False):
checkpoint = torch.load(ckpt_path, map_location='cpu')
id_embedding_key = 'decoder.learnable_style_emb.weight'
if len(checkpoint[id_embedding_key]) != len(
model.state_dict()[id_embedding_key]):
target_id_num = len(model.state_dict()[id_embedding_key])
checkpoint[id_embedding_key] = checkpoint[id_embedding_key][-1][
None].repeat(target_id_num, 1)
if re_init_decoder_and_head is True:
ckpt_keys = list(checkpoint.keys())
encoder_keys = filer_list(ckpt_keys, 'audio_encoder')
reinit_keys = list(set(ckpt_keys) - set(encoder_keys))
for k in reinit_keys:
if k in model.state_dict():
checkpoint[k] = model.state_dict()[k]
model.load_state_dict(checkpoint, strict=False)
return
class AverageMeter(object):
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_avg(self, ):
return self.avg
def get_average_meter_dict(args, phase):
if phase == 'train':
metric_name_list = ['rec_L2_loss', 'pca_loss']
elif phase == 'val' or phase == 'test':
metric_name_list = ['rec_L2_loss', 'LVE', 'LVD']
else:
raise ValueError('wrong phase for average meter')
average_meter_dict_template = {}
for k in metric_name_list:
average_meter_dict_template[k] = AverageMeter()
mixed_dataset_name = 'mixed_dataset'
dataset_name_list = args.dataset + [mixed_dataset_name]
average_meter_dict = {}
for dataset_name in dataset_name_list:
average_meter_dict[dataset_name] = deepcopy(
average_meter_dict_template)
return average_meter_dict, metric_name_list
def log_datasetloss(
logger,
epoch: int,
phase: str,
dataset_loss_dict: dict,
only_mix: bool = True,
):
base_str = f'{phase.upper():<5} Epoch: {epoch:03d} '
for dataset_name, loss_info in dataset_loss_dict.items():
if only_mix and dataset_name != 'mixed_dataset':
continue
for loss_type, avg_meter in loss_info.items():
base_str += f'{loss_type}_{dataset_name}:{avg_meter.avg: .8f}, '
logger.info(base_str)
return base_str
def write_to_tensorboard(writer,
epoch: int,
phase: str,
dataset_loss_dict: dict,
only_mix: bool = False):
for dataset_name, loss_info in dataset_loss_dict.items():
if only_mix and dataset_name != 'mixed_dataset':
continue
for loss_type, avg_meter in loss_info.items():
key = f'{phase}/{loss_type}_{dataset_name}'
writer.add_scalar(key, avg_meter.avg, epoch)
return 0
def get_parser():
parser = argparse.ArgumentParser(description=' ')
parser.add_argument(
'--config',
type=str,
default='config/unitalker.yaml',
help='config file')
parser.add_argument(
'--condition_id', type=str, default='common', help='condition id')
parser.add_argument(
'opts', help=' ', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
assert args.config is not None
cfg = config.load_cfg_from_cfg_file(args.config)
if args.opts is not None:
cfg = config.merge_cfg_from_list(cfg, args.opts)
cfg.condition_id = args.condition_id
cfg.dataset = cfg.dataset.split(',')
cfg.demo_dataset = cfg.demo_dataset.split(',')
if isinstance(cfg.duplicate_list, str):
cfg.duplicate_list = cfg.duplicate_list.split(',')
cfg.duplicate_list = [int(i) for i in cfg.duplicate_list]
return cfg
def seed_everything(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
return
gitextract_hq8r3so5/
├── .gitignore
├── LICENSE
├── README.md
├── config/
│ └── unitalker.yaml
├── dataset/
│ ├── data_item.py
│ ├── dataset.py
│ └── dataset_config.py
├── loss/
│ └── loss.py
├── main/
│ ├── demo.py
│ ├── render.py
│ ├── train.py
│ ├── train_eval_loop.py
│ └── train_pca.py
├── models/
│ ├── base_model.py
│ ├── hubert.py
│ ├── model_utils.py
│ ├── pca.py
│ ├── unitalker.py
│ ├── unitalker_decoder.py
│ ├── wav2vec2.py
│ └── wavlm.py
└── utils/
├── config.py
└── utils.py
SYMBOL INDEX (150 symbols across 18 files)
FILE: dataset/data_item.py
class DataItem (line 8) | class DataItem:
method __init__ (line 10) | def __init__(self,
method load_audio (line 49) | def load_audio(self, audio_path: str):
method truncate (line 55) | def truncate(self, audio_array: np.ndarray, sr: int,
method offset_id (line 65) | def offset_id(self, offset: int):
method scale_and_offset (line 70) | def scale_and_offset(self,
method to_dict (line 76) | def to_dict(self, ) -> Any:
method get_dict (line 91) | def get_dict(self, ):
FILE: dataset/dataset.py
class AudioFaceDataset (line 13) | class AudioFaceDataset(data.Dataset):
method __init__ (line 15) | def __init__(self, data_label_path: str, args: dict = None) -> None:
method __len__ (line 44) | def __len__(self, ):
method __getitem__ (line 47) | def __getitem__(self, index: Any) -> Any:
method get_identity_num (line 51) | def get_identity_num(self, ):
class MixAudioFaceDataset (line 55) | class MixAudioFaceDataset(AudioFaceDataset):
method __init__ (line 57) | def __init__(self,
function get_dataset_list (line 80) | def get_dataset_list(args):
function get_single_dataset (line 104) | def get_single_dataset(args, json_path: str = ''):
function get_dataloaders (line 123) | def get_dataloaders(args):
FILE: loss/loss.py
class FlameParams (line 11) | class FlameParams:
class FlameLayer (line 18) | class FlameLayer(nn.Module):
method __init__ (line 20) | def __init__(self, model_path: str) -> None:
method flame_params_from_3dmm (line 76) | def flame_params_from_3dmm(self,
method forward (line 97) | def forward(self, tensor_3dmm: torch.Tensor):
class FlameParamLossForDADHead (line 126) | class FlameParamLossForDADHead(nn.Module):
method __init__ (line 128) | def __init__(self, mouth_indices_path: str, loss_mask_path: str,
method forward (line 138) | def forward(self, x: torch.Tensor, target: torch.Tensor):
method get_vertices (line 147) | def get_vertices(
method mouth_metric (line 154) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
class VerticesLoss (line 165) | class VerticesLoss(nn.Module):
method __init__ (line 167) | def __init__(self, mouth_indices_path: str) -> None:
method forward (line 173) | def forward(self, x: torch.Tensor, target: torch.Tensor):
method get_vertices (line 176) | def get_vertices(
method mouth_metric (line 182) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
class BlendShapeLoss (line 193) | class BlendShapeLoss(nn.Module):
method __init__ (line 195) | def __init__(self,
method forward (line 215) | def forward(self, x: torch.Tensor, target: torch.Tensor):
method bs2vertices (line 219) | def bs2vertices(self, x: torch.Tensor):
method get_vertices (line 224) | def get_vertices(
method mouth_metric (line 231) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
class ParamLoss (line 244) | class ParamLoss(nn.Module):
method __init__ (line 246) | def __init__(self, weight: float):
method forward (line 252) | def forward(self, x: torch.Tensor, target: torch.Tensor):
method mouth_metric (line 255) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
method get_vertices (line 258) | def get_vertices(
class UniTalkerLoss (line 265) | class UniTalkerLoss(nn.Module):
method __init__ (line 267) | def __init__(self, args) -> None:
method forward (line 344) | def forward(
method pca_loss (line 352) | def pca_loss(self, x: torch.Tensor, target: torch.Tensor):
method mouth_metric (line 355) | def mouth_metric(self, x: torch.Tensor, target: torch.Tensor,
method get_vertices (line 359) | def get_vertices(self, x: torch.Tensor, annot_type: str):
FILE: main/demo.py
function get_all_audios (line 14) | def get_all_audios(audio_root:str):
function split_long_audio (line 26) | def split_long_audio(
function merge_out_list (line 47) | def merge_out_list(out_list: list, fps:int):
function main (line 69) | def main():
FILE: main/render.py
function render_meshes (line 23) | def render_meshes(mesh_vertices: torch.Tensor,
function read_obj (line 131) | def read_obj(in_path):
function pad_for_libx264 (line 161) | def pad_for_libx264(image_array):
function array_to_video (line 209) | def array_to_video(
function get_obj_faces (line 287) | def get_obj_faces(annot_type: str):
function vis_model_out_proxy (line 303) | def vis_model_out_proxy():
FILE: main/train.py
function main (line 21) | def main():
FILE: main/train_eval_loop.py
function train_epoch (line 8) | def train_epoch(data_loader, model, loss_module, optimizer, epoch, args):
function validate_epoch (line 53) | def validate_epoch(data_loader, model, loss_module, epoch, phase, args):
FILE: main/train_pca.py
function main (line 8) | def main():
function main_worker (line 15) | def main_worker(cfg):
FILE: models/base_model.py
class BaseModel (line 5) | class BaseModel(nn.Module):
method __init__ (line 8) | def __init__(self):
method forward (line 12) | def forward(self, *x):
method freeze_model (line 19) | def freeze_model(self, do_freeze: bool = True):
method summary (line 23) | def summary(self, logger, writer=None):
FILE: models/hubert.py
function _compute_mask_indices (line 12) | def _compute_mask_indices(
class HubertModel (line 61) | class HubertModel(HubertModel):
method __init__ (line 63) | def __init__(self, config):
method _freeze_parameters (line 66) | def _freeze_parameters(self, do_freeze: bool = True):
method forward (line 70) | def forward(
FILE: models/model_utils.py
function linear_interpolation (line 9) | def linear_interpolation(features, output_len: int):
function init_biased_mask (line 17) | def init_biased_mask(n_head, max_seq_len, period):
function enc_dec_mask (line 53) | def enc_dec_mask(device, T, S):
class PeriodicPositionalEncoding (line 61) | class PeriodicPositionalEncoding(nn.Module):
method __init__ (line 63) | def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
method forward (line 78) | def forward(self, x):
class TCN (line 83) | class TCN(nn.Module):
method __init__ (line 85) | def __init__(self, in_dim, out_dim):
method forward (line 91) | def forward(self, spectrogram, style_embed, time_steps=None):
class SeqTranslator1D (line 106) | class SeqTranslator1D(nn.Module):
method __init__ (line 109) | def __init__(self,
method forward (line 144) | def forward(self, x):
class ConvNormRelu (line 148) | class ConvNormRelu(nn.Module):
method __init__ (line 155) | def __init__(self,
method forward (line 264) | def forward(self, x, **kwargs):
FILE: models/pca.py
class PCALayer (line 6) | class PCALayer(nn.Module):
method __init__ (line 8) | def __init__(self, pca_info_path: str):
method encode (line 23) | def encode(
method decode (line 32) | def decode(self, feat: torch.Tensor, pca_dim: int = None):
class PCA (line 38) | class PCA:
method __init__ (line 40) | def __init__(self, trunc_dim: int = None):
method buld_pca_for_dataset (line 43) | def buld_pca_for_dataset(self, dataset, trunc_dim: int = 1024):
method load_dataset (line 47) | def load_dataset(self, dataset):
method build_incremental_PCA (line 56) | def build_incremental_PCA(
method build_pca (line 94) | def build_pca(
FILE: models/unitalker.py
class OutHead (line 9) | class OutHead(BaseModel):
method __init__ (line 11) | def __init__(self, args, out_dim: int):
method forward (line 22) | def forward(self, x):
class UniTalker (line 26) | class UniTalker(BaseModel):
method __init__ (line 28) | def __init__(self, args):
method forward (line 87) | def forward(self,
FILE: models/unitalker_decoder.py
class UniTalkerDecoderTCN (line 14) | class UniTalkerDecoderTCN(BaseModel):
method __init__ (line 16) | def __init__(self, args) -> None:
method forward (line 25) | def forward(self,
class UniTalkerDecoderTransformer (line 43) | class UniTalkerDecoderTransformer(BaseModel):
method __init__ (line 45) | def __init__(self, args) -> None:
method forward (line 63) | def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,
FILE: models/wav2vec2.py
function _compute_mask_indices (line 12) | def _compute_mask_indices(
class Wav2Vec2Model (line 61) | class Wav2Vec2Model(Wav2Vec2Model):
method __init__ (line 63) | def __init__(self, config):
method _freeze_wav2vec2_parameters (line 66) | def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
method forward (line 70) | def forward(
FILE: models/wavlm.py
class WavLMModel (line 16) | class WavLMModel(WavLMModel):
method __init__ (line 17) | def __init__(self, config):
method _freeze_wav2vec2_parameters (line 20) | def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
method forward (line 24) | def forward(
FILE: utils/config.py
class CfgNode (line 10) | class CfgNode(dict):
method __init__ (line 17) | def __init__(self, init_dict=None, key_list=None, new_allowed=False):
method __getattr__ (line 27) | def __getattr__(self, name):
method __setattr__ (line 33) | def __setattr__(self, name, value):
method __str__ (line 36) | def __str__(self):
method __repr__ (line 58) | def __repr__(self):
function load_cfg_from_cfg_file (line 63) | def load_cfg_from_cfg_file(file):
function merge_cfg_from_list (line 79) | def merge_cfg_from_list(cfg, cfg_list):
function _decode_cfg_value (line 93) | def _decode_cfg_value(v):
function _check_and_coerce_cfg_value_type (line 121) | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
FILE: utils/utils.py
function get_audio_encoder_dim (line 12) | def get_audio_encoder_dim(audio_encoder: str):
function filer_list (line 20) | def filer_list(in_list: list, key: str):
function count_checkpoint_params (line 28) | def count_checkpoint_params(ckpt: dict):
function read_obj (line 39) | def read_obj(in_path):
function get_logger (line 73) | def get_logger(log_file: str):
function get_template_dict (line 86) | def get_template_dict(data_root, annot_type: str):
function filter_unitalker_state_dict (line 97) | def filter_unitalker_state_dict(in_dict):
function get_template_verts (line 104) | def get_template_verts(data_root, dataset_name: str, subject: int):
function load_ckpt (line 110) | def load_ckpt(model, ckpt_path, re_init_decoder_and_head: bool = False):
class AverageMeter (line 129) | class AverageMeter(object):
method __init__ (line 132) | def __init__(self):
method reset (line 135) | def reset(self):
method update (line 141) | def update(self, val, n=1):
method get_avg (line 147) | def get_avg(self, ):
function get_average_meter_dict (line 151) | def get_average_meter_dict(args, phase):
function log_datasetloss (line 170) | def log_datasetloss(
function write_to_tensorboard (line 187) | def write_to_tensorboard(writer,
function get_parser (line 201) | def get_parser():
function seed_everything (line 226) | def seed_everything(seed: int = 42):
Condensed preview — 23 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (126K chars).
[
{
"path": ".gitignore",
"chars": 391,
"preview": "# Byte-compiled / optimized files\n*__pycache__/\n*.py[cod]\n\n\n# Distribution / packaging\ndist/\nbuild/\n*.egg-info/\n*.egg\n\n#"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 5024,
"preview": "\n# UniTalker: Scaling up Audio-Driven 3D Facial Animation through A Unified Model [ECCV2024]\n\n## Useful Links\n\n<div alig"
},
{
"path": "config/unitalker.yaml",
"chars": 1262,
"preview": "DATA:\n dataset: \"D0,D1,D2,D3,D4,D5,D6,D7\"\n demo_dataset: \"D0,D1,D2,D3,D4,D6,D7\" # To visualize D5, please download FL"
},
{
"path": "dataset/data_item.py",
"chars": 3432,
"preview": "import librosa\nimport numpy as np\nfrom typing import Any\n\nfrom .dataset_config import dataset_config\n\n\nclass DataItem:\n\n"
},
{
"path": "dataset/dataset.py",
"chars": 6275,
"preview": "import json\nimport numpy as np\nimport os.path as osp\nfrom torch.utils import data\nfrom tqdm import tqdm\nfrom transformer"
},
{
"path": "dataset/dataset_config.py",
"chars": 1648,
"preview": "dataset_config = {\n 'D0': {\n 'dirname': 'D0_BIWI',\n 'annot_type': 'BIWI_23370_vertices',\n 'scale"
},
{
"path": "loss/loss.py",
"chars": 13328,
"preview": "import numpy as np\nimport pickle\nimport torch\nfrom dataclasses import dataclass\nfrom smplx.lbs import lbs\nfrom smplx.uti"
},
{
"path": "main/demo.py",
"chars": 6550,
"preview": "import librosa\nimport numpy as np\nimport os\nimport torch\nimport torch.optim\nimport torch.utils.data\nfrom transformers im"
},
{
"path": "main/render.py",
"chars": 12620,
"preview": "\"\"\"Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V.\n\n(MPG) is holder of all proprietary rights on this com"
},
{
"path": "main/train.py",
"chars": 3210,
"preview": "#!/usr/bin/env python\n# yapf: disable\nimport os\nimport torch\nimport torch.optim\nimport torch.utils.data\nfrom tensorboard"
},
{
"path": "main/train_eval_loop.py",
"chars": 3900,
"preview": "import numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom utils.utils import log_datasetloss, write_to_tensorboard\n\n\nd"
},
{
"path": "main/train_pca.py",
"chars": 701,
"preview": "#!/usr/bin/env python\nimport numpy as np\n\nfrom models.pca import PCA\nfrom utils.utils import get_parser, get_audio_encod"
},
{
"path": "models/base_model.py",
"chars": 997,
"preview": "import numpy as np\nimport torch.nn as nn\n\n\nclass BaseModel(nn.Module):\n \"\"\"Base class for all models.\"\"\"\n\n def __i"
},
{
"path": "models/hubert.py",
"chars": 5821,
"preview": "import numpy as np\nimport torch\nfrom transformers import HubertModel \nfrom transformers.modeling_outputs import BaseMode"
},
{
"path": "models/model_utils.py",
"chars": 9809,
"preview": "# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/main.py\nimport math\nimport torch\nimport torch.nn as nn"
},
{
"path": "models/pca.py",
"chars": 4533,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\n\n\nclass PCALayer(nn.Module):\n\n def __init__(self, pca_info_path"
},
{
"path": "models/unitalker.py",
"chars": 4678,
"preview": "import os.path as osp\nimport torch.nn as nn\n\nfrom dataset.dataset_config import dataset_config\nfrom models.wav2vec2 impo"
},
{
"path": "models/unitalker_decoder.py",
"chars": 3224,
"preview": "# yapf: disable\nimport torch\nimport torch.nn as nn\n\nfrom .base_model import BaseModel\nfrom .model_utils import (\n TCN"
},
{
"path": "models/wav2vec2.py",
"chars": 5838,
"preview": "import numpy as np\nimport torch\nfrom transformers import Wav2Vec2Model\nfrom transformers.modeling_outputs import BaseMod"
},
{
"path": "models/wavlm.py",
"chars": 3241,
"preview": "import numpy as np\nimport torch\nfrom transformers import WavLMModel\nfrom transformers.modeling_outputs import Wav2Vec2Ba"
},
{
"path": "utils/config.py",
"chars": 5317,
"preview": "# -----------------------------------------------------------------------------\n# Functions for parsing args\n# ---------"
},
{
"path": "utils/utils.py",
"chars": 7570,
"preview": "import argparse\nimport logging\nimport numpy as np\nimport os\nimport random\nimport torch\nfrom copy import deepcopy\n\nfrom ."
}
]
About this extraction
This page contains the full source code of the X-niper/UniTalker GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 23 files (117.9 KB), approximately 28.5k tokens, and a symbol index with 150 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.