Repository: X-niper/UniTalker
Branch: main
Commit: 21481130f126
Files: 23
Total size: 117.9 KB
Directory structure:
gitextract_hq8r3so5/
├── .gitignore
├── LICENSE
├── README.md
├── config/
│ └── unitalker.yaml
├── dataset/
│ ├── data_item.py
│ ├── dataset.py
│ └── dataset_config.py
├── loss/
│ └── loss.py
├── main/
│ ├── demo.py
│ ├── render.py
│ ├── train.py
│ ├── train_eval_loop.py
│ └── train_pca.py
├── models/
│ ├── base_model.py
│ ├── hubert.py
│ ├── model_utils.py
│ ├── pca.py
│ ├── unitalker.py
│ ├── unitalker_decoder.py
│ ├── wav2vec2.py
│ └── wavlm.py
└── utils/
├── config.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized files
*__pycache__/
*.py[cod]
# Distribution / packaging
dist/
build/
*.egg-info/
*.egg
# Virtual environments
venv/
env/
.venv/
.pip/
pipenv/
# IDE specific files
.idea/
.vscode/
# Logs and temporary files
*.log
*.swp
*.swo
*.tmp
# Miscellaneous
.DS_Store
records.txt
binary_resources
do_experiment.py
small_tools
scancel_job.py
results
template_resources
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# UniTalker: Scaling up Audio-Driven 3D Facial Animation through A Unified Model [ECCV2024]
## Useful Links
> UniTalker generates realistic facial motion from different audio domains, including clean and noisy voices in various languages, text-to-speech-generated audios, and even noisy songs accompanied by back-ground music.
>
> UniTalker can output multiple annotations.
>
> For datasets with new annotations, one can simply plug new heads into UniTalker and train it with existing datasets or solely with new ones, avoiding retopology.
## Installation
### Environment
- Linux
- Python 3.10
- Pytorch 2.2.0
- CUDA 12.1
- transformers 4.39.3
- Pytorch3d 0.7.7 (Optional: just for rendering the results)
```bash
conda create -n unitalker python==3.10
conda activate unitalker
conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install transformers librosa tensorboardX smplx chumpy numpy==1.23.5 opencv-python
```
## Inference
### Download checkpoints, PCA models and template resources
[UniTalker-B-[D0-D7]](https://drive.google.com/file/d/1PmF8I6lyo0_64-NgeN5qIQAX6Bg0yw44/view?usp=sharing): The base model in paper. Download it and place it in `./pretrained_models` .
[UniTalker-L-[D0-D7]](https://drive.google.com/file/d/1sH2T7KLFNjUnTM-V1eRMM1Tytxd2sYAp/view?usp=sharing): The default model in paper. Please first try the base model to run the pipeline through.
[Unitalker-data-release-V1](https://drive.google.com/file/d/1Un7TB0Z5A1CG6bgeqKlhnSOECFN-C6KK/view?usp=sharing): The released datasets, PCA models, data-split json files and id-template numpy array. Download and unzip it in this repo.
[FLAME2020](https://flame.is.tue.mpg.de/download.php): Please download FLAME 2020 and move generic_model.pkl into resources/binary_resources/flame.pkl.
Use `git lfs pull` to get `./resources.zip` and `./test_audios.zip` and unzip it in this repo.
Finally, these files should be organized as follows:
```text
├── pretrained_models
│ ├── UniTalker-B-D0-D7.pt
│ ├── UniTalker-L-D0-D7.pt
├── resources
│ ├── binary_resources
│ │ ├── 02_flame_mouth_idx.npy
│ │ ├── ...
│ │ └── vocaset_FDD_wo_eyes.npy
│ └── obj_template
│ ├── 3DETF_blendshape_weight.obj
│ ├── ...
│ └── meshtalk_6172_vertices.obj
└── unitalker_data_release_V1
│ ├── D0_BIWI
│ │ ├── id_template.npy
│ │ └── pca.npz
│ ├── D1_vocaset
│ │ ├── id_template.npy
│ │ └── pca.npz
│ ├── D2_meshtalk
│ │ ├── id_template.npy
│ │ └── pca.npz
│ ├── D3D4_3DETF
│ │ ├── D3_HDTF
│ │ └── D4_RAVDESS
│ ├── D5_unitalker_faceforensics++
│ │ ├── id_template.npy
│ │ ├── test
│ │ ├── test.json
│ │ ├── train
│ │ ├── train.json
│ │ ├── val
│ │ └── val.json
│ ├── D6_unitalker_Chinese_speech
│ │ ├── id_template.npy
│ │ ├── test
│ │ ├── test.json
│ │ ├── train
│ │ ├── train.json
│ │ ├── val
│ │ └── val.json
│ └── D7_unitalker_song
│ ├── id_template.npy
│ ├── test
│ ├── test.json
│ ├── train
│ ├── train.json
│ ├── val
│ └── val.json
```
### Demo
```bash
python -m main.demo --config config/unitalker.yaml test_out_path ./test_results/demo.npz
python -m main.render ./test_results/demo.npz ./test_audios ./test_results/
```
## Train
### Download Data
[Unitalker-data-release-V1](https://drive.google.com/file/d/1qRBPsTdOWp72ty04oD1Q_ivtwMjrACLH/view?usp=sharing) contains D5, D6 and D7. The datasets have been processed and grouped into train, validation and test. Please use these three datasets to try the training step.
If you want to train the model on the D0-D7, you need to download the datasets following these links:
[D0: BIWI](https://github.com/Doubiiu/CodeTalker/blob/main/BIWI/README.md).
[D1: VOCASET](https://voca.is.tue.mpg.de/).
[D2: meshtalk](https://github.com/facebookresearch/meshtalk?tab=readme-ov-file).
[D4,D5: 3DETF](https://github.com/psyai-net/EmoTalk_release).
### Modify Config and Train
Please modify `dataset` and `duplicate_list` in `config/unitalker.yaml` according to the datasets you have prepared, ensuring that both lists maintain the same length.
```bash
python -m main.train --config config/unitalker.yaml
```
================================================
FILE: config/unitalker.yaml
================================================
DATA:
dataset: "D0,D1,D2,D3,D4,D5,D6,D7"
demo_dataset: "D0,D1,D2,D3,D4,D6,D7" # To visualize D5, please download FLAME2020 model first and place it in "resources/binary_resources/flame.pkl"
data_root: "./unitalker_data_release_V1/"
duplicate_list: "10,5,4,1,1,1,1,1"
train_json_key: train.json
LOSS:
rec_weight: 1.0
pca_weight: 0.01
blendshape_weight: 0.0001
NETWORK:
use_pca: True
pca_dim: 512
MODEL:
audio_encoder_repo: 'microsoft/wavlm-base-plus' # choose from 'microsoft/wavlm-base-plus' and facebook/wav2vec2-large-xlsr-53'
freeze_wav2vec: False
interpolate_pos: 1
decoder_dimension: 256
decoder_type: conv # choose from conv and transformer
period: 30
headlayer: 1
TRAIN:
fix_encoder_first: True
random_id_prob: 0.1
re_init_decoder_and_head: False
only_head: False
workers: 4 # data loader workers
batch_size: 1 # batch size for training
batch_size_val: 1 # batch size for validation during training, memory and speed tradeoff
lr: 0.0001
epochs: 100
save_every: 1
save_path: "./results"
weight_path: "./pretrained_models/UniTalker-B-D0-D7.pt"
eval_every: 1
test_every: 1
log_every: 1
do_test: True
TEST:
test_wav_dir: "./test_audios"
test_out_path: "./test_results/demo.npz"
================================================
FILE: dataset/data_item.py
================================================
import librosa
import numpy as np
from typing import Any
from .dataset_config import dataset_config
class DataItem:
def __init__(self,
annot_path: str,
audio_path: str,
identity_idx: int,
annot_type: str = '',
dataset_name: str = '',
id_template: np.ndarray = None,
fps: float = 30,
processor=None) -> None:
self.annot_path = annot_path
self.audio_path = audio_path
self.identity_idx = identity_idx
self.fps = fps
self.annot_data = np.load(annot_path)
if self.fps == 60:
self.annot_data = self.annot_data[::2]
self.fps = 30
elif self.fps == 100:
self.annot_data = self.annot_data[::4]
self.fps = 25
elif self.fps != 30 and self.fps != 25:
raise ValueError('wrong fps')
scale = dataset_config[dataset_name]['scale']
self.annot_data = self.scale_and_offset(self.annot_data, scale, 0.0)
self.audio_data, sr = self.load_audio(audio_path)
self.original_audio_data, self.annot_data = self.truncate(
self.audio_data, sr, self.annot_data, self.fps)
self.audio_data = np.squeeze(
processor(self.original_audio_data, sampling_rate=sr).input_values)
self.audio_data = self.audio_data.astype(np.float32)
self.annot_type = annot_type
self.id_template = self.scale_and_offset(id_template, scale, 0.0)
self.annot_data = self.annot_data.reshape(len(self.annot_data),
-1).astype(np.float32)
self.id_template = self.id_template.reshape(1, -1).astype(np.float32)
self.dataset_name = dataset_name
self.data_dict = self.to_dict()
return
def load_audio(self, audio_path: str):
if audio_path.endswith('.npy'):
return np.load(audio_path), 16000
else:
return librosa.load(audio_path, sr=16000)
def truncate(self, audio_array: np.ndarray, sr: int,
annot_data: np.ndarray, fps: int):
audio_duration = len(audio_array) / sr
annot_duration = len(annot_data) / fps
duration = min(audio_duration, annot_duration, 20)
audio_length = round(duration * sr)
annot_length = round(duration * fps)
self.duration = duration
return audio_array[:audio_length], annot_data[:annot_length]
def offset_id(self, offset: int):
self.identity_idx = self.identity_idx + offset
self.data_dict['id'] = self.identity_idx
return
def scale_and_offset(self,
data: np.ndarray,
scale: float = 1.0,
offset: np.ndarray = 0.0):
return data * scale + offset
def to_dict(self, ) -> Any:
item = {
'data': self.annot_data,
'audio': self.audio_data,
'original_audio': self.original_audio_data,
'fps': self.fps,
'id': self.identity_idx,
'annot_path': self.annot_path,
'audio_path': self.audio_path,
'annot_type': self.annot_type,
'template': self.id_template,
'dataset_name': self.dataset_name
}
return item
def get_dict(self, ):
return self.data_dict
================================================
FILE: dataset/dataset.py
================================================
import json
import numpy as np
import os.path as osp
from torch.utils import data
from tqdm import tqdm
from transformers import Wav2Vec2FeatureExtractor
from typing import Any, List
from .data_item import DataItem
from .dataset_config import dataset_config
class AudioFaceDataset(data.Dataset):
def __init__(self, data_label_path: str, args: dict = None) -> None:
super().__init__()
data_root = osp.dirname(data_label_path, )
id_template_path = osp.join(data_root, 'id_template.npy')
id_template_list = np.load(id_template_path)
with open(data_label_path) as f:
labels = json.load(f)
info = labels['info']
self.id_list = info['id_list']
self.sr = 16000
data_info_list = labels['data']
data_list = []
for data_info in tqdm(data_info_list):
data = DataItem(
annot_path=osp.join(data_root, data_info['annot_path']),
audio_path=osp.join(data_root, data_info['audio_path']),
identity_idx=data_info['id'],
annot_type=data_info['annot_type'],
dataset_name=data_info['dataset'],
id_template=id_template_list[data_info['id']],
fps=data_info['fps'],
processor=args.processor,
)
if data.duration < 0.5:
continue
data_list.append(data)
self.data_list = data_list
return
def __len__(self, ):
return len(self.data_list)
def __getitem__(self, index: Any) -> Any:
data = self.data_list[index]
return data.get_dict()
def get_identity_num(self, ):
return len(self.id_list)
class MixAudioFaceDataset(AudioFaceDataset):
def __init__(self,
dataset_list: List[AudioFaceDataset],
duplicate_list: list = None) -> None:
super(AudioFaceDataset).__init__()
self.id_list = []
self.sr = 16000
self.data_list = []
self.annot_type_list = []
if duplicate_list is None:
duplicate_list = [1] * len(dataset_list)
else:
assert len(duplicate_list) == len(dataset_list)
for dup, dataset in zip(duplicate_list, dataset_list):
id_index_offset = len(self.id_list)
for d in dataset.data_list:
d.offset_id(id_index_offset)
if d.annot_type not in self.annot_type_list:
self.annot_type_list.append(d.annot_type)
self.id_list = self.id_list + dataset.id_list
self.data_list = self.data_list + dataset.data_list * dup
return
def get_dataset_list(args):
dataset_name_list = args.dataset
dataset_dir_list = [
dataset_config[name]['dirname'] for name in dataset_name_list
]
train_json_list = [
osp.join(args.data_root, dataset_name, args.train_json_key)
for dataset_name in dataset_dir_list
]
args.read_audio = False
args.processor = None
train_dataset_list = []
processor = Wav2Vec2FeatureExtractor.from_pretrained(
args.audio_encoder_repo)
args.processor = processor
for train_json in train_json_list:
dataset = AudioFaceDataset(
train_json,
args,
)
train_dataset_list.append(dataset)
return train_dataset_list
def get_single_dataset(args, json_path: str = ''):
args.read_audio = True
processor = Wav2Vec2FeatureExtractor.from_pretrained(
args.audio_encoder_repo)
args.processor = processor
dataset = AudioFaceDataset(
json_path,
args,
)
test_loader = data.DataLoader(
dataset=dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
num_workers=args.workers)
return test_loader
def get_dataloaders(args):
dataset_name_list = args.dataset
print(dataset_name_list)
dataset_dir_list = [
dataset_config[name]['dirname'] for name in dataset_name_list
]
train_json_list = [
osp.join(args.data_root, dataset_dir, args.train_json_key)
for dataset_dir in dataset_dir_list
]
val_json_list = [
osp.join(args.data_root, dataset_dir, 'val.json')
for dataset_dir in dataset_dir_list
]
test_json_list = [
osp.join(args.data_root, dataset_dir, 'test.json')
for dataset_dir in dataset_dir_list
]
processor = Wav2Vec2FeatureExtractor.from_pretrained(
args.audio_encoder_repo)
args.processor = processor
if getattr(args, 'do_train', True) is True:
train_dataset_list = []
for train_json in train_json_list:
dataset = AudioFaceDataset(train_json, args)
train_dataset_list.append(dataset)
mix_train_dataset = MixAudioFaceDataset(train_dataset_list,
args.duplicate_list)
train_loader = data.DataLoader(
dataset=mix_train_dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=args.workers)
else:
train_loader = None
if getattr(args, 'do_validate', True) is True:
val_dataset_list = []
for val_json in val_json_list:
dataset = AudioFaceDataset(val_json, args)
val_dataset_list.append(dataset)
mix_val_dataset = MixAudioFaceDataset(val_dataset_list)
val_loader = data.DataLoader(
dataset=mix_val_dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
num_workers=args.workers)
else:
val_loader = None
if getattr(args, 'do_test', True) is True:
test_dataset_list = []
for test_json in test_json_list:
dataset = AudioFaceDataset(test_json, args)
test_dataset_list.append(dataset)
mix_test_dataset = MixAudioFaceDataset(test_dataset_list)
test_loader = data.DataLoader(
dataset=mix_test_dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
num_workers=args.workers)
else:
test_loader = None
return train_loader, val_loader, test_loader
================================================
FILE: dataset/dataset_config.py
================================================
dataset_config = {
'D0': {
'dirname': 'D0_BIWI',
'annot_type': 'BIWI_23370_vertices',
'scale': 0.2,
'annot_dim': 23370 * 3,
'subjects': 6,
'pca': True,
},
'D1': {
'dirname': 'D1_vocaset',
'annot_type': 'FLAME_5023_vertices',
'scale': 1.0,
'annot_dim': 5023 * 3,
'subjects': 12,
'pca': True,
},
'D2': {
'dirname': 'D2_meshtalk',
'annot_type': 'meshtalk_6172_vertices',
'scale': 0.001,
'annot_dim': 6172 * 3,
'subjects': 13,
'pca': True,
},
'D3': {
'dirname': 'D3D4_3DETF/D3_HDTF',
'annot_type': '3DETF_blendshape_weight',
'scale': 1.0,
'annot_dim': 52,
'subjects': 141,
'pca': False,
},
'D4': {
'dirname': 'D3D4_3DETF/D4_RAVDESS',
'annot_type': '3DETF_blendshape_weight',
'scale': 1.0,
'annot_dim': 52,
'subjects': 24,
'pca': False,
},
'D5': {
'dirname': 'D5_unitalker_faceforensics++',
'annot_type': 'flame_params_from_dadhead',
'scale': 1.0,
'annot_dim': 413,
'subjects': 719,
'pca': False,
},
'D6': {
'dirname': 'D6_unitalker_Chinese_speech',
'annot_type': 'inhouse_blendshape_weight',
'scale': 1.0,
'annot_dim': 51,
'subjects': 8,
'pca': False,
},
'D7': {
'dirname': 'D7_unitalker_song',
'annot_type': 'inhouse_blendshape_weight',
'scale': 1.0,
'annot_dim': 51,
'subjects': 11,
'pca': False,
},
}
================================================
FILE: loss/loss.py
================================================
import numpy as np
import pickle
import torch
from dataclasses import dataclass
from smplx.lbs import lbs
from smplx.utils import Struct, to_np, to_tensor
from torch import nn
@dataclass
class FlameParams:
betas: torch.Tensor
jaw: torch.Tensor
eyeballs: torch.Tensor
neck: torch.Tensor
class FlameLayer(nn.Module):
def __init__(self, model_path: str) -> None:
super().__init__()
self.dtype = torch.float32
with open(model_path, 'rb') as f:
self.flame_model = Struct(**pickle.load(f, encoding='latin1'))
shapedirs = self.flame_model.shapedirs
# The shape components
self.register_buffer('shapedirs',
to_tensor(to_np(shapedirs), dtype=self.dtype))
j_regressor = to_tensor(
to_np(self.flame_model.J_regressor), dtype=self.dtype)
self.register_buffer('J_regressor', j_regressor)
self.register_buffer(
'v_template',
to_tensor(to_np(self.flame_model.v_template), dtype=self.dtype),
)
num_pose_basis = self.flame_model.posedirs.shape[-1]
posedirs = np.reshape(self.flame_model.posedirs,
[-1, num_pose_basis]).T
self.register_buffer('posedirs',
to_tensor(to_np(posedirs), dtype=self.dtype))
parents = to_tensor(to_np(self.flame_model.kintree_table[0])).long()
parents[0] = -1
self.register_buffer('parents', parents)
self.register_buffer(
'lbs_weights',
to_tensor(to_np(self.flame_model.weights), dtype=self.dtype))
default_neck_pose = torch.zeros([1, 3],
dtype=self.dtype,
requires_grad=False)
self.register_parameter(
'neck_pose', nn.Parameter(default_neck_pose, requires_grad=False))
default_eyeball_pose = torch.zeros([1, 6],
dtype=self.dtype,
requires_grad=False)
self.register_parameter(
'eyeballs',
nn.Parameter(default_eyeball_pose, requires_grad=False))
default_jaw = torch.zeros([1, 3],
dtype=self.dtype,
requires_grad=False)
self.register_parameter('jaw',
nn.Parameter(default_jaw, requires_grad=False))
self.FLAME_CONSTS = {
'betas': 400,
'rotation': 6,
'jaw': 3,
'eyeballs': 0,
'neck': 0,
'translation': 3,
'scale': 1,
}
def flame_params_from_3dmm(self,
tensor_3dmm: torch.Tensor,
zero_expr: bool = False):
constants = self.FLAME_CONSTS
assert tensor_3dmm.ndim == 2
cur_index: int = 0
betas = tensor_3dmm[:, :constants['betas']]
cur_index += constants['betas']
jaw = tensor_3dmm[:, cur_index:cur_index + constants['jaw']]
cur_index += constants['jaw']
cur_index += constants['rotation']
eyeballs = tensor_3dmm[:, cur_index:cur_index + constants['eyeballs']]
cur_index += constants['eyeballs']
neck = tensor_3dmm[:, cur_index:cur_index + constants['neck']]
cur_index += constants['neck']
cur_index += constants['translation']
cur_index += constants['scale']
return FlameParams(betas=betas, jaw=jaw, eyeballs=eyeballs, neck=neck)
def forward(self, tensor_3dmm: torch.Tensor):
bs = tensor_3dmm.shape[0]
flame_params = self.flame_params_from_3dmm(tensor_3dmm)
betas = flame_params.betas
rotation = torch.zeros([bs, 3], device=tensor_3dmm.device)
neck_pose = flame_params.neck if not (
0 in flame_params.neck.shape) else self.neck_pose[[0]].expand(
bs, -1)
eyeballs = flame_params.eyeballs if not (
0 in flame_params.eyeballs.shape) else self.eyeballs[[0]].expand(
bs, -1)
jaw = flame_params.jaw if not (
0 in flame_params.jaw.shape) else self.jaw[[0]].expand(bs, -1)
full_pose = torch.cat([rotation, neck_pose, jaw, eyeballs], dim=1)
template_vertices = self.v_template.unsqueeze(0).repeat(bs, 1, 1)
vertices, _ = lbs(
betas,
full_pose,
template_vertices,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.lbs_weights,
)
return vertices
class FlameParamLossForDADHead(nn.Module):
def __init__(self, mouth_indices_path: str, loss_mask_path: str,
flame_model_path: str) -> None:
super().__init__()
mouth_indices = np.load(mouth_indices_path)
loss_indices = np.load(loss_mask_path)
self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))
self.register_buffer('loss_indices', torch.from_numpy(loss_indices))
self.mse = nn.MSELoss()
self.flame_layer = FlameLayer(flame_model_path)
def forward(self, x: torch.Tensor, target: torch.Tensor):
ori_shape = x.shape
x = self.flame_layer(x.reshape(-1, ori_shape[-1]))
target = self.flame_layer(target.reshape(-1, ori_shape[-1]))
x = x.reshape((ori_shape[0], ori_shape[1], -1, 3))
target = target.reshape((ori_shape[0], ori_shape[1], -1, 3))
return self.mse(x[:, :, self.loss_indices], target[:, :,
self.loss_indices])
def get_vertices(
self,
x: torch.Tensor,
):
x = self.flame_layer(x.reshape(-1, x.shape[-1]))
return x[:, self.loss_indices]
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
ori_shape = x.shape
x = self.flame_layer(x.reshape(-1, ori_shape[-1]))
target = self.flame_layer(target.reshape(-1, ori_shape[-1]))
metric_L2 = (target[:, self.mouth_idx] - x[:, self.mouth_idx])**2
metric_L2 = metric_L2.sum(-1)
metric_L2 = metric_L2.max(-1)[0]
metric_L2norm = metric_L2**0.5
return metric_L2.mean(), metric_L2norm.mean()
class VerticesLoss(nn.Module):
def __init__(self, mouth_indices_path: str) -> None:
super().__init__()
mouth_indices = np.load(mouth_indices_path)
self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))
self.mse = nn.MSELoss()
def forward(self, x: torch.Tensor, target: torch.Tensor):
return self.mse(x, target)
def get_vertices(
self,
x: torch.Tensor,
):
return x.reshape(-1, x.shape[-1] // 3, 3)
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
ori_shape = x.shape
target = target.reshape(*ori_shape[:-1], -1, 3)
x = x.reshape(*ori_shape[:-1], -1, 3)
metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2
metric_L2 = metric_L2.sum(-1)
metric_L2 = metric_L2.max(-1)[0]
metric_L2norm = metric_L2**0.5
return metric_L2.mean(), metric_L2norm.mean()
class BlendShapeLoss(nn.Module):
def __init__(self,
mouth_indices_path: str,
blendshape_path: str,
bs_beta: float = 0.0,
components_num: int = None) -> None:
super().__init__()
mouth_indices = np.load(mouth_indices_path)
blendshape = np.load(blendshape_path)
self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices))
mean_shape = blendshape['meanshape']
mean_shape = mean_shape.reshape(-1).astype(np.float32)
blend_shape = blendshape['blendshape']
blend_shape = blend_shape.reshape(len(blend_shape),
-1).astype(np.float32)
self.register_buffer('mean_shape', torch.from_numpy(mean_shape))
self.register_buffer('blend_shape', torch.from_numpy(blend_shape))
self.bs_beta = bs_beta
self.mse = nn.MSELoss()
self.components_num = components_num
def forward(self, x: torch.Tensor, target: torch.Tensor):
return self.bs_beta * self.mse(x, target) + self.mse(
self.bs2vertices(x), self.bs2vertices(target))
def bs2vertices(self, x: torch.Tensor):
return x[:, :self.components_num] @ \
self.blend_shape[:self.components_num] + \
self.mean_shape
def get_vertices(
self,
x: torch.Tensor,
):
x = self.bs2vertices(x)
return x.reshape(-1, x.shape[-1] // 3, 3)
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
target = self.bs2vertices(target)
x = self.bs2vertices(x)
ori_shape = x.shape
target = target.reshape(*ori_shape[:-1], -1, 3)
x = x.reshape(*ori_shape[:-1], -1, 3)
metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2
metric_L2 = metric_L2.sum(-1)
metric_L2 = metric_L2.max(-1)[0]
metric_L2norm = metric_L2**0.5
return metric_L2.mean(), metric_L2norm.mean()
class ParamLoss(nn.Module):
def __init__(self, weight: float):
super().__init__()
self.weight = weight
self.mse = nn.MSELoss()
self.L1 = nn.L1Loss()
def forward(self, x: torch.Tensor, target: torch.Tensor):
return self.weight * self.mse(x, target)
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor):
return self.mse(x, target), self.L1(x, target)
def get_vertices(
self,
x: torch.Tensor,
):
return x.squeeze(0)
class UniTalkerLoss(nn.Module):
def __init__(self, args) -> None:
super().__init__()
loss_config = {
'flame_params_from_dadhead': {
'class': FlameParamLossForDADHead,
'args': {
'mouth_indices_path':
'resources/binary_resources/02_flame_mouth_idx.npy',
'loss_mask_path':
'resources/binary_resources/flame_humanface_index.npy',
'flame_model_path': 'resources/binary_resources/flame.pkl',
}
},
'3DETF_blendshape_weight': {
'class': BlendShapeLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/03_emo_talk_mouth_idx.npy',
'blendshape_path':
'resources/binary_resources/EmoTalk.npz',
'bs_beta': args.blendshape_weight,
},
},
'inhouse_blendshape_weight': {
'class': BlendShapeLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/05_inhouse_arkit_mouth_idx.npy',
'blendshape_path':
'resources/binary_resources/inhouse_arkit.npz',
'bs_beta': args.blendshape_weight,
},
},
'FLAME_5023_vertices': {
'class': VerticesLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/02_flame_mouth_idx.npy',
},
},
'BIWI_23370_vertices': {
'class': VerticesLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/04_BIWI_mouth_idx.npy',
},
},
'FLAME_SUB_2055_vertices': {
'class': VerticesLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/mouth_idx_FLAME_5023_vertices_wo_eyes_wo_head.npy',
},
},
'meshtalk_6172_vertices': {
'class': VerticesLoss,
'args': {
'mouth_indices_path':
'resources/binary_resources/06_mesh_talk_mouth_idx.npy',
},
},
'24_viseme': {
'class': ParamLoss,
'args': {
'weight': args.blendshape_weight
},
}
}
loss_module_dict = {}
for k, v in loss_config.items():
loss_module = v['class'](**v['args'])
loss_module_dict[k] = loss_module
self.loss_module_dict = nn.ModuleDict(loss_module_dict)
self.mse = nn.MSELoss()
return
def forward(
self,
x: torch.Tensor,
target: torch.Tensor,
annot_type: str,
):
return self.loss_module_dict[annot_type](x, target)
def pca_loss(self, x: torch.Tensor, target: torch.Tensor):
return self.mse(x, target)
def mouth_metric(self, x: torch.Tensor, target: torch.Tensor,
annot_type: str):
return self.loss_module_dict[annot_type].mouth_metric(x, target)
def get_vertices(self, x: torch.Tensor, annot_type: str):
return self.loss_module_dict[annot_type].get_vertices(x)
================================================
FILE: main/demo.py
================================================
import librosa
import numpy as np
import os
import torch
import torch.optim
import torch.utils.data
from transformers import Wav2Vec2FeatureExtractor
from dataset.dataset_config import dataset_config
from loss.loss import UniTalkerLoss
from models.unitalker import UniTalker
from utils.utils import get_parser, get_template_verts, get_audio_encoder_dim
def get_all_audios(audio_root:str):
wav_f_names = []
for r, _, f in os.walk(audio_root):
for wav in f:
if wav.endswith('.wav') or wav.endswith('.mp3'):
relative_path = os.path.join(r, wav)
relative_path = os.path.relpath(relative_path,
audio_root)
wav_f_names.append(relative_path)
wav_f_names = sorted(wav_f_names)
return wav_f_names
def split_long_audio(
audio: np.ndarray,
processor:Wav2Vec2FeatureExtractor
):
# audio = audio.squeeze(0)
a, b = 25, 5
sr = 16000
total_length = len(audio) /sr
reps = max(0, int(np.ceil((total_length - a) / (a - b)))) + 1
in_audio_split_list = []
start, end = 0, int(a * sr)
step = int((a - b) * sr)
for i in range(reps):
audio_split = audio[start:end]
audio_split = np.squeeze(
processor(audio_split, sampling_rate=sr).input_values)
in_audio_split_list.append(audio_split)
start += step
end += step
return in_audio_split_list
def merge_out_list(out_list: list, fps:int):
if len(out_list) == 1:
return out_list[0]
a, b = 25, 5
left_weight = np.linspace(1, 0, b * fps)[:, np.newaxis]
right_weight = 1 - left_weight
a = a * fps
b = b * fps
offset = a - b
out_length = len(out_list[-1]) + offset * (len(out_list) - 1)
merged_out = np.empty((out_length, out_list[-1].shape[-1]),
dtype=out_list[-1].dtype)
merged_out[:a] = out_list[0]
for out_piece in out_list[1:]:
merged_out[a - b:a] = left_weight * merged_out[
a - b:a] + right_weight * out_piece[:b]
merged_out[a:a + offset] = out_piece[b:]
a += offset
return merged_out
def main():
args = get_parser()
training_dataset_name_list = args.dataset
demo_dataset_name_list = args.demo_dataset
condition_id_config = {
'D0': 3,
'D1': 3,
'D2': 0,
'D3': 0,
'D4': 0,
'D5': 0,
'D6': 4,
'D7': 0,
}
template_id_config = {
'D0': 3,
'D1': 3,
'D2': 0,
'D3': 0,
'D4': 0,
'D5': 0,
'D6': 4,
'D7': 0,
}
checkpoint = torch.load(args.weight_path, map_location='cpu')
args.identity_num = len(checkpoint['decoder.learnable_style_emb.weight'])
start_idx = 0
for dataset_name in training_dataset_name_list:
annot_type = dataset_config[dataset_name]['annot_type']
id_num = dataset_config[dataset_name]['subjects']
end_idx = start_idx + id_num
local_condition_idx = condition_id_config[dataset_name]
template_idx = template_id_config[dataset_name]
template = get_template_verts(args.data_root, dataset_name, template_idx)
template_id_config[dataset_name] = torch.Tensor(template.reshape(1, -1))
if args.condition_id == 'each':
condition_id_config[dataset_name] = torch.tensor(
start_idx + local_condition_idx).reshape(1)
elif args.condition_id == 'common':
condition_id_config[dataset_name] = torch.tensor(args.identity_num -
1).reshape(1)
else:
try:
condition_id = int(args.condition_id)
condition_id_config[dataset_name] = torch.tensor(
condition_id).reshape(1)
except ValueError:
assert args.condition_id in dataset_config.keys()
condition_id_config[dataset_name] = torch.tensor(
start_idx + local_condition_idx).reshape(1)
start_idx = end_idx
if args.random_id_prob > 0:
assert end_idx == (args.identity_num - 1)
else:
assert end_idx == args.identity_num
args.audio_encoder_feature_dim = get_audio_encoder_dim(
args.audio_encoder_repo)
model = UniTalker(args)
model.load_state_dict(checkpoint, strict=False)
model.eval()
model.cuda()
wav_f_names = get_all_audios(args.test_wav_dir)
out_npz_path = args.test_out_path
dirname = os.path.dirname(out_npz_path)
os.makedirs(dirname, exist_ok=True)
out_dict = {}
loss_module = UniTalkerLoss(args).cuda()
processor = Wav2Vec2FeatureExtractor.from_pretrained(
args.audio_encoder_repo,)
with torch.no_grad():
for wav_f in wav_f_names:
out_dict[wav_f] = {}
wav_path = os.path.join(args.test_wav_dir, wav_f)
audio_data, sr = librosa.load(wav_path, sr=16000)
audio_data = np.squeeze(
processor(audio_data, sampling_rate=sr).input_values)
audio_data_splits = split_long_audio(audio_data, processor)
for dataset_name in demo_dataset_name_list:
template = template_id_config[dataset_name].cuda()
scale = dataset_config[dataset_name]['scale']
template = scale * template
condition_id = condition_id_config[dataset_name].cuda()
annot_type = dataset_config[dataset_name]['annot_type']
if annot_type == 'BIWI_23370_vertices':
fps = 25
else:
fps = 30
out_list = []
for audio_data in audio_data_splits:
audio_data = torch.Tensor(audio_data[None]).cuda()
out, _, _ = model(
audio_data,
template,
face_motion=None,
style_idx=condition_id,
annot_type=annot_type,
fps=fps,
)
out_list.append(out.cpu().numpy().squeeze(0))
out = merge_out_list(out_list, fps)
out = loss_module.get_vertices(torch.from_numpy(out).cuda(), annot_type)
out_dict[wav_f][dataset_name] = out.cpu()
print(f"save results to {out_npz_path}")
np.savez(out_npz_path, **out_dict)
if __name__ == '__main__':
main()
================================================
FILE: main/render.py
================================================
"""Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V.
(MPG) is holder of all proprietary rights on this computer program. You can
only use this computer program if you have closed a license agreement with MPG
or you get the right to use the computer program from someone who is authorized
to grant you that right. Any use of the computer program without a valid
license is prohibited and liable to prosecution. Copyright 2019 Max-Planck-
Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of
its Max Planck Institute for Intelligent Systems and the Max Planck Institute
for Biological Cybernetics. All rights reserved. More information about VOCA is
available at http://voca.is.tue.mpg.de. For comments or questions, please email
us at voca@tue.mpg.de
"""
import cv2
import numpy as np
import os
import subprocess
import torch
from tqdm import tqdm
from dataset.dataset_config import dataset_config
def render_meshes(mesh_vertices: torch.Tensor,
faces: torch.Tensor,
img_size: tuple = (256, 256),
aa_factor: int = 1,
vis_bs: int = 100):
assert len(mesh_vertices) == len(faces) or len(faces) == 1
if len(faces) == 1:
faces = faces.repeat(len(mesh_vertices), 1, 1)
x_max = y_max = mesh_vertices[..., 0:2].abs().max()
from pytorch3d.renderer import (
DirectionalLights,
FoVOrthographicCameras,
HardPhongShader,
Materials,
MeshRasterizer,
MeshRenderer,
RasterizationSettings,
TexturesVertex,
)
from pytorch3d.renderer.blending import BlendParams
from pytorch3d.renderer.materials import Materials
from pytorch3d.structures import Meshes
aa_factor = int(aa_factor)
device = mesh_vertices.device
verts_batch_lst = torch.split(mesh_vertices, vis_bs)
faces_batch_lst = torch.split(faces, vis_bs)
R = torch.eye(3).to(mesh_vertices.device)
R[0, 0] = -1
R[2, 2] = -1
T = torch.zeros(3).to(mesh_vertices.device)
T[2] = 10
R, T = R[None], T[None]
W, H = img_size
cameras = FoVOrthographicCameras(
device=device,
R=R,
T=T,
znear=0.01,
zfar=3,
max_x=x_max * 1.2,
min_x=-x_max * 1.2,
max_y=y_max * 1.2,
min_y=-y_max * 1.2,
)
# cameras = FoVPerspectiveCameras(
# device=device,
# R=R,
# T=T,
# znear=0.01,
# zfar=3,
# aspect_ratio=1,
# fov=0.3,
# degrees=False
# )
raster_settings = RasterizationSettings(
image_size=(W * aa_factor, H * aa_factor),
blur_radius=0.0,
faces_per_pixel=1,
)
raster = MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings,
)
blend_params = BlendParams(
sigma=0.0, gamma=0.0, background_color=(0.0, 0.0, 0.0))
lights = DirectionalLights(
device=device,
direction=((0, 0, 1), ),
ambient_color=((0.3, 0.3, 0.3), ),
diffuse_color=((0.6, 0.6, 0.6), ),
specular_color=((0.1, 0.1, 0.1), ))
materias = Materials(
ambient_color=((1, 1, 1), ),
diffuse_color=((1, 1, 1), ),
specular_color=((1, 1, 1), ),
shininess=15,
device=device)
shader = HardPhongShader(
device=device,
cameras=cameras,
lights=lights,
materials=materias,
blend_params=blend_params)
renderer = MeshRenderer(
rasterizer=raster,
shader=shader,
)
rendered_imgs = []
for verts_batch, faces_batch in tqdm(
zip(verts_batch_lst, faces_batch_lst),
total=(len(verts_batch_lst))):
textures = TexturesVertex(verts_features=torch.ones_like(verts_batch))
meshes = Meshes(
verts=verts_batch, faces=faces_batch, textures=textures)
with torch.no_grad():
imgs = renderer(meshes)
rendered_imgs.append(imgs.cpu())
rendered_imgs = torch.cat(rendered_imgs, dim=0)
if aa_factor > 1:
rendered_imgs = rendered_imgs.permute(0, 3, 1, 2) # NHWC -> NCHW
rendered_imgs = torch.nn.functional.interpolate(
rendered_imgs, scale_factor=1 / aa_factor, mode='bicubic')
rendered_imgs = rendered_imgs.permute(0, 2, 3, 1) # NCHW -> NHWC
return rendered_imgs
def read_obj(in_path):
with open(in_path, 'r') as obj_file:
# Read the lines of the OBJ file
lines = obj_file.readlines()
# Initialize empty lists for vertices and faces
verts = []
faces = []
for line in lines:
line = line.strip() # Remove leading/trailing whitespace
elements = line.split() # Split the line into elements
if len(elements) == 0:
continue # Skip empty lines
# Check the type of line (vertex or face)
if elements[0] == 'v':
# Vertex line
x, y, z = map(float,
elements[1:4]) # Extract the vertex coordinates
verts.append((x, y, z)) # Add the vertex to the list
elif elements[0] == 'f':
# Face line
face_indices = [
int(index.split('/')[0]) for index in elements[1:]
] # Extract the vertex indices
faces.append(face_indices) # Add the face to the list
return np.array(verts), np.array(faces)
def pad_for_libx264(image_array):
"""Pad zeros if width or height of image_array is not divisible by 2.
Otherwise you will get.
\"[libx264 @ 0x1b1d560] width not divisible by 2 \"
Args:
image_array (np.ndarray):
Image or images load by cv2.imread().
Possible shapes:
1. [height, width]
2. [height, width, channels]
3. [images, height, width]
4. [images, height, width, channels]
Returns:
np.ndarray:
A image with both edges divisible by 2.
"""
if image_array.ndim == 2 or \
(image_array.ndim == 3 and image_array.shape[2] == 3):
hei_index = 0
wid_index = 1
elif image_array.ndim == 4 or \
(image_array.ndim == 3 and image_array.shape[2] != 3):
hei_index = 1
wid_index = 2
else:
return image_array
hei_pad = image_array.shape[hei_index] % 2
wid_pad = image_array.shape[wid_index] % 2
if hei_pad + wid_pad > 0:
pad_width = []
for dim_index in range(image_array.ndim):
if dim_index == hei_index:
pad_width.append((0, hei_pad))
elif dim_index == wid_index:
pad_width.append((0, wid_pad))
else:
pad_width.append((0, 0))
values = 0
image_array = \
np.pad(image_array,
pad_width,
mode='constant', constant_values=values)
return image_array
def array_to_video(
image_array: np.ndarray,
output_path: str,
fps=30,
resolution=None,
disable_log: bool = False,
) -> None:
"""Convert an array to a video directly, gif not supported.
Args:
image_array (np.ndarray): shape should be (f * h * w * 3).
output_path (str): output video file path.
fps (Union[int, float, optional): fps. Defaults to 30.
resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
optional): (height, width) of the output video.
Defaults to None.
disable_log (bool, optional): whether close the ffmepg command info.
Defaults to False.
Raises:
FileNotFoundError: check output path.
TypeError: check input array.
Returns:
None.
"""
if not isinstance(image_array, np.ndarray):
raise TypeError('Input should be np.ndarray.')
assert image_array.ndim == 4
assert image_array.shape[-1] == 3
if resolution:
height, width = resolution
width += width % 2
height += height % 2
else:
image_array = pad_for_libx264(image_array)
height, width = image_array.shape[1], image_array.shape[2]
command = [
'/usr/bin/ffmpeg',
'-y', # (optional) overwrite output file if it exists
'-f',
'rawvideo',
'-s',
f'{int(width)}x{int(height)}', # size of one frame
'-pix_fmt',
'bgr24',
'-r',
f'{fps}', # frames per second
'-loglevel',
'error',
'-threads',
'4',
'-i',
'-', # The input comes from a pipe
'-vcodec',
'libx264',
'-an', # Tells FFMPEG not to expect any audio
output_path,
]
if not disable_log:
print(f'Running \"{" ".join(command)}\"')
process = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if process.stdin is None or process.stderr is None:
raise BrokenPipeError('No buffer received.')
index = 0
while True:
if index >= image_array.shape[0]:
break
process.stdin.write(image_array[index].tobytes())
index += 1
process.stdin.close()
process.stderr.close()
process.wait()
def get_obj_faces(annot_type: str):
template_obj_path_dict = {
'3DETF_blendshape_weight': 'resources/obj_template/3DETF_blendshape_weight.obj',
'FLAME_5023_vertices': 'resources/obj_template/FLAME_5023_vertices.obj',
'BIWI_23370_vertices': 'resources/obj_template/BIWI_23370_vertices.obj',
'flame_params_from_dadhead': 'resources/obj_template/flame_params_from_dadhead.obj',
'inhouse_blendshape_weight': 'resources/obj_template/inhouse_blendshape_weight.obj',
'meshtalk_6172_vertices': 'resources/obj_template/meshtalk_6172_vertices.obj'
}
faces = np.array(read_obj(template_obj_path_dict[annot_type])[1])
if faces.min() == 1:
faces = faces - 1
return faces
def vis_model_out_proxy():
import sys
npz_path = sys.argv[1]
audio_dir = sys.argv[2]
out_dir = sys.argv[3]
save_images = False
os.makedirs(out_dir, exist_ok=True)
out_dict = np.load(npz_path, allow_pickle=True)
print(list(out_dict.keys()))
for wav_f in out_dict.keys():
wav_path = os.path.join(audio_dir, wav_f)
cur_out_dict = out_dict[wav_f]
for datasetname, out in cur_out_dict.item().items():
out = torch.Tensor(out)
annot_type = dataset_config[datasetname]['annot_type']
faces = get_obj_faces(annot_type)
with torch.no_grad():
if save_images:
out_size = (1024, 1024)
else:
out_size = (512, 512)
img_array = render_meshes(out.cuda(),
torch.from_numpy(faces[None]).cuda(),
out_size)
if save_images:
channels = 4
else:
channels = 3
img_array = (img_array[..., :channels].cpu().numpy() *
255).astype(np.uint8)
out_key = f'{wav_f[:-4]}'
if save_images:
out_images_dir = os.path.join(out_dir, datasetname, out_key)
os.makedirs(out_images_dir, exist_ok=True)
for imgidx, img in enumerate(img_array):
img_path = os.path.join(out_images_dir,
f'{imgidx:04d}.png')
cv2.imwrite(img_path, img)
continue
out_video_path = os.path.join(out_dir,
f'{out_key}_{datasetname}.mp4')
out_video_dir = os.path.dirname(out_video_path)
os.makedirs(out_video_dir, exist_ok=True)
tmp_path = os.path.join(out_video_dir, 'tmp.mp4')
if os.path.exists(tmp_path):
os.remove(tmp_path)
if annot_type == 'BIWI_23370_vertices':
fps = 25
else:
fps = 30
print(img_array.shape, img_array.dtype)
array_to_video(img_array, output_path=tmp_path, fps=fps)
cmd = f'ffmpeg -i {tmp_path} -i {wav_path} -c:v copy -c:a aac -shortest -strict -2 {out_video_path} -y '
print(cmd)
subprocess.run(cmd, shell=True)
os.remove(tmp_path)
if __name__ == '__main__':
vis_model_out_proxy()
================================================
FILE: main/train.py
================================================
#!/usr/bin/env python
# yapf: disable
import os
import torch
import torch.optim
import torch.utils.data
from tensorboardX import SummaryWriter
from dataset.dataset import get_dataloaders
from loss.loss import UniTalkerLoss
from models.unitalker import UniTalker
from utils.utils import (
get_average_meter_dict, get_logger, get_parser, get_audio_encoder_dim,
load_ckpt, seed_everything, filter_unitalker_state_dict
)
from .train_eval_loop import train_epoch, validate_epoch
# yapf: enable
def main():
seed_everything(42)
args = get_parser()
args.audio_encoder_feature_dim = get_audio_encoder_dim(args.audio_encoder_repo)
train_loader, val_loader, test_loader = get_dataloaders(args)
args.identity_num = train_loader.dataset.get_identity_num()
args.annot_type_list = train_loader.dataset.annot_type_list
if args.random_id_prob > 0.0:
args.identity_num = args.identity_num + 1
os.makedirs(args.save_path, exist_ok=True)
log_file = os.path.join(args.save_path, 'train.log')
logger = get_logger(log_file)
writer = SummaryWriter(args.save_path)
logger.info('=> creating model ...')
model = UniTalker(args)
if args.weight_path is not None:
logger.info(f'=> loading model {args.weight_path} ...')
load_ckpt(model, args.weight_path, args.re_init_decoder_and_head)
logger.info(args)
model = model.cuda()
model.summary(logger)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
loss_module = UniTalkerLoss(args).cuda()
args.train_meter, args.train_metric_name_list = \
get_average_meter_dict(args, 'train')
args.val_meter, args.val_metric_name_list = \
get_average_meter_dict(args, 'val')
args.logger = logger
args.writer = writer
for epoch in range(args.epochs):
if args.fix_encoder_first:
if epoch == 0:
model.audio_encoder._freeze_wav2vec2_parameters(True)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr)
elif epoch == 1:
model.audio_encoder._freeze_wav2vec2_parameters(
args.freeze_wav2vec)
model.audio_encoder.feature_extractor._freeze_parameters()
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr)
epoch_plus = epoch + 1
train_epoch(train_loader, model, loss_module, optimizer, epoch_plus,
args)
if epoch_plus % args.eval_every == 0:
validate_epoch(val_loader, model, loss_module, epoch_plus, 'val',
args)
if epoch_plus % args.test_every == 0:
validate_epoch(test_loader, model, loss_module, epoch_plus, 'test',
args)
if epoch_plus % args.save_every == 0:
model_path = os.path.join(args.save_path, f'{epoch_plus:03d}.pt')
torch.save(filter_unitalker_state_dict(model.state_dict()), model_path)
if __name__ == '__main__':
main()
================================================
FILE: main/train_eval_loop.py
================================================
import numpy as np
import torch
from tqdm import tqdm
from utils.utils import log_datasetloss, write_to_tensorboard
def train_epoch(data_loader, model, loss_module, optimizer, epoch, args):
model.train()
train_meter = args.train_meter
metric_name_list = args.train_metric_name_list
zero_loss = torch.tensor(0.0)
mixed_dataset_name = 'mixed_dataset'
pbar = tqdm(data_loader)
pbar.set_description(f'Epoch: {epoch} train')
for batch in pbar:
data = batch['data'].cuda(non_blocking=True)
template = batch['template'].cuda(non_blocking=True)
audio = batch['audio'].cuda(non_blocking=True)
identity = batch['id'].cuda(non_blocking=True)
if args.random_id_prob > 0.0:
if np.random.random() < args.random_id_prob:
identity = identity.fill_(args.identity_num - 1)
annot_type = batch['annot_type'][0]
dataset_name = batch['dataset_name'][0]
out_motion, out_pca, gt_pca = model(
audio, template, data, style_idx=identity, annot_type=annot_type)
rec_loss = loss_module(out_motion, data, annot_type)
if out_pca is not None:
pca_loss = loss_module.pca_loss(out_pca, gt_pca)
else:
pca_loss = zero_loss
loss = rec_loss + args.pca_weight * pca_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
for k, v in zip(metric_name_list, (rec_loss, pca_loss)):
train_meter[dataset_name][k].update(v.item())
train_meter[mixed_dataset_name][k].update(v.item())
log_datasetloss(args.logger, epoch, 'train', train_meter, only_mix=True)
write_to_tensorboard(
args.writer, epoch, 'train', train_meter, only_mix=False)
for dataset_name, sub_metric_dict in train_meter.items():
for metric_name in sub_metric_dict.keys():
sub_metric_dict[metric_name].reset()
return
def validate_epoch(data_loader, model, loss_module, epoch, phase, args):
assert phase in ('val', 'test')
model.eval()
val_meter = args.val_meter
metric_name_list = args.val_metric_name_list
mixed_dataset_name = 'mixed_dataset'
with torch.no_grad():
pbar = tqdm(data_loader)
pbar.set_description(f'Epoch: {epoch} {phase:<5}')
for batch in pbar:
data = batch['data'].cuda(non_blocking=True)
template = batch['template'].cuda(non_blocking=True)
audio = batch['audio'].cuda(non_blocking=True)
identity = batch['id'].cuda(non_blocking=True)
annot_type = batch['annot_type'][0]
dataset_name = batch['dataset_name'][0]
if args.random_id_prob >= 1.0:
identity = identity.fill_(args.identity_num - 1)
out_motion, _, _ = model(
audio,
template,
data,
style_idx=identity,
annot_type=annot_type)
rec_loss = loss_module(out_motion, data, annot_type)
mouth_metric_L2, mouth_metric_L2_norm = loss_module.mouth_metric(
out_motion, data, annot_type)
for k, v in zip(metric_name_list,
(rec_loss, mouth_metric_L2, mouth_metric_L2_norm)):
val_meter[dataset_name][k].update(v.item())
val_meter[mixed_dataset_name][k].update(v.item())
if getattr(args, 'logger', None) is not None:
log_datasetloss(
args.logger, epoch, phase, val_meter, only_mix=True)
if getattr(args, 'writer', None) is not None:
write_to_tensorboard(
args.writer, epoch, phase, val_meter, only_mix=False)
for dataset_name, sub_metric_dict in val_meter.items():
for metric_name in sub_metric_dict.keys():
sub_metric_dict[metric_name].reset()
return
================================================
FILE: main/train_pca.py
================================================
#!/usr/bin/env python
import numpy as np
from models.pca import PCA
from utils.utils import get_parser, get_audio_encoder_dim
def main():
args = get_parser()
args.audio_encoder_feature_dim = get_audio_encoder_dim(
args.wav2vec2_type)
main_worker(args)
def main_worker(cfg):
pca_builder = PCA()
from dataset.dataset import get_dataset_list
train_dataset_list = get_dataset_list(cfg)
for dataset in train_dataset_list:
data = pca_builder.load_dataset(dataset)
pca_info = pca_builder.build_incremental_PCA(
data, batch_size=1095, trunc_dim=647)
np.savez('mesh_talk_pca.npz', **pca_info)
if __name__ == '__main__':
main()
================================================
FILE: models/base_model.py
================================================
import numpy as np
import torch.nn as nn
class BaseModel(nn.Module):
"""Base class for all models."""
def __init__(self):
super(BaseModel, self).__init__()
# self.logger = logging.getLogger(self.__class__.__name__)
def forward(self, *x):
"""Forward pass logic.
:return: Model output
"""
raise NotImplementedError
def freeze_model(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def summary(self, logger, writer=None):
"""Model summary."""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size())
for p in model_parameters]) / 1e6 # Unit is Mega
logger.info('===>Trainable parameters: %.3f M' % params)
if writer is not None:
writer.add_text('Model Summary',
'Trainable parameters: %.3f M' % params)
================================================
FILE: models/hubert.py
================================================
import numpy as np
import torch
from transformers import HubertModel
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Tuple
from .model_utils import linear_interpolation
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.Tensor] = None,
min_masks: int = 0,
) -> np.ndarray:
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(mask_prob * all_sz / float(mask_length) +
np.random.rand())
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(mask_prob * sz / float(mask_length) +
np.random.rand())
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
lengths = np.full(num_mask, mask_length)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([
mask_idc[j] + offset for j in range(len(mask_idc))
for offset in range(lengths[j])
])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
class HubertModel(HubertModel):
def __init__(self, config):
super().__init__(config)
def _freeze_parameters(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
frame_num=None,
interpolate_pos: int = 0,
):
self.config.output_attentions = True
output_attentions = output_attentions if \
output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
conv_features = self.feature_extractor(input_values)
hidden_states = conv_features.transpose(1, 2)
if interpolate_pos == 0:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if attention_mask is not None:
output_lengths = self._get_feat_extract_output_lengths(
attention_mask.sum(-1))
attention_mask = torch.zeros(
hidden_states.shape[:2],
dtype=hidden_states.dtype,
device=hidden_states.device)
attention_mask[(torch.arange(
attention_mask.shape[0],
device=hidden_states.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip(
[-1]).bool()
hidden_states = self.feature_projection(hidden_states)
if self.config.apply_spec_augment and self.training:
batch_size, sequence_length, hidden_size = hidden_states.size()
if self.config.mask_time_prob > 0:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
self.config.mask_time_prob,
self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=2,
)
hidden_states[torch.from_numpy(
mask_time_indices)] = self.masked_spec_embed.to(
hidden_states.dtype)
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
self.config.mask_feature_prob,
self.config.mask_feature_length,
)
mask_feature_indices = torch.from_numpy(
mask_feature_indices).to(hidden_states.device)
hidden_states[mask_feature_indices[:, None].expand(
-1, sequence_length, -1)] = 0
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if interpolate_pos == 1:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if not return_dict:
return (hidden_states, ) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
================================================
FILE: models/model_utils.py
================================================
# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/main.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# linear interpolation layer
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(
features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
# Temporal Bias
def init_biased_mask(n_head, max_seq_len, period):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
slopes = torch.Tensor(get_slopes(n_head))
bias = torch.div(
torch.arange(start=0, end=max_seq_len,
step=period).unsqueeze(1).repeat(1, period).view(-1),
period,
rounding_mode='floor')
bias = -torch.flip(bias, dims=[0])
alibi = torch.zeros(max_seq_len, max_seq_len)
for i in range(max_seq_len):
alibi[i, :i + 1] = bias[-(i + 1):]
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
mask = (torch.triu(torch.ones(max_seq_len,
max_seq_len)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
mask == 1, float(0.0))
mask = mask.unsqueeze(0) + alibi
return mask
# Alignment Bias
def enc_dec_mask(device, T, S):
mask = torch.ones(T, S)
for i in range(T):
mask[i, i] = 0
return (mask == 1).to(device=device)
# Periodic Positional Encoding
class PeriodicPositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
super(PeriodicPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(period, d_model)
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, period, d_model)
repeat_num = (max_seq_len // period) + 1
pe = pe.repeat(1, repeat_num, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class TCN(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.first_net = SeqTranslator1D(
in_dim, out_dim, min_layers_num=3, residual=True, norm='ln')
self.dropout = nn.Dropout(0.1)
def forward(self, spectrogram, style_embed, time_steps=None):
spectrogram = spectrogram
spectrogram = self.dropout(spectrogram)
style_embed = style_embed.unsqueeze(2)
style_embed = style_embed.repeat(1, 1, spectrogram.shape[2])
spectrogram = torch.cat([spectrogram, style_embed], dim=1)
x1 = self.first_net(spectrogram) # .permute(0, 2, 1)
if time_steps is not None:
x1 = F.interpolate(
x1, size=time_steps, align_corners=False, mode='linear')
return x1
class SeqTranslator1D(nn.Module):
"""(B, C, T)->(B, C_out, T)"""
def __init__(self,
C_in,
C_out,
kernel_size=None,
stride=None,
min_layers_num=None,
residual=True,
norm='bn'):
super(SeqTranslator1D, self).__init__()
conv_layers = nn.ModuleList([])
conv_layers.append(
ConvNormRelu(
in_channels=C_in,
out_channels=C_out,
type='1d',
kernel_size=kernel_size,
stride=stride,
residual=residual,
norm=norm))
self.num_layers = 1
if min_layers_num is not None and self.num_layers < min_layers_num:
while self.num_layers < min_layers_num:
conv_layers.append(
ConvNormRelu(
in_channels=C_out,
out_channels=C_out,
type='1d',
kernel_size=kernel_size,
stride=stride,
residual=residual,
norm=norm))
self.num_layers += 1
self.conv_layers = nn.Sequential(*conv_layers)
def forward(self, x):
return self.conv_layers(x)
class ConvNormRelu(nn.Module):
"""(B,C_in,H,W) -> (B, C_out, H, W) there exist some kernel size that makes
the result is not H/s.
#TODO: there might some problems with residual
"""
def __init__(self,
in_channels,
out_channels,
type='1d',
leaky=False,
downsample=False,
kernel_size=None,
stride=None,
padding=None,
p=0,
groups=1,
residual=False,
norm='bn'):
"""conv-bn-relu."""
super(ConvNormRelu, self).__init__()
self.residual = residual
self.norm_type = norm
# kernel_size = k
# stride = s
if kernel_size is None and stride is None:
if not downsample:
kernel_size = 3
stride = 1
else:
kernel_size = 4
stride = 2
if padding is None:
if isinstance(kernel_size, int) and isinstance(stride, tuple):
padding = tuple(int((kernel_size - st) / 2) for st in stride)
elif isinstance(kernel_size, tuple) and isinstance(stride, int):
padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
padding = tuple(
int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
else:
padding = int((kernel_size - stride) / 2)
if self.residual:
if downsample:
if type == '1d':
self.residual_layer = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding))
elif type == '2d':
self.residual_layer = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding))
else:
if in_channels == out_channels:
self.residual_layer = nn.Identity()
else:
if type == '1d':
self.residual_layer = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding))
elif type == '2d':
self.residual_layer = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding))
in_channels = in_channels * groups
out_channels = out_channels * groups
if type == '1d':
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm1d(out_channels)
self.dropout = nn.Dropout(p=p)
elif type == '2d':
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm2d(out_channels)
self.dropout = nn.Dropout2d(p=p)
if norm == 'gn':
self.norm = nn.GroupNorm(2, out_channels)
elif norm == 'ln':
self.norm = nn.LayerNorm(out_channels)
if leaky:
self.relu = nn.LeakyReLU(negative_slope=0.2)
else:
self.relu = nn.ReLU()
def forward(self, x, **kwargs):
if self.norm_type == 'ln':
out = self.dropout(self.conv(x))
out = self.norm(out.transpose(1, 2)).transpose(1, 2)
else:
out = self.norm(self.dropout(self.conv(x)))
if self.residual:
residual = self.residual_layer(x)
out += residual
return self.relu(out)
================================================
FILE: models/pca.py
================================================
import numpy as np
import torch
import torch.nn as nn
class PCALayer(nn.Module):
def __init__(self, pca_info_path: str):
super().__init__()
pca_info = np.load(pca_info_path)
feat_to_data = pca_info['components_to_data'].astype(np.float32)
data_to_feat = feat_to_data.T
data_bias = pca_info['original_data_mean'].astype(np.float32)
feat_mean = pca_info['data_components_mean'].astype(np.float32)
feat_std = pca_info['data_components_std'].astype(np.float32)
self.register_buffer('feat_to_data', torch.from_numpy(feat_to_data))
self.register_buffer('data_to_feat', torch.from_numpy(data_to_feat))
self.register_buffer('data_bias', torch.from_numpy(data_bias))
self.register_buffer('feat_mean', torch.from_numpy(feat_mean))
self.register_buffer('feat_std', torch.from_numpy(feat_std))
self.pca_dim = len(feat_mean)
def encode(
self,
x: torch.Tensor,
pca_dim: int = None,
):
x = x - self.data_bias
feat = x @ (self.data_to_feat[:, :pca_dim])
return feat
def decode(self, feat: torch.Tensor, pca_dim: int = None):
data = feat[..., :pca_dim] @ (self.feat_to_data[:pca_dim])
data = data + self.data_bias
return data
class PCA:
def __init__(self, trunc_dim: int = None):
super().__init__()
def buld_pca_for_dataset(self, dataset, trunc_dim: int = 1024):
annot_array = self.load_dataset(dataset)
self.build_pca(annot_array, trunc_dim)
def load_dataset(self, dataset):
annot_list = []
indices = np.arange(len(dataset))
for i in indices:
data_dict = dataset.__getitem__(i)
annot_list.append(data_dict['data'] - data_dict['template'])
annot_array = np.concatenate(annot_list, axis=0)
return annot_array
def build_incremental_PCA(
self,
in_data: np.ndarray,
trunc_dim=None,
batch_size=1024,
):
indices = np.arange(len(in_data))
np.random.shuffle(indices)
in_data = in_data[indices]
if in_data.dtype != np.float32:
in_data = in_data.astype(np.float32)
from sklearn.decomposition import IncrementalPCA
ipca = IncrementalPCA(n_components=trunc_dim, batch_size=batch_size)
data_mean = in_data.mean(0)
in_data = in_data - data_mean
in_data_slices = np.split(
in_data, np.arange(batch_size, len(in_data), batch_size))
for batch_data in in_data_slices:
if len(batch_data) != batch_size:
continue
ipca.partial_fit(batch_data)
S = ipca.singular_values_
components_to_data = ipca.components_
explained_variance_ratio = ipca.explained_variance_ratio_
data_components = ipca.transform(in_data)
data_components_mean = data_components.mean(0)
data_components_std = data_components.std(0)
ret_dict = {
'explained_variance_ratio': explained_variance_ratio,
'components_to_data': components_to_data,
'original_data_mean': data_mean,
'S': S,
'data_components_mean': data_components_mean,
'data_components_std': data_components_std
}
return ret_dict
def build_pca(
self,
in_data: np.ndarray,
trunc_dim=None,
save_compactness=True,
check_svd=True,
):
ret_dict = {}
in_data = torch.Tensor(in_data).cuda()
count, in_dim = in_data.shape
mean_data = in_data.mean(dim=0)
in_data_center = in_data - mean_data # (count, in_dim)
in_data_center = in_data_center.cpu()
# U (count, count), S(count, ), Vt (count, in_dim)
U, S, Vt = torch.linalg.svd(in_data_center, full_matrices=False)
if check_svd:
S_mat = torch.diag(S)
rebuilt = U @ S_mat @ Vt
is_precise = torch.allclose(rebuilt, in_data_center)
print('svd rebuilt all close ?', is_precise)
var_S = S**2
if save_compactness:
compact_vector = var_S / var_S.sum()
compact_value = compact_vector[0:trunc_dim].sum()
ret_dict['compactness, ', compact_value]
pca_bases = Vt[0:trunc_dim].reshape(trunc_dim, -1)
ret_dict['pca_bases'] = pca_bases
ret_dict['pca_std'] = S[0:trunc_dim]
ret_dict['mean_data'] = mean_data
return ret_dict
================================================
FILE: models/unitalker.py
================================================
import os.path as osp
import torch.nn as nn
from dataset.dataset_config import dataset_config
from models.wav2vec2 import Wav2Vec2Model
from .base_model import BaseModel
# from models.hubert import HubertModel
from models.wavlm import WavLMModel
class OutHead(BaseModel):
def __init__(self, args, out_dim: int):
super().__init__()
head_layer = []
in_dim = args.decoder_dimension
for i in range(args.headlayer - 1):
head_layer.append(nn.Linear(in_dim, in_dim))
head_layer.append(nn.Tanh())
head_layer.append(nn.Linear(in_dim, out_dim))
self.head_layer = nn.Sequential(*head_layer)
def forward(self, x):
return self.head_layer(x)
class UniTalker(BaseModel):
def __init__(self, args):
super().__init__()
self.args = args
if 'wav2vec2' in args.audio_encoder_repo:
self.audio_encoder = Wav2Vec2Model.from_pretrained(
args.audio_encoder_repo)
elif 'wavlm' in args.audio_encoder_repo:
self.audio_encoder = WavLMModel.from_pretrained(
args.audio_encoder_repo)
else:
raise ValueError("wrong audio_encoder_repo")
self.audio_encoder.feature_extractor._freeze_parameters()
if args.freeze_wav2vec:
self.audio_encoder._freeze_wav2vec2_parameters()
if args.decoder_type == 'conv':
from .unitalker_decoder import UniTalkerDecoderTCN as Decoder
elif args.decoder_type == 'transformer':
from .unitalker_decoder import \
UniTalkerDecoderTransformer as Decoder
else:
ValueError('unknown decoder type ')
self.decoder = Decoder(args)
if args.use_pca:
pca_layer_dict = {}
from models.pca import PCALayer
for dataset_name in args.dataset:
if dataset_config[dataset_name]['pca']:
annot_type = dataset_config[dataset_name]['annot_type']
dirname = dataset_config[dataset_name]['dirname']
pca_path = osp.join(args.data_root, dirname, 'pca.npz')
pca_layer_dict[annot_type] = PCALayer(pca_path)
self.pca_layer_dict = nn.ModuleDict(pca_layer_dict)
else:
self.pca_layer_dict = {}
self.pca_dim = args.pca_dim
self.use_pca = args.use_pca
self.interpolate_pos = args.interpolate_pos
out_head_dict = {}
for dataset_name in args.dataset:
annot_type = dataset_config[dataset_name]['annot_type']
out_dim = dataset_config[dataset_name]['annot_dim']
if (args.use_pca is True) and (annot_type in self.pca_layer_dict):
pca_dim = self.pca_layer_dict[annot_type].pca_dim
out_dim = min(args.pca_dim, out_dim, pca_dim)
if args.headlayer == 1:
out_projection = nn.Linear(args.decoder_dimension, out_dim)
nn.init.constant_(out_projection.weight, 0)
nn.init.constant_(out_projection.bias, 0)
else:
out_projection = OutHead(args, out_dim)
out_head_dict[annot_type] = out_projection
self.out_head_dict = nn.ModuleDict(out_head_dict)
self.identity_num = args.identity_num
return
def forward(self,
audio,
template,
face_motion,
style_idx,
annot_type: str,
fps: float = None):
if face_motion is not None:
frame_num = face_motion.shape[1]
else:
frame_num = round(audio.shape[-1] / 16000 * fps)
hidden_states = self.audio_encoder(
audio, frame_num=frame_num, interpolate_pos=self.interpolate_pos)
hidden_states = hidden_states.last_hidden_state
decoder_out = self.decoder(hidden_states, style_idx, frame_num)
if (not self.use_pca) or (annot_type not in self.pca_layer_dict):
out_motion = self.out_head_dict[annot_type](decoder_out)
out_motion = out_motion + template
out_pca, gt_pca = None, None
else:
out_pca = self.out_head_dict[annot_type](decoder_out)
out_motion = self.pca_layer_dict[annot_type].decode(
out_pca, self.pca_dim)
out_motion = out_motion + template
if face_motion is not None:
gt_pca = self.pca_layer_dict[annot_type].encode(
face_motion - template, self.pca_dim)
else:
gt_pca = None
return out_motion, out_pca, gt_pca
================================================
FILE: models/unitalker_decoder.py
================================================
# yapf: disable
import torch
import torch.nn as nn
from .base_model import BaseModel
from .model_utils import (
TCN, PeriodicPositionalEncoding, enc_dec_mask, init_biased_mask,
linear_interpolation,
)
# yapf: enable
class UniTalkerDecoderTCN(BaseModel):
def __init__(self, args) -> None:
super().__init__()
self.learnable_style_emb = nn.Embedding(args.identity_num, 64)
in_dim = args.decoder_dimension
out_dim = args.decoder_dimension
self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, in_dim)
self.tcn = TCN(in_dim + 64, out_dim)
self.interpolate_pos = args.interpolate_pos
def forward(self,
hidden_states: torch.Tensor,
style_idx: torch.Tensor,
frame_num: int = None):
batch_size = len(hidden_states)
if style_idx is None:
style_embedding = self.learnable_style_emb.weight[-1].unsqueeze(
0).repeat(batch_size, 1)
else:
style_embedding = self.learnable_style_emb(style_idx)
feature = self.audio_feature_map(hidden_states).transpose(1, 2)
feature = self.tcn(feature, style_embedding)
feature = feature.transpose(1, 2)
if self.interpolate_pos == 2:
feature = linear_interpolation(feature, output_len=frame_num)
return feature
class UniTalkerDecoderTransformer(BaseModel):
def __init__(self, args) -> None:
super().__init__()
out_dim = args.decoder_dimension
self.learnable_style_emb = nn.Embedding(args.identity_num, out_dim)
self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, out_dim)
self.PPE = PeriodicPositionalEncoding(
out_dim, period=args.period, max_seq_len=3000)
self.biased_mask = init_biased_mask(
n_head=4, max_seq_len=3000, period=args.period)
decoder_layer = nn.TransformerDecoderLayer(
d_model=out_dim,
nhead=4,
dim_feedforward=2 * out_dim,
batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer, num_layers=1)
self.interpolate_pos = args.interpolate_pos
def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,
frame_num: int):
obj_embedding = self.learnable_style_emb(style_idx)
obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1)
hidden_states = self.audio_feature_map(hidden_states)
style_input = self.PPE(obj_embedding)
tgt_mask = self.biased_mask[:, :style_input.shape[1], :style_input.
shape[1]].clone().detach().to(
device=style_input.device)
memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1],
frame_num)
feat_out = self.transformer_decoder(
style_input,
hidden_states,
tgt_mask=tgt_mask,
memory_mask=memory_mask)
if self.interpolate_pos == 2:
feat_out = linear_interpolation(feat_out, output_len=frame_num)
return feat_out
================================================
FILE: models/wav2vec2.py
================================================
import numpy as np
import torch
from transformers import Wav2Vec2Model
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Tuple
from .model_utils import linear_interpolation
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.Tensor] = None,
min_masks: int = 0,
) -> np.ndarray:
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(mask_prob * all_sz / float(mask_length) +
np.random.rand())
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(mask_prob * sz / float(mask_length) +
np.random.rand())
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
lengths = np.full(num_mask, mask_length)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([
mask_idc[j] + offset for j in range(len(mask_idc))
for offset in range(lengths[j])
])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
class Wav2Vec2Model(Wav2Vec2Model):
def __init__(self, config):
super().__init__(config)
def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
frame_num=None,
interpolate_pos: int = 0,
):
self.config.output_attentions = True
output_attentions = output_attentions if \
output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
conv_features = self.feature_extractor(input_values)
hidden_states = conv_features.transpose(1, 2)
if interpolate_pos == 0:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if attention_mask is not None:
output_lengths = self._get_feat_extract_output_lengths(
attention_mask.sum(-1))
attention_mask = torch.zeros(
hidden_states.shape[:2],
dtype=hidden_states.dtype,
device=hidden_states.device)
attention_mask[(torch.arange(
attention_mask.shape[0],
device=hidden_states.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip(
[-1]).bool()
hidden_states, _ = self.feature_projection(hidden_states)
if self.config.apply_spec_augment and self.training:
batch_size, sequence_length, hidden_size = hidden_states.size()
if self.config.mask_time_prob > 0:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
self.config.mask_time_prob,
self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=2,
)
hidden_states[torch.from_numpy(
mask_time_indices)] = self.masked_spec_embed.to(
hidden_states.dtype)
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
self.config.mask_feature_prob,
self.config.mask_feature_length,
)
mask_feature_indices = torch.from_numpy(
mask_feature_indices).to(hidden_states.device)
hidden_states[mask_feature_indices[:, None].expand(
-1, sequence_length, -1)] = 0
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if interpolate_pos == 1:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if not return_dict:
return (hidden_states, ) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
================================================
FILE: models/wavlm.py
================================================
import numpy as np
import torch
from transformers import WavLMModel
from transformers.modeling_outputs import Wav2Vec2BaseModelOutput
from typing import Optional, Tuple, Union
from .model_utils import linear_interpolation
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
class WavLMModel(WavLMModel):
def __init__(self, config):
super().__init__(config)
def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
frame_num=None,
interpolate_pos: int = 0,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if interpolate_pos == 0:
extract_features = linear_interpolation(
extract_features, output_len=frame_num)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if interpolate_pos == 1:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
================================================
FILE: utils/config.py
================================================
# -----------------------------------------------------------------------------
# Functions for parsing args
# -----------------------------------------------------------------------------
import copy
import os
import yaml
from ast import literal_eval
class CfgNode(dict):
"""CfgNode represents an internal node in the configuration tree.
It's a simple dict-like container that allows for attribute-based access to
keys.
"""
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
# Recursively convert nested dictionaries in init_dict into CfgNodes
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
for k, v in init_dict.items():
if type(v) is dict:
# Convert dict to CfgNode
init_dict[k] = CfgNode(v, key_list=key_list + [k])
super(CfgNode, self).__init__(init_dict)
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
self[name] = value
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
r = ''
s = []
for k, v in sorted(self.items()):
separator = '\n' if isinstance(v, CfgNode) else ' '
attr_str = '{}:{}{}'.format(str(k), separator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += '\n'.join(s)
return r
def __repr__(self):
return '{}({})'.format(self.__class__.__name__,
super(CfgNode, self).__repr__())
def load_cfg_from_cfg_file(file):
cfg = {}
assert os.path.isfile(file) and file.endswith('.yaml'), \
'{} is not a yaml file'.format(file)
with open(file, 'r') as f:
cfg_from_file = yaml.safe_load(f)
for key in cfg_from_file:
for k, v in cfg_from_file[key].items():
cfg[k] = v
cfg = CfgNode(cfg)
return cfg
def merge_cfg_from_list(cfg, cfg_list):
new_cfg = copy.deepcopy(cfg)
assert len(cfg_list) % 2 == 0
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
subkey = full_key.split('.')[-1]
assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
value = _decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
full_key)
setattr(new_cfg, subkey, value)
return new_cfg
def _decode_cfg_value(v):
"""Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object."""
# All remaining processing is only applied to strings
if not isinstance(v, str):
return v
# Try to interpret `v` as a:
# string, number, tuple, list, dict, boolean, or None
try:
v = literal_eval(v)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return v
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
"""Checks that `replacement`, which is intended to replace `original` is of
the right type.
The type is correct if it matches exactly or is one of a few cases in which
the type can be easily coerced.
"""
original_type = type(original)
replacement_type = type(replacement)
# The types must match (with some exceptions)
if replacement_type == original_type or original is None:
return replacement
# types match from_type and to_type
def conditional_cast(from_type, to_type):
if replacement_type == from_type and original_type == to_type:
return True, to_type(replacement)
else:
return False, None
# Conditionally casts
# list <-> tuple
casts = [(tuple, list), (list, tuple)]
# For py2: allow converting from str (bytes) to a unicode string
try:
casts.append((str, unicode)) # noqa: F821
except Exception:
pass
for (from_type, to_type) in casts:
converted, converted_value = conditional_cast(from_type, to_type)
if converted:
return converted_value
raise ValueError(
'Type mismatch ({} vs. {}) with values ({} vs. {}) for config '
'key: {}'.format(original_type, replacement_type, original,
replacement, full_key))
================================================
FILE: utils/utils.py
================================================
import argparse
import logging
import numpy as np
import os
import random
import torch
from copy import deepcopy
from . import config
def get_audio_encoder_dim(audio_encoder: str):
if audio_encoder == 'microsoft/wavlm-base-plus':
return 768
elif audio_encoder == 'facebook/wav2vec2-large-xlsr-53':
return 1024
else:
raise ValueError("wrong audio_encoder")
def filer_list(in_list: list, key: str):
ret_list = []
for line in in_list:
if line.startswith(key):
ret_list.append(line)
return ret_list
def count_checkpoint_params(ckpt: dict):
params = 0
pca_params = 0
for k, v in ckpt.items():
if 'pca' in k:
print(k)
pca_params += np.prod(v.size())
params += np.prod(v.size())
return params, pca_params
def read_obj(in_path):
with open(in_path, 'r') as obj_file:
# Read the lines of the OBJ file
lines = obj_file.readlines()
# Initialize empty lists for vertices and faces
verts = []
faces = []
for line in lines:
line = line.strip() # Remove leading/trailing whitespace
elements = line.split() # Split the line into elements
if len(elements) == 0:
continue # Skip empty lines
# Check the type of line (vertex or face)
if elements[0] == 'v':
# Vertex line
x, y, z = map(float,
elements[1:4]) # Extract the vertex coordinates
verts.append((x, y, z)) # Add the vertex to the list
elif elements[0] == 'f':
# Face line
face_indices = [
int(index.split('/')[0]) for index in elements[1:]
] # Extract the vertex indices
faces.append(face_indices) # Add the face to the list
verts = np.array(verts)
faces = np.array(faces)
if faces.min() == 1:
faces = faces - 1
return verts, faces
def get_logger(log_file: str):
logger_name = 'main-logger'
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
handler = logging.FileHandler(log_file, mode='w')
fmt = '[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d]=>%(message)s' # noqa: E501
handler.setFormatter(logging.Formatter(fmt))
for hdl in logger.handlers:
logger.removeHandler(hdl)
logger.addHandler(handler)
return logger
def get_template_dict(data_root, annot_type: str):
import pickle
if annot_type == 'BIWI_23370_vertices':
template_pkl_path = os.path.join(data_root, 'BIWI', 'templates.pkl')
elif annot_type == 'FLAME_5023_vertices':
template_pkl_path = os.path.join(data_root, 'vocaset', 'templates.pkl')
with open(template_pkl_path, 'rb') as f:
template_dict = pickle.load(f, encoding='latin')
return template_dict
def filter_unitalker_state_dict(in_dict):
out_dict = {}
for k, v in in_dict.items():
if not 'pca_layer_dict' in k:
out_dict[k] = v
return out_dict
def get_template_verts(data_root, dataset_name: str, subject: int):
from dataset.dataset_config import dataset_config
template_path = os.path.join(data_root, dataset_config[dataset_name]['dirname'], 'id_template.npy')
return np.load(template_path)[subject]
def load_ckpt(model, ckpt_path, re_init_decoder_and_head: bool = False):
checkpoint = torch.load(ckpt_path, map_location='cpu')
id_embedding_key = 'decoder.learnable_style_emb.weight'
if len(checkpoint[id_embedding_key]) != len(
model.state_dict()[id_embedding_key]):
target_id_num = len(model.state_dict()[id_embedding_key])
checkpoint[id_embedding_key] = checkpoint[id_embedding_key][-1][
None].repeat(target_id_num, 1)
if re_init_decoder_and_head is True:
ckpt_keys = list(checkpoint.keys())
encoder_keys = filer_list(ckpt_keys, 'audio_encoder')
reinit_keys = list(set(ckpt_keys) - set(encoder_keys))
for k in reinit_keys:
if k in model.state_dict():
checkpoint[k] = model.state_dict()[k]
model.load_state_dict(checkpoint, strict=False)
return
class AverageMeter(object):
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_avg(self, ):
return self.avg
def get_average_meter_dict(args, phase):
if phase == 'train':
metric_name_list = ['rec_L2_loss', 'pca_loss']
elif phase == 'val' or phase == 'test':
metric_name_list = ['rec_L2_loss', 'LVE', 'LVD']
else:
raise ValueError('wrong phase for average meter')
average_meter_dict_template = {}
for k in metric_name_list:
average_meter_dict_template[k] = AverageMeter()
mixed_dataset_name = 'mixed_dataset'
dataset_name_list = args.dataset + [mixed_dataset_name]
average_meter_dict = {}
for dataset_name in dataset_name_list:
average_meter_dict[dataset_name] = deepcopy(
average_meter_dict_template)
return average_meter_dict, metric_name_list
def log_datasetloss(
logger,
epoch: int,
phase: str,
dataset_loss_dict: dict,
only_mix: bool = True,
):
base_str = f'{phase.upper():<5} Epoch: {epoch:03d} '
for dataset_name, loss_info in dataset_loss_dict.items():
if only_mix and dataset_name != 'mixed_dataset':
continue
for loss_type, avg_meter in loss_info.items():
base_str += f'{loss_type}_{dataset_name}:{avg_meter.avg: .8f}, '
logger.info(base_str)
return base_str
def write_to_tensorboard(writer,
epoch: int,
phase: str,
dataset_loss_dict: dict,
only_mix: bool = False):
for dataset_name, loss_info in dataset_loss_dict.items():
if only_mix and dataset_name != 'mixed_dataset':
continue
for loss_type, avg_meter in loss_info.items():
key = f'{phase}/{loss_type}_{dataset_name}'
writer.add_scalar(key, avg_meter.avg, epoch)
return 0
def get_parser():
parser = argparse.ArgumentParser(description=' ')
parser.add_argument(
'--config',
type=str,
default='config/unitalker.yaml',
help='config file')
parser.add_argument(
'--condition_id', type=str, default='common', help='condition id')
parser.add_argument(
'opts', help=' ', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
assert args.config is not None
cfg = config.load_cfg_from_cfg_file(args.config)
if args.opts is not None:
cfg = config.merge_cfg_from_list(cfg, args.opts)
cfg.condition_id = args.condition_id
cfg.dataset = cfg.dataset.split(',')
cfg.demo_dataset = cfg.demo_dataset.split(',')
if isinstance(cfg.duplicate_list, str):
cfg.duplicate_list = cfg.duplicate_list.split(',')
cfg.duplicate_list = [int(i) for i in cfg.duplicate_list]
return cfg
def seed_everything(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
return