Repository: X-niper/UniTalker Branch: main Commit: 21481130f126 Files: 23 Total size: 117.9 KB Directory structure: gitextract_hq8r3so5/ ├── .gitignore ├── LICENSE ├── README.md ├── config/ │ └── unitalker.yaml ├── dataset/ │ ├── data_item.py │ ├── dataset.py │ └── dataset_config.py ├── loss/ │ └── loss.py ├── main/ │ ├── demo.py │ ├── render.py │ ├── train.py │ ├── train_eval_loop.py │ └── train_pca.py ├── models/ │ ├── base_model.py │ ├── hubert.py │ ├── model_utils.py │ ├── pca.py │ ├── unitalker.py │ ├── unitalker_decoder.py │ ├── wav2vec2.py │ └── wavlm.py └── utils/ ├── config.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized files *__pycache__/ *.py[cod] # Distribution / packaging dist/ build/ *.egg-info/ *.egg # Virtual environments venv/ env/ .venv/ .pip/ pipenv/ # IDE specific files .idea/ .vscode/ # Logs and temporary files *.log *.swp *.swo *.tmp # Miscellaneous .DS_Store records.txt binary_resources do_experiment.py small_tools scancel_job.py results template_resources ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # UniTalker: Scaling up Audio-Driven 3D Facial Animation through A Unified Model [ECCV2024] ## Useful Links
[Homepage]      [arXiv]      [Video]     
UniTalker Architecture
> UniTalker generates realistic facial motion from different audio domains, including clean and noisy voices in various languages, text-to-speech-generated audios, and even noisy songs accompanied by back-ground music. > > UniTalker can output multiple annotations. > > For datasets with new annotations, one can simply plug new heads into UniTalker and train it with existing datasets or solely with new ones, avoiding retopology. ## Installation ### Environment - Linux - Python 3.10 - Pytorch 2.2.0 - CUDA 12.1 - transformers 4.39.3 - Pytorch3d 0.7.7 (Optional: just for rendering the results) ```bash conda create -n unitalker python==3.10 conda activate unitalker conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia pip install transformers librosa tensorboardX smplx chumpy numpy==1.23.5 opencv-python ``` ## Inference ### Download checkpoints, PCA models and template resources [UniTalker-B-[D0-D7]](https://drive.google.com/file/d/1PmF8I6lyo0_64-NgeN5qIQAX6Bg0yw44/view?usp=sharing): The base model in paper. Download it and place it in `./pretrained_models` . [UniTalker-L-[D0-D7]](https://drive.google.com/file/d/1sH2T7KLFNjUnTM-V1eRMM1Tytxd2sYAp/view?usp=sharing): The default model in paper. Please first try the base model to run the pipeline through. [Unitalker-data-release-V1](https://drive.google.com/file/d/1Un7TB0Z5A1CG6bgeqKlhnSOECFN-C6KK/view?usp=sharing): The released datasets, PCA models, data-split json files and id-template numpy array. Download and unzip it in this repo. [FLAME2020](https://flame.is.tue.mpg.de/download.php): Please download FLAME 2020 and move generic_model.pkl into resources/binary_resources/flame.pkl. Use `git lfs pull` to get `./resources.zip` and `./test_audios.zip` and unzip it in this repo. Finally, these files should be organized as follows: ```text ├── pretrained_models │ ├── UniTalker-B-D0-D7.pt │ ├── UniTalker-L-D0-D7.pt ├── resources │ ├── binary_resources │ │ ├── 02_flame_mouth_idx.npy │ │ ├── ... │ │ └── vocaset_FDD_wo_eyes.npy │ └── obj_template │ ├── 3DETF_blendshape_weight.obj │ ├── ... │ └── meshtalk_6172_vertices.obj └── unitalker_data_release_V1 │ ├── D0_BIWI │ │ ├── id_template.npy │ │ └── pca.npz │ ├── D1_vocaset │ │ ├── id_template.npy │ │ └── pca.npz │ ├── D2_meshtalk │ │ ├── id_template.npy │ │ └── pca.npz │ ├── D3D4_3DETF │ │ ├── D3_HDTF │ │ └── D4_RAVDESS │ ├── D5_unitalker_faceforensics++ │ │ ├── id_template.npy │ │ ├── test │ │ ├── test.json │ │ ├── train │ │ ├── train.json │ │ ├── val │ │ └── val.json │ ├── D6_unitalker_Chinese_speech │ │ ├── id_template.npy │ │ ├── test │ │ ├── test.json │ │ ├── train │ │ ├── train.json │ │ ├── val │ │ └── val.json │ └── D7_unitalker_song │ ├── id_template.npy │ ├── test │ ├── test.json │ ├── train │ ├── train.json │ ├── val │ └── val.json ``` ### Demo ```bash python -m main.demo --config config/unitalker.yaml test_out_path ./test_results/demo.npz python -m main.render ./test_results/demo.npz ./test_audios ./test_results/ ``` ## Train ### Download Data [Unitalker-data-release-V1](https://drive.google.com/file/d/1qRBPsTdOWp72ty04oD1Q_ivtwMjrACLH/view?usp=sharing) contains D5, D6 and D7. The datasets have been processed and grouped into train, validation and test. Please use these three datasets to try the training step. If you want to train the model on the D0-D7, you need to download the datasets following these links: [D0: BIWI](https://github.com/Doubiiu/CodeTalker/blob/main/BIWI/README.md). [D1: VOCASET](https://voca.is.tue.mpg.de/). [D2: meshtalk](https://github.com/facebookresearch/meshtalk?tab=readme-ov-file). [D4,D5: 3DETF](https://github.com/psyai-net/EmoTalk_release). ### Modify Config and Train Please modify `dataset` and `duplicate_list` in `config/unitalker.yaml` according to the datasets you have prepared, ensuring that both lists maintain the same length. ```bash python -m main.train --config config/unitalker.yaml ``` ================================================ FILE: config/unitalker.yaml ================================================ DATA: dataset: "D0,D1,D2,D3,D4,D5,D6,D7" demo_dataset: "D0,D1,D2,D3,D4,D6,D7" # To visualize D5, please download FLAME2020 model first and place it in "resources/binary_resources/flame.pkl" data_root: "./unitalker_data_release_V1/" duplicate_list: "10,5,4,1,1,1,1,1" train_json_key: train.json LOSS: rec_weight: 1.0 pca_weight: 0.01 blendshape_weight: 0.0001 NETWORK: use_pca: True pca_dim: 512 MODEL: audio_encoder_repo: 'microsoft/wavlm-base-plus' # choose from 'microsoft/wavlm-base-plus' and facebook/wav2vec2-large-xlsr-53' freeze_wav2vec: False interpolate_pos: 1 decoder_dimension: 256 decoder_type: conv # choose from conv and transformer period: 30 headlayer: 1 TRAIN: fix_encoder_first: True random_id_prob: 0.1 re_init_decoder_and_head: False only_head: False workers: 4 # data loader workers batch_size: 1 # batch size for training batch_size_val: 1 # batch size for validation during training, memory and speed tradeoff lr: 0.0001 epochs: 100 save_every: 1 save_path: "./results" weight_path: "./pretrained_models/UniTalker-B-D0-D7.pt" eval_every: 1 test_every: 1 log_every: 1 do_test: True TEST: test_wav_dir: "./test_audios" test_out_path: "./test_results/demo.npz" ================================================ FILE: dataset/data_item.py ================================================ import librosa import numpy as np from typing import Any from .dataset_config import dataset_config class DataItem: def __init__(self, annot_path: str, audio_path: str, identity_idx: int, annot_type: str = '', dataset_name: str = '', id_template: np.ndarray = None, fps: float = 30, processor=None) -> None: self.annot_path = annot_path self.audio_path = audio_path self.identity_idx = identity_idx self.fps = fps self.annot_data = np.load(annot_path) if self.fps == 60: self.annot_data = self.annot_data[::2] self.fps = 30 elif self.fps == 100: self.annot_data = self.annot_data[::4] self.fps = 25 elif self.fps != 30 and self.fps != 25: raise ValueError('wrong fps') scale = dataset_config[dataset_name]['scale'] self.annot_data = self.scale_and_offset(self.annot_data, scale, 0.0) self.audio_data, sr = self.load_audio(audio_path) self.original_audio_data, self.annot_data = self.truncate( self.audio_data, sr, self.annot_data, self.fps) self.audio_data = np.squeeze( processor(self.original_audio_data, sampling_rate=sr).input_values) self.audio_data = self.audio_data.astype(np.float32) self.annot_type = annot_type self.id_template = self.scale_and_offset(id_template, scale, 0.0) self.annot_data = self.annot_data.reshape(len(self.annot_data), -1).astype(np.float32) self.id_template = self.id_template.reshape(1, -1).astype(np.float32) self.dataset_name = dataset_name self.data_dict = self.to_dict() return def load_audio(self, audio_path: str): if audio_path.endswith('.npy'): return np.load(audio_path), 16000 else: return librosa.load(audio_path, sr=16000) def truncate(self, audio_array: np.ndarray, sr: int, annot_data: np.ndarray, fps: int): audio_duration = len(audio_array) / sr annot_duration = len(annot_data) / fps duration = min(audio_duration, annot_duration, 20) audio_length = round(duration * sr) annot_length = round(duration * fps) self.duration = duration return audio_array[:audio_length], annot_data[:annot_length] def offset_id(self, offset: int): self.identity_idx = self.identity_idx + offset self.data_dict['id'] = self.identity_idx return def scale_and_offset(self, data: np.ndarray, scale: float = 1.0, offset: np.ndarray = 0.0): return data * scale + offset def to_dict(self, ) -> Any: item = { 'data': self.annot_data, 'audio': self.audio_data, 'original_audio': self.original_audio_data, 'fps': self.fps, 'id': self.identity_idx, 'annot_path': self.annot_path, 'audio_path': self.audio_path, 'annot_type': self.annot_type, 'template': self.id_template, 'dataset_name': self.dataset_name } return item def get_dict(self, ): return self.data_dict ================================================ FILE: dataset/dataset.py ================================================ import json import numpy as np import os.path as osp from torch.utils import data from tqdm import tqdm from transformers import Wav2Vec2FeatureExtractor from typing import Any, List from .data_item import DataItem from .dataset_config import dataset_config class AudioFaceDataset(data.Dataset): def __init__(self, data_label_path: str, args: dict = None) -> None: super().__init__() data_root = osp.dirname(data_label_path, ) id_template_path = osp.join(data_root, 'id_template.npy') id_template_list = np.load(id_template_path) with open(data_label_path) as f: labels = json.load(f) info = labels['info'] self.id_list = info['id_list'] self.sr = 16000 data_info_list = labels['data'] data_list = [] for data_info in tqdm(data_info_list): data = DataItem( annot_path=osp.join(data_root, data_info['annot_path']), audio_path=osp.join(data_root, data_info['audio_path']), identity_idx=data_info['id'], annot_type=data_info['annot_type'], dataset_name=data_info['dataset'], id_template=id_template_list[data_info['id']], fps=data_info['fps'], processor=args.processor, ) if data.duration < 0.5: continue data_list.append(data) self.data_list = data_list return def __len__(self, ): return len(self.data_list) def __getitem__(self, index: Any) -> Any: data = self.data_list[index] return data.get_dict() def get_identity_num(self, ): return len(self.id_list) class MixAudioFaceDataset(AudioFaceDataset): def __init__(self, dataset_list: List[AudioFaceDataset], duplicate_list: list = None) -> None: super(AudioFaceDataset).__init__() self.id_list = [] self.sr = 16000 self.data_list = [] self.annot_type_list = [] if duplicate_list is None: duplicate_list = [1] * len(dataset_list) else: assert len(duplicate_list) == len(dataset_list) for dup, dataset in zip(duplicate_list, dataset_list): id_index_offset = len(self.id_list) for d in dataset.data_list: d.offset_id(id_index_offset) if d.annot_type not in self.annot_type_list: self.annot_type_list.append(d.annot_type) self.id_list = self.id_list + dataset.id_list self.data_list = self.data_list + dataset.data_list * dup return def get_dataset_list(args): dataset_name_list = args.dataset dataset_dir_list = [ dataset_config[name]['dirname'] for name in dataset_name_list ] train_json_list = [ osp.join(args.data_root, dataset_name, args.train_json_key) for dataset_name in dataset_dir_list ] args.read_audio = False args.processor = None train_dataset_list = [] processor = Wav2Vec2FeatureExtractor.from_pretrained( args.audio_encoder_repo) args.processor = processor for train_json in train_json_list: dataset = AudioFaceDataset( train_json, args, ) train_dataset_list.append(dataset) return train_dataset_list def get_single_dataset(args, json_path: str = ''): args.read_audio = True processor = Wav2Vec2FeatureExtractor.from_pretrained( args.audio_encoder_repo) args.processor = processor dataset = AudioFaceDataset( json_path, args, ) test_loader = data.DataLoader( dataset=dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=args.workers) return test_loader def get_dataloaders(args): dataset_name_list = args.dataset print(dataset_name_list) dataset_dir_list = [ dataset_config[name]['dirname'] for name in dataset_name_list ] train_json_list = [ osp.join(args.data_root, dataset_dir, args.train_json_key) for dataset_dir in dataset_dir_list ] val_json_list = [ osp.join(args.data_root, dataset_dir, 'val.json') for dataset_dir in dataset_dir_list ] test_json_list = [ osp.join(args.data_root, dataset_dir, 'test.json') for dataset_dir in dataset_dir_list ] processor = Wav2Vec2FeatureExtractor.from_pretrained( args.audio_encoder_repo) args.processor = processor if getattr(args, 'do_train', True) is True: train_dataset_list = [] for train_json in train_json_list: dataset = AudioFaceDataset(train_json, args) train_dataset_list.append(dataset) mix_train_dataset = MixAudioFaceDataset(train_dataset_list, args.duplicate_list) train_loader = data.DataLoader( dataset=mix_train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) else: train_loader = None if getattr(args, 'do_validate', True) is True: val_dataset_list = [] for val_json in val_json_list: dataset = AudioFaceDataset(val_json, args) val_dataset_list.append(dataset) mix_val_dataset = MixAudioFaceDataset(val_dataset_list) val_loader = data.DataLoader( dataset=mix_val_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=args.workers) else: val_loader = None if getattr(args, 'do_test', True) is True: test_dataset_list = [] for test_json in test_json_list: dataset = AudioFaceDataset(test_json, args) test_dataset_list.append(dataset) mix_test_dataset = MixAudioFaceDataset(test_dataset_list) test_loader = data.DataLoader( dataset=mix_test_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=args.workers) else: test_loader = None return train_loader, val_loader, test_loader ================================================ FILE: dataset/dataset_config.py ================================================ dataset_config = { 'D0': { 'dirname': 'D0_BIWI', 'annot_type': 'BIWI_23370_vertices', 'scale': 0.2, 'annot_dim': 23370 * 3, 'subjects': 6, 'pca': True, }, 'D1': { 'dirname': 'D1_vocaset', 'annot_type': 'FLAME_5023_vertices', 'scale': 1.0, 'annot_dim': 5023 * 3, 'subjects': 12, 'pca': True, }, 'D2': { 'dirname': 'D2_meshtalk', 'annot_type': 'meshtalk_6172_vertices', 'scale': 0.001, 'annot_dim': 6172 * 3, 'subjects': 13, 'pca': True, }, 'D3': { 'dirname': 'D3D4_3DETF/D3_HDTF', 'annot_type': '3DETF_blendshape_weight', 'scale': 1.0, 'annot_dim': 52, 'subjects': 141, 'pca': False, }, 'D4': { 'dirname': 'D3D4_3DETF/D4_RAVDESS', 'annot_type': '3DETF_blendshape_weight', 'scale': 1.0, 'annot_dim': 52, 'subjects': 24, 'pca': False, }, 'D5': { 'dirname': 'D5_unitalker_faceforensics++', 'annot_type': 'flame_params_from_dadhead', 'scale': 1.0, 'annot_dim': 413, 'subjects': 719, 'pca': False, }, 'D6': { 'dirname': 'D6_unitalker_Chinese_speech', 'annot_type': 'inhouse_blendshape_weight', 'scale': 1.0, 'annot_dim': 51, 'subjects': 8, 'pca': False, }, 'D7': { 'dirname': 'D7_unitalker_song', 'annot_type': 'inhouse_blendshape_weight', 'scale': 1.0, 'annot_dim': 51, 'subjects': 11, 'pca': False, }, } ================================================ FILE: loss/loss.py ================================================ import numpy as np import pickle import torch from dataclasses import dataclass from smplx.lbs import lbs from smplx.utils import Struct, to_np, to_tensor from torch import nn @dataclass class FlameParams: betas: torch.Tensor jaw: torch.Tensor eyeballs: torch.Tensor neck: torch.Tensor class FlameLayer(nn.Module): def __init__(self, model_path: str) -> None: super().__init__() self.dtype = torch.float32 with open(model_path, 'rb') as f: self.flame_model = Struct(**pickle.load(f, encoding='latin1')) shapedirs = self.flame_model.shapedirs # The shape components self.register_buffer('shapedirs', to_tensor(to_np(shapedirs), dtype=self.dtype)) j_regressor = to_tensor( to_np(self.flame_model.J_regressor), dtype=self.dtype) self.register_buffer('J_regressor', j_regressor) self.register_buffer( 'v_template', to_tensor(to_np(self.flame_model.v_template), dtype=self.dtype), ) num_pose_basis = self.flame_model.posedirs.shape[-1] posedirs = np.reshape(self.flame_model.posedirs, [-1, num_pose_basis]).T self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) parents = to_tensor(to_np(self.flame_model.kintree_table[0])).long() parents[0] = -1 self.register_buffer('parents', parents) self.register_buffer( 'lbs_weights', to_tensor(to_np(self.flame_model.weights), dtype=self.dtype)) default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False) self.register_parameter( 'neck_pose', nn.Parameter(default_neck_pose, requires_grad=False)) default_eyeball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False) self.register_parameter( 'eyeballs', nn.Parameter(default_eyeball_pose, requires_grad=False)) default_jaw = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False) self.register_parameter('jaw', nn.Parameter(default_jaw, requires_grad=False)) self.FLAME_CONSTS = { 'betas': 400, 'rotation': 6, 'jaw': 3, 'eyeballs': 0, 'neck': 0, 'translation': 3, 'scale': 1, } def flame_params_from_3dmm(self, tensor_3dmm: torch.Tensor, zero_expr: bool = False): constants = self.FLAME_CONSTS assert tensor_3dmm.ndim == 2 cur_index: int = 0 betas = tensor_3dmm[:, :constants['betas']] cur_index += constants['betas'] jaw = tensor_3dmm[:, cur_index:cur_index + constants['jaw']] cur_index += constants['jaw'] cur_index += constants['rotation'] eyeballs = tensor_3dmm[:, cur_index:cur_index + constants['eyeballs']] cur_index += constants['eyeballs'] neck = tensor_3dmm[:, cur_index:cur_index + constants['neck']] cur_index += constants['neck'] cur_index += constants['translation'] cur_index += constants['scale'] return FlameParams(betas=betas, jaw=jaw, eyeballs=eyeballs, neck=neck) def forward(self, tensor_3dmm: torch.Tensor): bs = tensor_3dmm.shape[0] flame_params = self.flame_params_from_3dmm(tensor_3dmm) betas = flame_params.betas rotation = torch.zeros([bs, 3], device=tensor_3dmm.device) neck_pose = flame_params.neck if not ( 0 in flame_params.neck.shape) else self.neck_pose[[0]].expand( bs, -1) eyeballs = flame_params.eyeballs if not ( 0 in flame_params.eyeballs.shape) else self.eyeballs[[0]].expand( bs, -1) jaw = flame_params.jaw if not ( 0 in flame_params.jaw.shape) else self.jaw[[0]].expand(bs, -1) full_pose = torch.cat([rotation, neck_pose, jaw, eyeballs], dim=1) template_vertices = self.v_template.unsqueeze(0).repeat(bs, 1, 1) vertices, _ = lbs( betas, full_pose, template_vertices, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, ) return vertices class FlameParamLossForDADHead(nn.Module): def __init__(self, mouth_indices_path: str, loss_mask_path: str, flame_model_path: str) -> None: super().__init__() mouth_indices = np.load(mouth_indices_path) loss_indices = np.load(loss_mask_path) self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices)) self.register_buffer('loss_indices', torch.from_numpy(loss_indices)) self.mse = nn.MSELoss() self.flame_layer = FlameLayer(flame_model_path) def forward(self, x: torch.Tensor, target: torch.Tensor): ori_shape = x.shape x = self.flame_layer(x.reshape(-1, ori_shape[-1])) target = self.flame_layer(target.reshape(-1, ori_shape[-1])) x = x.reshape((ori_shape[0], ori_shape[1], -1, 3)) target = target.reshape((ori_shape[0], ori_shape[1], -1, 3)) return self.mse(x[:, :, self.loss_indices], target[:, :, self.loss_indices]) def get_vertices( self, x: torch.Tensor, ): x = self.flame_layer(x.reshape(-1, x.shape[-1])) return x[:, self.loss_indices] def mouth_metric(self, x: torch.Tensor, target: torch.Tensor): ori_shape = x.shape x = self.flame_layer(x.reshape(-1, ori_shape[-1])) target = self.flame_layer(target.reshape(-1, ori_shape[-1])) metric_L2 = (target[:, self.mouth_idx] - x[:, self.mouth_idx])**2 metric_L2 = metric_L2.sum(-1) metric_L2 = metric_L2.max(-1)[0] metric_L2norm = metric_L2**0.5 return metric_L2.mean(), metric_L2norm.mean() class VerticesLoss(nn.Module): def __init__(self, mouth_indices_path: str) -> None: super().__init__() mouth_indices = np.load(mouth_indices_path) self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices)) self.mse = nn.MSELoss() def forward(self, x: torch.Tensor, target: torch.Tensor): return self.mse(x, target) def get_vertices( self, x: torch.Tensor, ): return x.reshape(-1, x.shape[-1] // 3, 3) def mouth_metric(self, x: torch.Tensor, target: torch.Tensor): ori_shape = x.shape target = target.reshape(*ori_shape[:-1], -1, 3) x = x.reshape(*ori_shape[:-1], -1, 3) metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2 metric_L2 = metric_L2.sum(-1) metric_L2 = metric_L2.max(-1)[0] metric_L2norm = metric_L2**0.5 return metric_L2.mean(), metric_L2norm.mean() class BlendShapeLoss(nn.Module): def __init__(self, mouth_indices_path: str, blendshape_path: str, bs_beta: float = 0.0, components_num: int = None) -> None: super().__init__() mouth_indices = np.load(mouth_indices_path) blendshape = np.load(blendshape_path) self.register_buffer('mouth_idx', torch.from_numpy(mouth_indices)) mean_shape = blendshape['meanshape'] mean_shape = mean_shape.reshape(-1).astype(np.float32) blend_shape = blendshape['blendshape'] blend_shape = blend_shape.reshape(len(blend_shape), -1).astype(np.float32) self.register_buffer('mean_shape', torch.from_numpy(mean_shape)) self.register_buffer('blend_shape', torch.from_numpy(blend_shape)) self.bs_beta = bs_beta self.mse = nn.MSELoss() self.components_num = components_num def forward(self, x: torch.Tensor, target: torch.Tensor): return self.bs_beta * self.mse(x, target) + self.mse( self.bs2vertices(x), self.bs2vertices(target)) def bs2vertices(self, x: torch.Tensor): return x[:, :self.components_num] @ \ self.blend_shape[:self.components_num] + \ self.mean_shape def get_vertices( self, x: torch.Tensor, ): x = self.bs2vertices(x) return x.reshape(-1, x.shape[-1] // 3, 3) def mouth_metric(self, x: torch.Tensor, target: torch.Tensor): target = self.bs2vertices(target) x = self.bs2vertices(x) ori_shape = x.shape target = target.reshape(*ori_shape[:-1], -1, 3) x = x.reshape(*ori_shape[:-1], -1, 3) metric_L2 = (target[:, :, self.mouth_idx] - x[:, :, self.mouth_idx])**2 metric_L2 = metric_L2.sum(-1) metric_L2 = metric_L2.max(-1)[0] metric_L2norm = metric_L2**0.5 return metric_L2.mean(), metric_L2norm.mean() class ParamLoss(nn.Module): def __init__(self, weight: float): super().__init__() self.weight = weight self.mse = nn.MSELoss() self.L1 = nn.L1Loss() def forward(self, x: torch.Tensor, target: torch.Tensor): return self.weight * self.mse(x, target) def mouth_metric(self, x: torch.Tensor, target: torch.Tensor): return self.mse(x, target), self.L1(x, target) def get_vertices( self, x: torch.Tensor, ): return x.squeeze(0) class UniTalkerLoss(nn.Module): def __init__(self, args) -> None: super().__init__() loss_config = { 'flame_params_from_dadhead': { 'class': FlameParamLossForDADHead, 'args': { 'mouth_indices_path': 'resources/binary_resources/02_flame_mouth_idx.npy', 'loss_mask_path': 'resources/binary_resources/flame_humanface_index.npy', 'flame_model_path': 'resources/binary_resources/flame.pkl', } }, '3DETF_blendshape_weight': { 'class': BlendShapeLoss, 'args': { 'mouth_indices_path': 'resources/binary_resources/03_emo_talk_mouth_idx.npy', 'blendshape_path': 'resources/binary_resources/EmoTalk.npz', 'bs_beta': args.blendshape_weight, }, }, 'inhouse_blendshape_weight': { 'class': BlendShapeLoss, 'args': { 'mouth_indices_path': 'resources/binary_resources/05_inhouse_arkit_mouth_idx.npy', 'blendshape_path': 'resources/binary_resources/inhouse_arkit.npz', 'bs_beta': args.blendshape_weight, }, }, 'FLAME_5023_vertices': { 'class': VerticesLoss, 'args': { 'mouth_indices_path': 'resources/binary_resources/02_flame_mouth_idx.npy', }, }, 'BIWI_23370_vertices': { 'class': VerticesLoss, 'args': { 'mouth_indices_path': 'resources/binary_resources/04_BIWI_mouth_idx.npy', }, }, 'FLAME_SUB_2055_vertices': { 'class': VerticesLoss, 'args': { 'mouth_indices_path': 'resources/binary_resources/mouth_idx_FLAME_5023_vertices_wo_eyes_wo_head.npy', }, }, 'meshtalk_6172_vertices': { 'class': VerticesLoss, 'args': { 'mouth_indices_path': 'resources/binary_resources/06_mesh_talk_mouth_idx.npy', }, }, '24_viseme': { 'class': ParamLoss, 'args': { 'weight': args.blendshape_weight }, } } loss_module_dict = {} for k, v in loss_config.items(): loss_module = v['class'](**v['args']) loss_module_dict[k] = loss_module self.loss_module_dict = nn.ModuleDict(loss_module_dict) self.mse = nn.MSELoss() return def forward( self, x: torch.Tensor, target: torch.Tensor, annot_type: str, ): return self.loss_module_dict[annot_type](x, target) def pca_loss(self, x: torch.Tensor, target: torch.Tensor): return self.mse(x, target) def mouth_metric(self, x: torch.Tensor, target: torch.Tensor, annot_type: str): return self.loss_module_dict[annot_type].mouth_metric(x, target) def get_vertices(self, x: torch.Tensor, annot_type: str): return self.loss_module_dict[annot_type].get_vertices(x) ================================================ FILE: main/demo.py ================================================ import librosa import numpy as np import os import torch import torch.optim import torch.utils.data from transformers import Wav2Vec2FeatureExtractor from dataset.dataset_config import dataset_config from loss.loss import UniTalkerLoss from models.unitalker import UniTalker from utils.utils import get_parser, get_template_verts, get_audio_encoder_dim def get_all_audios(audio_root:str): wav_f_names = [] for r, _, f in os.walk(audio_root): for wav in f: if wav.endswith('.wav') or wav.endswith('.mp3'): relative_path = os.path.join(r, wav) relative_path = os.path.relpath(relative_path, audio_root) wav_f_names.append(relative_path) wav_f_names = sorted(wav_f_names) return wav_f_names def split_long_audio( audio: np.ndarray, processor:Wav2Vec2FeatureExtractor ): # audio = audio.squeeze(0) a, b = 25, 5 sr = 16000 total_length = len(audio) /sr reps = max(0, int(np.ceil((total_length - a) / (a - b)))) + 1 in_audio_split_list = [] start, end = 0, int(a * sr) step = int((a - b) * sr) for i in range(reps): audio_split = audio[start:end] audio_split = np.squeeze( processor(audio_split, sampling_rate=sr).input_values) in_audio_split_list.append(audio_split) start += step end += step return in_audio_split_list def merge_out_list(out_list: list, fps:int): if len(out_list) == 1: return out_list[0] a, b = 25, 5 left_weight = np.linspace(1, 0, b * fps)[:, np.newaxis] right_weight = 1 - left_weight a = a * fps b = b * fps offset = a - b out_length = len(out_list[-1]) + offset * (len(out_list) - 1) merged_out = np.empty((out_length, out_list[-1].shape[-1]), dtype=out_list[-1].dtype) merged_out[:a] = out_list[0] for out_piece in out_list[1:]: merged_out[a - b:a] = left_weight * merged_out[ a - b:a] + right_weight * out_piece[:b] merged_out[a:a + offset] = out_piece[b:] a += offset return merged_out def main(): args = get_parser() training_dataset_name_list = args.dataset demo_dataset_name_list = args.demo_dataset condition_id_config = { 'D0': 3, 'D1': 3, 'D2': 0, 'D3': 0, 'D4': 0, 'D5': 0, 'D6': 4, 'D7': 0, } template_id_config = { 'D0': 3, 'D1': 3, 'D2': 0, 'D3': 0, 'D4': 0, 'D5': 0, 'D6': 4, 'D7': 0, } checkpoint = torch.load(args.weight_path, map_location='cpu') args.identity_num = len(checkpoint['decoder.learnable_style_emb.weight']) start_idx = 0 for dataset_name in training_dataset_name_list: annot_type = dataset_config[dataset_name]['annot_type'] id_num = dataset_config[dataset_name]['subjects'] end_idx = start_idx + id_num local_condition_idx = condition_id_config[dataset_name] template_idx = template_id_config[dataset_name] template = get_template_verts(args.data_root, dataset_name, template_idx) template_id_config[dataset_name] = torch.Tensor(template.reshape(1, -1)) if args.condition_id == 'each': condition_id_config[dataset_name] = torch.tensor( start_idx + local_condition_idx).reshape(1) elif args.condition_id == 'common': condition_id_config[dataset_name] = torch.tensor(args.identity_num - 1).reshape(1) else: try: condition_id = int(args.condition_id) condition_id_config[dataset_name] = torch.tensor( condition_id).reshape(1) except ValueError: assert args.condition_id in dataset_config.keys() condition_id_config[dataset_name] = torch.tensor( start_idx + local_condition_idx).reshape(1) start_idx = end_idx if args.random_id_prob > 0: assert end_idx == (args.identity_num - 1) else: assert end_idx == args.identity_num args.audio_encoder_feature_dim = get_audio_encoder_dim( args.audio_encoder_repo) model = UniTalker(args) model.load_state_dict(checkpoint, strict=False) model.eval() model.cuda() wav_f_names = get_all_audios(args.test_wav_dir) out_npz_path = args.test_out_path dirname = os.path.dirname(out_npz_path) os.makedirs(dirname, exist_ok=True) out_dict = {} loss_module = UniTalkerLoss(args).cuda() processor = Wav2Vec2FeatureExtractor.from_pretrained( args.audio_encoder_repo,) with torch.no_grad(): for wav_f in wav_f_names: out_dict[wav_f] = {} wav_path = os.path.join(args.test_wav_dir, wav_f) audio_data, sr = librosa.load(wav_path, sr=16000) audio_data = np.squeeze( processor(audio_data, sampling_rate=sr).input_values) audio_data_splits = split_long_audio(audio_data, processor) for dataset_name in demo_dataset_name_list: template = template_id_config[dataset_name].cuda() scale = dataset_config[dataset_name]['scale'] template = scale * template condition_id = condition_id_config[dataset_name].cuda() annot_type = dataset_config[dataset_name]['annot_type'] if annot_type == 'BIWI_23370_vertices': fps = 25 else: fps = 30 out_list = [] for audio_data in audio_data_splits: audio_data = torch.Tensor(audio_data[None]).cuda() out, _, _ = model( audio_data, template, face_motion=None, style_idx=condition_id, annot_type=annot_type, fps=fps, ) out_list.append(out.cpu().numpy().squeeze(0)) out = merge_out_list(out_list, fps) out = loss_module.get_vertices(torch.from_numpy(out).cuda(), annot_type) out_dict[wav_f][dataset_name] = out.cpu() print(f"save results to {out_npz_path}") np.savez(out_npz_path, **out_dict) if __name__ == '__main__': main() ================================================ FILE: main/render.py ================================================ """Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights on this computer program. You can only use this computer program if you have closed a license agreement with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. Any use of the computer program without a valid license is prohibited and liable to prosecution. Copyright 2019 Max-Planck- Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute for Intelligent Systems and the Max Planck Institute for Biological Cybernetics. All rights reserved. More information about VOCA is available at http://voca.is.tue.mpg.de. For comments or questions, please email us at voca@tue.mpg.de """ import cv2 import numpy as np import os import subprocess import torch from tqdm import tqdm from dataset.dataset_config import dataset_config def render_meshes(mesh_vertices: torch.Tensor, faces: torch.Tensor, img_size: tuple = (256, 256), aa_factor: int = 1, vis_bs: int = 100): assert len(mesh_vertices) == len(faces) or len(faces) == 1 if len(faces) == 1: faces = faces.repeat(len(mesh_vertices), 1, 1) x_max = y_max = mesh_vertices[..., 0:2].abs().max() from pytorch3d.renderer import ( DirectionalLights, FoVOrthographicCameras, HardPhongShader, Materials, MeshRasterizer, MeshRenderer, RasterizationSettings, TexturesVertex, ) from pytorch3d.renderer.blending import BlendParams from pytorch3d.renderer.materials import Materials from pytorch3d.structures import Meshes aa_factor = int(aa_factor) device = mesh_vertices.device verts_batch_lst = torch.split(mesh_vertices, vis_bs) faces_batch_lst = torch.split(faces, vis_bs) R = torch.eye(3).to(mesh_vertices.device) R[0, 0] = -1 R[2, 2] = -1 T = torch.zeros(3).to(mesh_vertices.device) T[2] = 10 R, T = R[None], T[None] W, H = img_size cameras = FoVOrthographicCameras( device=device, R=R, T=T, znear=0.01, zfar=3, max_x=x_max * 1.2, min_x=-x_max * 1.2, max_y=y_max * 1.2, min_y=-y_max * 1.2, ) # cameras = FoVPerspectiveCameras( # device=device, # R=R, # T=T, # znear=0.01, # zfar=3, # aspect_ratio=1, # fov=0.3, # degrees=False # ) raster_settings = RasterizationSettings( image_size=(W * aa_factor, H * aa_factor), blur_radius=0.0, faces_per_pixel=1, ) raster = MeshRasterizer( cameras=cameras, raster_settings=raster_settings, ) blend_params = BlendParams( sigma=0.0, gamma=0.0, background_color=(0.0, 0.0, 0.0)) lights = DirectionalLights( device=device, direction=((0, 0, 1), ), ambient_color=((0.3, 0.3, 0.3), ), diffuse_color=((0.6, 0.6, 0.6), ), specular_color=((0.1, 0.1, 0.1), )) materias = Materials( ambient_color=((1, 1, 1), ), diffuse_color=((1, 1, 1), ), specular_color=((1, 1, 1), ), shininess=15, device=device) shader = HardPhongShader( device=device, cameras=cameras, lights=lights, materials=materias, blend_params=blend_params) renderer = MeshRenderer( rasterizer=raster, shader=shader, ) rendered_imgs = [] for verts_batch, faces_batch in tqdm( zip(verts_batch_lst, faces_batch_lst), total=(len(verts_batch_lst))): textures = TexturesVertex(verts_features=torch.ones_like(verts_batch)) meshes = Meshes( verts=verts_batch, faces=faces_batch, textures=textures) with torch.no_grad(): imgs = renderer(meshes) rendered_imgs.append(imgs.cpu()) rendered_imgs = torch.cat(rendered_imgs, dim=0) if aa_factor > 1: rendered_imgs = rendered_imgs.permute(0, 3, 1, 2) # NHWC -> NCHW rendered_imgs = torch.nn.functional.interpolate( rendered_imgs, scale_factor=1 / aa_factor, mode='bicubic') rendered_imgs = rendered_imgs.permute(0, 2, 3, 1) # NCHW -> NHWC return rendered_imgs def read_obj(in_path): with open(in_path, 'r') as obj_file: # Read the lines of the OBJ file lines = obj_file.readlines() # Initialize empty lists for vertices and faces verts = [] faces = [] for line in lines: line = line.strip() # Remove leading/trailing whitespace elements = line.split() # Split the line into elements if len(elements) == 0: continue # Skip empty lines # Check the type of line (vertex or face) if elements[0] == 'v': # Vertex line x, y, z = map(float, elements[1:4]) # Extract the vertex coordinates verts.append((x, y, z)) # Add the vertex to the list elif elements[0] == 'f': # Face line face_indices = [ int(index.split('/')[0]) for index in elements[1:] ] # Extract the vertex indices faces.append(face_indices) # Add the face to the list return np.array(verts), np.array(faces) def pad_for_libx264(image_array): """Pad zeros if width or height of image_array is not divisible by 2. Otherwise you will get. \"[libx264 @ 0x1b1d560] width not divisible by 2 \" Args: image_array (np.ndarray): Image or images load by cv2.imread(). Possible shapes: 1. [height, width] 2. [height, width, channels] 3. [images, height, width] 4. [images, height, width, channels] Returns: np.ndarray: A image with both edges divisible by 2. """ if image_array.ndim == 2 or \ (image_array.ndim == 3 and image_array.shape[2] == 3): hei_index = 0 wid_index = 1 elif image_array.ndim == 4 or \ (image_array.ndim == 3 and image_array.shape[2] != 3): hei_index = 1 wid_index = 2 else: return image_array hei_pad = image_array.shape[hei_index] % 2 wid_pad = image_array.shape[wid_index] % 2 if hei_pad + wid_pad > 0: pad_width = [] for dim_index in range(image_array.ndim): if dim_index == hei_index: pad_width.append((0, hei_pad)) elif dim_index == wid_index: pad_width.append((0, wid_pad)) else: pad_width.append((0, 0)) values = 0 image_array = \ np.pad(image_array, pad_width, mode='constant', constant_values=values) return image_array def array_to_video( image_array: np.ndarray, output_path: str, fps=30, resolution=None, disable_log: bool = False, ) -> None: """Convert an array to a video directly, gif not supported. Args: image_array (np.ndarray): shape should be (f * h * w * 3). output_path (str): output video file path. fps (Union[int, float, optional): fps. Defaults to 30. resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]], optional): (height, width) of the output video. Defaults to None. disable_log (bool, optional): whether close the ffmepg command info. Defaults to False. Raises: FileNotFoundError: check output path. TypeError: check input array. Returns: None. """ if not isinstance(image_array, np.ndarray): raise TypeError('Input should be np.ndarray.') assert image_array.ndim == 4 assert image_array.shape[-1] == 3 if resolution: height, width = resolution width += width % 2 height += height % 2 else: image_array = pad_for_libx264(image_array) height, width = image_array.shape[1], image_array.shape[2] command = [ '/usr/bin/ffmpeg', '-y', # (optional) overwrite output file if it exists '-f', 'rawvideo', '-s', f'{int(width)}x{int(height)}', # size of one frame '-pix_fmt', 'bgr24', '-r', f'{fps}', # frames per second '-loglevel', 'error', '-threads', '4', '-i', '-', # The input comes from a pipe '-vcodec', 'libx264', '-an', # Tells FFMPEG not to expect any audio output_path, ] if not disable_log: print(f'Running \"{" ".join(command)}\"') process = subprocess.Popen( command, stdin=subprocess.PIPE, stderr=subprocess.PIPE, ) if process.stdin is None or process.stderr is None: raise BrokenPipeError('No buffer received.') index = 0 while True: if index >= image_array.shape[0]: break process.stdin.write(image_array[index].tobytes()) index += 1 process.stdin.close() process.stderr.close() process.wait() def get_obj_faces(annot_type: str): template_obj_path_dict = { '3DETF_blendshape_weight': 'resources/obj_template/3DETF_blendshape_weight.obj', 'FLAME_5023_vertices': 'resources/obj_template/FLAME_5023_vertices.obj', 'BIWI_23370_vertices': 'resources/obj_template/BIWI_23370_vertices.obj', 'flame_params_from_dadhead': 'resources/obj_template/flame_params_from_dadhead.obj', 'inhouse_blendshape_weight': 'resources/obj_template/inhouse_blendshape_weight.obj', 'meshtalk_6172_vertices': 'resources/obj_template/meshtalk_6172_vertices.obj' } faces = np.array(read_obj(template_obj_path_dict[annot_type])[1]) if faces.min() == 1: faces = faces - 1 return faces def vis_model_out_proxy(): import sys npz_path = sys.argv[1] audio_dir = sys.argv[2] out_dir = sys.argv[3] save_images = False os.makedirs(out_dir, exist_ok=True) out_dict = np.load(npz_path, allow_pickle=True) print(list(out_dict.keys())) for wav_f in out_dict.keys(): wav_path = os.path.join(audio_dir, wav_f) cur_out_dict = out_dict[wav_f] for datasetname, out in cur_out_dict.item().items(): out = torch.Tensor(out) annot_type = dataset_config[datasetname]['annot_type'] faces = get_obj_faces(annot_type) with torch.no_grad(): if save_images: out_size = (1024, 1024) else: out_size = (512, 512) img_array = render_meshes(out.cuda(), torch.from_numpy(faces[None]).cuda(), out_size) if save_images: channels = 4 else: channels = 3 img_array = (img_array[..., :channels].cpu().numpy() * 255).astype(np.uint8) out_key = f'{wav_f[:-4]}' if save_images: out_images_dir = os.path.join(out_dir, datasetname, out_key) os.makedirs(out_images_dir, exist_ok=True) for imgidx, img in enumerate(img_array): img_path = os.path.join(out_images_dir, f'{imgidx:04d}.png') cv2.imwrite(img_path, img) continue out_video_path = os.path.join(out_dir, f'{out_key}_{datasetname}.mp4') out_video_dir = os.path.dirname(out_video_path) os.makedirs(out_video_dir, exist_ok=True) tmp_path = os.path.join(out_video_dir, 'tmp.mp4') if os.path.exists(tmp_path): os.remove(tmp_path) if annot_type == 'BIWI_23370_vertices': fps = 25 else: fps = 30 print(img_array.shape, img_array.dtype) array_to_video(img_array, output_path=tmp_path, fps=fps) cmd = f'ffmpeg -i {tmp_path} -i {wav_path} -c:v copy -c:a aac -shortest -strict -2 {out_video_path} -y ' print(cmd) subprocess.run(cmd, shell=True) os.remove(tmp_path) if __name__ == '__main__': vis_model_out_proxy() ================================================ FILE: main/train.py ================================================ #!/usr/bin/env python # yapf: disable import os import torch import torch.optim import torch.utils.data from tensorboardX import SummaryWriter from dataset.dataset import get_dataloaders from loss.loss import UniTalkerLoss from models.unitalker import UniTalker from utils.utils import ( get_average_meter_dict, get_logger, get_parser, get_audio_encoder_dim, load_ckpt, seed_everything, filter_unitalker_state_dict ) from .train_eval_loop import train_epoch, validate_epoch # yapf: enable def main(): seed_everything(42) args = get_parser() args.audio_encoder_feature_dim = get_audio_encoder_dim(args.audio_encoder_repo) train_loader, val_loader, test_loader = get_dataloaders(args) args.identity_num = train_loader.dataset.get_identity_num() args.annot_type_list = train_loader.dataset.annot_type_list if args.random_id_prob > 0.0: args.identity_num = args.identity_num + 1 os.makedirs(args.save_path, exist_ok=True) log_file = os.path.join(args.save_path, 'train.log') logger = get_logger(log_file) writer = SummaryWriter(args.save_path) logger.info('=> creating model ...') model = UniTalker(args) if args.weight_path is not None: logger.info(f'=> loading model {args.weight_path} ...') load_ckpt(model, args.weight_path, args.re_init_decoder_and_head) logger.info(args) model = model.cuda() model.summary(logger) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) loss_module = UniTalkerLoss(args).cuda() args.train_meter, args.train_metric_name_list = \ get_average_meter_dict(args, 'train') args.val_meter, args.val_metric_name_list = \ get_average_meter_dict(args, 'val') args.logger = logger args.writer = writer for epoch in range(args.epochs): if args.fix_encoder_first: if epoch == 0: model.audio_encoder._freeze_wav2vec2_parameters(True) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) elif epoch == 1: model.audio_encoder._freeze_wav2vec2_parameters( args.freeze_wav2vec) model.audio_encoder.feature_extractor._freeze_parameters() optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) epoch_plus = epoch + 1 train_epoch(train_loader, model, loss_module, optimizer, epoch_plus, args) if epoch_plus % args.eval_every == 0: validate_epoch(val_loader, model, loss_module, epoch_plus, 'val', args) if epoch_plus % args.test_every == 0: validate_epoch(test_loader, model, loss_module, epoch_plus, 'test', args) if epoch_plus % args.save_every == 0: model_path = os.path.join(args.save_path, f'{epoch_plus:03d}.pt') torch.save(filter_unitalker_state_dict(model.state_dict()), model_path) if __name__ == '__main__': main() ================================================ FILE: main/train_eval_loop.py ================================================ import numpy as np import torch from tqdm import tqdm from utils.utils import log_datasetloss, write_to_tensorboard def train_epoch(data_loader, model, loss_module, optimizer, epoch, args): model.train() train_meter = args.train_meter metric_name_list = args.train_metric_name_list zero_loss = torch.tensor(0.0) mixed_dataset_name = 'mixed_dataset' pbar = tqdm(data_loader) pbar.set_description(f'Epoch: {epoch} train') for batch in pbar: data = batch['data'].cuda(non_blocking=True) template = batch['template'].cuda(non_blocking=True) audio = batch['audio'].cuda(non_blocking=True) identity = batch['id'].cuda(non_blocking=True) if args.random_id_prob > 0.0: if np.random.random() < args.random_id_prob: identity = identity.fill_(args.identity_num - 1) annot_type = batch['annot_type'][0] dataset_name = batch['dataset_name'][0] out_motion, out_pca, gt_pca = model( audio, template, data, style_idx=identity, annot_type=annot_type) rec_loss = loss_module(out_motion, data, annot_type) if out_pca is not None: pca_loss = loss_module.pca_loss(out_pca, gt_pca) else: pca_loss = zero_loss loss = rec_loss + args.pca_weight * pca_loss optimizer.zero_grad() loss.backward() optimizer.step() for k, v in zip(metric_name_list, (rec_loss, pca_loss)): train_meter[dataset_name][k].update(v.item()) train_meter[mixed_dataset_name][k].update(v.item()) log_datasetloss(args.logger, epoch, 'train', train_meter, only_mix=True) write_to_tensorboard( args.writer, epoch, 'train', train_meter, only_mix=False) for dataset_name, sub_metric_dict in train_meter.items(): for metric_name in sub_metric_dict.keys(): sub_metric_dict[metric_name].reset() return def validate_epoch(data_loader, model, loss_module, epoch, phase, args): assert phase in ('val', 'test') model.eval() val_meter = args.val_meter metric_name_list = args.val_metric_name_list mixed_dataset_name = 'mixed_dataset' with torch.no_grad(): pbar = tqdm(data_loader) pbar.set_description(f'Epoch: {epoch} {phase:<5}') for batch in pbar: data = batch['data'].cuda(non_blocking=True) template = batch['template'].cuda(non_blocking=True) audio = batch['audio'].cuda(non_blocking=True) identity = batch['id'].cuda(non_blocking=True) annot_type = batch['annot_type'][0] dataset_name = batch['dataset_name'][0] if args.random_id_prob >= 1.0: identity = identity.fill_(args.identity_num - 1) out_motion, _, _ = model( audio, template, data, style_idx=identity, annot_type=annot_type) rec_loss = loss_module(out_motion, data, annot_type) mouth_metric_L2, mouth_metric_L2_norm = loss_module.mouth_metric( out_motion, data, annot_type) for k, v in zip(metric_name_list, (rec_loss, mouth_metric_L2, mouth_metric_L2_norm)): val_meter[dataset_name][k].update(v.item()) val_meter[mixed_dataset_name][k].update(v.item()) if getattr(args, 'logger', None) is not None: log_datasetloss( args.logger, epoch, phase, val_meter, only_mix=True) if getattr(args, 'writer', None) is not None: write_to_tensorboard( args.writer, epoch, phase, val_meter, only_mix=False) for dataset_name, sub_metric_dict in val_meter.items(): for metric_name in sub_metric_dict.keys(): sub_metric_dict[metric_name].reset() return ================================================ FILE: main/train_pca.py ================================================ #!/usr/bin/env python import numpy as np from models.pca import PCA from utils.utils import get_parser, get_audio_encoder_dim def main(): args = get_parser() args.audio_encoder_feature_dim = get_audio_encoder_dim( args.wav2vec2_type) main_worker(args) def main_worker(cfg): pca_builder = PCA() from dataset.dataset import get_dataset_list train_dataset_list = get_dataset_list(cfg) for dataset in train_dataset_list: data = pca_builder.load_dataset(dataset) pca_info = pca_builder.build_incremental_PCA( data, batch_size=1095, trunc_dim=647) np.savez('mesh_talk_pca.npz', **pca_info) if __name__ == '__main__': main() ================================================ FILE: models/base_model.py ================================================ import numpy as np import torch.nn as nn class BaseModel(nn.Module): """Base class for all models.""" def __init__(self): super(BaseModel, self).__init__() # self.logger = logging.getLogger(self.__class__.__name__) def forward(self, *x): """Forward pass logic. :return: Model output """ raise NotImplementedError def freeze_model(self, do_freeze: bool = True): for param in self.parameters(): param.requires_grad = (not do_freeze) def summary(self, logger, writer=None): """Model summary.""" model_parameters = filter(lambda p: p.requires_grad, self.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) / 1e6 # Unit is Mega logger.info('===>Trainable parameters: %.3f M' % params) if writer is not None: writer.add_text('Model Summary', 'Trainable parameters: %.3f M' % params) ================================================ FILE: models/hubert.py ================================================ import numpy as np import torch from transformers import HubertModel from transformers.modeling_outputs import BaseModelOutput from typing import Optional, Tuple from .model_utils import linear_interpolation # the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501 # initialize our encoder with the pre-trained wav2vec 2.0 weights. def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray: bsz, all_sz = shape mask = np.full((bsz, all_sz), False) all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand()) all_num_mask = max(min_masks, all_num_mask) mask_idcs = [] padding_mask = attention_mask.ne(1) if attention_mask is not None else None for i in range(bsz): if padding_mask is not None: sz = all_sz - padding_mask[i].long().sum().item() num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand()) num_mask = max(min_masks, num_mask) else: sz = all_sz num_mask = all_num_mask lengths = np.full(num_mask, mask_length) if sum(lengths) == 0: lengths[0] = min(mask_length, sz - 1) min_len = min(lengths) if sz - min_len <= num_mask: min_len = sz - num_mask - 1 mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) mask_idc = np.asarray([ mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j]) ]) mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) min_len = min([len(m) for m in mask_idcs]) for i, mask_idc in enumerate(mask_idcs): if len(mask_idc) > min_len: mask_idc = np.random.choice(mask_idc, min_len, replace=False) mask[i, mask_idc] = True return mask class HubertModel(HubertModel): def __init__(self, config): super().__init__(config) def _freeze_parameters(self, do_freeze: bool = True): for param in self.parameters(): param.requires_grad = (not do_freeze) def forward( self, input_values, attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, frame_num=None, interpolate_pos: int = 0, ): self.config.output_attentions = True output_attentions = output_attentions if \ output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None \ else self.config.use_return_dict conv_features = self.feature_extractor(input_values) hidden_states = conv_features.transpose(1, 2) if interpolate_pos == 0: hidden_states = linear_interpolation( hidden_states, output_len=frame_num) if attention_mask is not None: output_lengths = self._get_feat_extract_output_lengths( attention_mask.sum(-1)) attention_mask = torch.zeros( hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device) attention_mask[(torch.arange( attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip( [-1]).bool() hidden_states = self.feature_projection(hidden_states) if self.config.apply_spec_augment and self.training: batch_size, sequence_length, hidden_size = hidden_states.size() if self.config.mask_time_prob > 0: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), self.config.mask_time_prob, self.config.mask_time_length, attention_mask=attention_mask, min_masks=2, ) hidden_states[torch.from_numpy( mask_time_indices)] = self.masked_spec_embed.to( hidden_states.dtype) if self.config.mask_feature_prob > 0: mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), self.config.mask_feature_prob, self.config.mask_feature_length, ) mask_feature_indices = torch.from_numpy( mask_feature_indices).to(hidden_states.device) hidden_states[mask_feature_indices[:, None].expand( -1, sequence_length, -1)] = 0 encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if interpolate_pos == 1: hidden_states = linear_interpolation( hidden_states, output_len=frame_num) if not return_dict: return (hidden_states, ) + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) ================================================ FILE: models/model_utils.py ================================================ # Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/main.py import math import torch import torch.nn as nn import torch.nn.functional as F # linear interpolation layer def linear_interpolation(features, output_len: int): features = features.transpose(1, 2) output_features = F.interpolate( features, size=output_len, align_corners=True, mode='linear') return output_features.transpose(1, 2) # Temporal Bias def init_biased_mask(n_head, max_seq_len, period): def get_slopes(n): def get_slopes_power_of_2(n): start = (2**(-2**-(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: closest_power_of_2 = 2**math.floor(math.log2(n)) return get_slopes_power_of_2(closest_power_of_2) + get_slopes( 2 * closest_power_of_2)[0::2][:n - closest_power_of_2] slopes = torch.Tensor(get_slopes(n_head)) bias = torch.div( torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1, period).view(-1), period, rounding_mode='floor') bias = -torch.flip(bias, dims=[0]) alibi = torch.zeros(max_seq_len, max_seq_len) for i in range(max_seq_len): alibi[i, :i + 1] = bias[-(i + 1):] alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0) mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill( mask == 1, float(0.0)) mask = mask.unsqueeze(0) + alibi return mask # Alignment Bias def enc_dec_mask(device, T, S): mask = torch.ones(T, S) for i in range(T): mask[i, i] = 0 return (mask == 1).to(device=device) # Periodic Positional Encoding class PeriodicPositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000): super(PeriodicPositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(period, d_model) position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, period, d_model) repeat_num = (max_seq_len // period) + 1 pe = pe.repeat(1, repeat_num, 1) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1), :] return self.dropout(x) class TCN(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.first_net = SeqTranslator1D( in_dim, out_dim, min_layers_num=3, residual=True, norm='ln') self.dropout = nn.Dropout(0.1) def forward(self, spectrogram, style_embed, time_steps=None): spectrogram = spectrogram spectrogram = self.dropout(spectrogram) style_embed = style_embed.unsqueeze(2) style_embed = style_embed.repeat(1, 1, spectrogram.shape[2]) spectrogram = torch.cat([spectrogram, style_embed], dim=1) x1 = self.first_net(spectrogram) # .permute(0, 2, 1) if time_steps is not None: x1 = F.interpolate( x1, size=time_steps, align_corners=False, mode='linear') return x1 class SeqTranslator1D(nn.Module): """(B, C, T)->(B, C_out, T)""" def __init__(self, C_in, C_out, kernel_size=None, stride=None, min_layers_num=None, residual=True, norm='bn'): super(SeqTranslator1D, self).__init__() conv_layers = nn.ModuleList([]) conv_layers.append( ConvNormRelu( in_channels=C_in, out_channels=C_out, type='1d', kernel_size=kernel_size, stride=stride, residual=residual, norm=norm)) self.num_layers = 1 if min_layers_num is not None and self.num_layers < min_layers_num: while self.num_layers < min_layers_num: conv_layers.append( ConvNormRelu( in_channels=C_out, out_channels=C_out, type='1d', kernel_size=kernel_size, stride=stride, residual=residual, norm=norm)) self.num_layers += 1 self.conv_layers = nn.Sequential(*conv_layers) def forward(self, x): return self.conv_layers(x) class ConvNormRelu(nn.Module): """(B,C_in,H,W) -> (B, C_out, H, W) there exist some kernel size that makes the result is not H/s. #TODO: there might some problems with residual """ def __init__(self, in_channels, out_channels, type='1d', leaky=False, downsample=False, kernel_size=None, stride=None, padding=None, p=0, groups=1, residual=False, norm='bn'): """conv-bn-relu.""" super(ConvNormRelu, self).__init__() self.residual = residual self.norm_type = norm # kernel_size = k # stride = s if kernel_size is None and stride is None: if not downsample: kernel_size = 3 stride = 1 else: kernel_size = 4 stride = 2 if padding is None: if isinstance(kernel_size, int) and isinstance(stride, tuple): padding = tuple(int((kernel_size - st) / 2) for st in stride) elif isinstance(kernel_size, tuple) and isinstance(stride, int): padding = tuple(int((ks - stride) / 2) for ks in kernel_size) elif isinstance(kernel_size, tuple) and isinstance(stride, tuple): padding = tuple( int((ks - st) / 2) for ks, st in zip(kernel_size, stride)) else: padding = int((kernel_size - stride) / 2) if self.residual: if downsample: if type == '1d': self.residual_layer = nn.Sequential( nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)) elif type == '2d': self.residual_layer = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)) else: if in_channels == out_channels: self.residual_layer = nn.Identity() else: if type == '1d': self.residual_layer = nn.Sequential( nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)) elif type == '2d': self.residual_layer = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)) in_channels = in_channels * groups out_channels = out_channels * groups if type == '1d': self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) self.norm = nn.BatchNorm1d(out_channels) self.dropout = nn.Dropout(p=p) elif type == '2d': self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) self.norm = nn.BatchNorm2d(out_channels) self.dropout = nn.Dropout2d(p=p) if norm == 'gn': self.norm = nn.GroupNorm(2, out_channels) elif norm == 'ln': self.norm = nn.LayerNorm(out_channels) if leaky: self.relu = nn.LeakyReLU(negative_slope=0.2) else: self.relu = nn.ReLU() def forward(self, x, **kwargs): if self.norm_type == 'ln': out = self.dropout(self.conv(x)) out = self.norm(out.transpose(1, 2)).transpose(1, 2) else: out = self.norm(self.dropout(self.conv(x))) if self.residual: residual = self.residual_layer(x) out += residual return self.relu(out) ================================================ FILE: models/pca.py ================================================ import numpy as np import torch import torch.nn as nn class PCALayer(nn.Module): def __init__(self, pca_info_path: str): super().__init__() pca_info = np.load(pca_info_path) feat_to_data = pca_info['components_to_data'].astype(np.float32) data_to_feat = feat_to_data.T data_bias = pca_info['original_data_mean'].astype(np.float32) feat_mean = pca_info['data_components_mean'].astype(np.float32) feat_std = pca_info['data_components_std'].astype(np.float32) self.register_buffer('feat_to_data', torch.from_numpy(feat_to_data)) self.register_buffer('data_to_feat', torch.from_numpy(data_to_feat)) self.register_buffer('data_bias', torch.from_numpy(data_bias)) self.register_buffer('feat_mean', torch.from_numpy(feat_mean)) self.register_buffer('feat_std', torch.from_numpy(feat_std)) self.pca_dim = len(feat_mean) def encode( self, x: torch.Tensor, pca_dim: int = None, ): x = x - self.data_bias feat = x @ (self.data_to_feat[:, :pca_dim]) return feat def decode(self, feat: torch.Tensor, pca_dim: int = None): data = feat[..., :pca_dim] @ (self.feat_to_data[:pca_dim]) data = data + self.data_bias return data class PCA: def __init__(self, trunc_dim: int = None): super().__init__() def buld_pca_for_dataset(self, dataset, trunc_dim: int = 1024): annot_array = self.load_dataset(dataset) self.build_pca(annot_array, trunc_dim) def load_dataset(self, dataset): annot_list = [] indices = np.arange(len(dataset)) for i in indices: data_dict = dataset.__getitem__(i) annot_list.append(data_dict['data'] - data_dict['template']) annot_array = np.concatenate(annot_list, axis=0) return annot_array def build_incremental_PCA( self, in_data: np.ndarray, trunc_dim=None, batch_size=1024, ): indices = np.arange(len(in_data)) np.random.shuffle(indices) in_data = in_data[indices] if in_data.dtype != np.float32: in_data = in_data.astype(np.float32) from sklearn.decomposition import IncrementalPCA ipca = IncrementalPCA(n_components=trunc_dim, batch_size=batch_size) data_mean = in_data.mean(0) in_data = in_data - data_mean in_data_slices = np.split( in_data, np.arange(batch_size, len(in_data), batch_size)) for batch_data in in_data_slices: if len(batch_data) != batch_size: continue ipca.partial_fit(batch_data) S = ipca.singular_values_ components_to_data = ipca.components_ explained_variance_ratio = ipca.explained_variance_ratio_ data_components = ipca.transform(in_data) data_components_mean = data_components.mean(0) data_components_std = data_components.std(0) ret_dict = { 'explained_variance_ratio': explained_variance_ratio, 'components_to_data': components_to_data, 'original_data_mean': data_mean, 'S': S, 'data_components_mean': data_components_mean, 'data_components_std': data_components_std } return ret_dict def build_pca( self, in_data: np.ndarray, trunc_dim=None, save_compactness=True, check_svd=True, ): ret_dict = {} in_data = torch.Tensor(in_data).cuda() count, in_dim = in_data.shape mean_data = in_data.mean(dim=0) in_data_center = in_data - mean_data # (count, in_dim) in_data_center = in_data_center.cpu() # U (count, count), S(count, ), Vt (count, in_dim) U, S, Vt = torch.linalg.svd(in_data_center, full_matrices=False) if check_svd: S_mat = torch.diag(S) rebuilt = U @ S_mat @ Vt is_precise = torch.allclose(rebuilt, in_data_center) print('svd rebuilt all close ?', is_precise) var_S = S**2 if save_compactness: compact_vector = var_S / var_S.sum() compact_value = compact_vector[0:trunc_dim].sum() ret_dict['compactness, ', compact_value] pca_bases = Vt[0:trunc_dim].reshape(trunc_dim, -1) ret_dict['pca_bases'] = pca_bases ret_dict['pca_std'] = S[0:trunc_dim] ret_dict['mean_data'] = mean_data return ret_dict ================================================ FILE: models/unitalker.py ================================================ import os.path as osp import torch.nn as nn from dataset.dataset_config import dataset_config from models.wav2vec2 import Wav2Vec2Model from .base_model import BaseModel # from models.hubert import HubertModel from models.wavlm import WavLMModel class OutHead(BaseModel): def __init__(self, args, out_dim: int): super().__init__() head_layer = [] in_dim = args.decoder_dimension for i in range(args.headlayer - 1): head_layer.append(nn.Linear(in_dim, in_dim)) head_layer.append(nn.Tanh()) head_layer.append(nn.Linear(in_dim, out_dim)) self.head_layer = nn.Sequential(*head_layer) def forward(self, x): return self.head_layer(x) class UniTalker(BaseModel): def __init__(self, args): super().__init__() self.args = args if 'wav2vec2' in args.audio_encoder_repo: self.audio_encoder = Wav2Vec2Model.from_pretrained( args.audio_encoder_repo) elif 'wavlm' in args.audio_encoder_repo: self.audio_encoder = WavLMModel.from_pretrained( args.audio_encoder_repo) else: raise ValueError("wrong audio_encoder_repo") self.audio_encoder.feature_extractor._freeze_parameters() if args.freeze_wav2vec: self.audio_encoder._freeze_wav2vec2_parameters() if args.decoder_type == 'conv': from .unitalker_decoder import UniTalkerDecoderTCN as Decoder elif args.decoder_type == 'transformer': from .unitalker_decoder import \ UniTalkerDecoderTransformer as Decoder else: ValueError('unknown decoder type ') self.decoder = Decoder(args) if args.use_pca: pca_layer_dict = {} from models.pca import PCALayer for dataset_name in args.dataset: if dataset_config[dataset_name]['pca']: annot_type = dataset_config[dataset_name]['annot_type'] dirname = dataset_config[dataset_name]['dirname'] pca_path = osp.join(args.data_root, dirname, 'pca.npz') pca_layer_dict[annot_type] = PCALayer(pca_path) self.pca_layer_dict = nn.ModuleDict(pca_layer_dict) else: self.pca_layer_dict = {} self.pca_dim = args.pca_dim self.use_pca = args.use_pca self.interpolate_pos = args.interpolate_pos out_head_dict = {} for dataset_name in args.dataset: annot_type = dataset_config[dataset_name]['annot_type'] out_dim = dataset_config[dataset_name]['annot_dim'] if (args.use_pca is True) and (annot_type in self.pca_layer_dict): pca_dim = self.pca_layer_dict[annot_type].pca_dim out_dim = min(args.pca_dim, out_dim, pca_dim) if args.headlayer == 1: out_projection = nn.Linear(args.decoder_dimension, out_dim) nn.init.constant_(out_projection.weight, 0) nn.init.constant_(out_projection.bias, 0) else: out_projection = OutHead(args, out_dim) out_head_dict[annot_type] = out_projection self.out_head_dict = nn.ModuleDict(out_head_dict) self.identity_num = args.identity_num return def forward(self, audio, template, face_motion, style_idx, annot_type: str, fps: float = None): if face_motion is not None: frame_num = face_motion.shape[1] else: frame_num = round(audio.shape[-1] / 16000 * fps) hidden_states = self.audio_encoder( audio, frame_num=frame_num, interpolate_pos=self.interpolate_pos) hidden_states = hidden_states.last_hidden_state decoder_out = self.decoder(hidden_states, style_idx, frame_num) if (not self.use_pca) or (annot_type not in self.pca_layer_dict): out_motion = self.out_head_dict[annot_type](decoder_out) out_motion = out_motion + template out_pca, gt_pca = None, None else: out_pca = self.out_head_dict[annot_type](decoder_out) out_motion = self.pca_layer_dict[annot_type].decode( out_pca, self.pca_dim) out_motion = out_motion + template if face_motion is not None: gt_pca = self.pca_layer_dict[annot_type].encode( face_motion - template, self.pca_dim) else: gt_pca = None return out_motion, out_pca, gt_pca ================================================ FILE: models/unitalker_decoder.py ================================================ # yapf: disable import torch import torch.nn as nn from .base_model import BaseModel from .model_utils import ( TCN, PeriodicPositionalEncoding, enc_dec_mask, init_biased_mask, linear_interpolation, ) # yapf: enable class UniTalkerDecoderTCN(BaseModel): def __init__(self, args) -> None: super().__init__() self.learnable_style_emb = nn.Embedding(args.identity_num, 64) in_dim = args.decoder_dimension out_dim = args.decoder_dimension self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, in_dim) self.tcn = TCN(in_dim + 64, out_dim) self.interpolate_pos = args.interpolate_pos def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor, frame_num: int = None): batch_size = len(hidden_states) if style_idx is None: style_embedding = self.learnable_style_emb.weight[-1].unsqueeze( 0).repeat(batch_size, 1) else: style_embedding = self.learnable_style_emb(style_idx) feature = self.audio_feature_map(hidden_states).transpose(1, 2) feature = self.tcn(feature, style_embedding) feature = feature.transpose(1, 2) if self.interpolate_pos == 2: feature = linear_interpolation(feature, output_len=frame_num) return feature class UniTalkerDecoderTransformer(BaseModel): def __init__(self, args) -> None: super().__init__() out_dim = args.decoder_dimension self.learnable_style_emb = nn.Embedding(args.identity_num, out_dim) self.audio_feature_map = nn.Linear(args.audio_encoder_feature_dim, out_dim) self.PPE = PeriodicPositionalEncoding( out_dim, period=args.period, max_seq_len=3000) self.biased_mask = init_biased_mask( n_head=4, max_seq_len=3000, period=args.period) decoder_layer = nn.TransformerDecoderLayer( d_model=out_dim, nhead=4, dim_feedforward=2 * out_dim, batch_first=True) self.transformer_decoder = nn.TransformerDecoder( decoder_layer, num_layers=1) self.interpolate_pos = args.interpolate_pos def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor, frame_num: int): obj_embedding = self.learnable_style_emb(style_idx) obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1) hidden_states = self.audio_feature_map(hidden_states) style_input = self.PPE(obj_embedding) tgt_mask = self.biased_mask[:, :style_input.shape[1], :style_input. shape[1]].clone().detach().to( device=style_input.device) memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1], frame_num) feat_out = self.transformer_decoder( style_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask) if self.interpolate_pos == 2: feat_out = linear_interpolation(feat_out, output_len=frame_num) return feat_out ================================================ FILE: models/wav2vec2.py ================================================ import numpy as np import torch from transformers import Wav2Vec2Model from transformers.modeling_outputs import BaseModelOutput from typing import Optional, Tuple from .model_utils import linear_interpolation # the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501 # initialize our encoder with the pre-trained wav2vec 2.0 weights. def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray: bsz, all_sz = shape mask = np.full((bsz, all_sz), False) all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand()) all_num_mask = max(min_masks, all_num_mask) mask_idcs = [] padding_mask = attention_mask.ne(1) if attention_mask is not None else None for i in range(bsz): if padding_mask is not None: sz = all_sz - padding_mask[i].long().sum().item() num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand()) num_mask = max(min_masks, num_mask) else: sz = all_sz num_mask = all_num_mask lengths = np.full(num_mask, mask_length) if sum(lengths) == 0: lengths[0] = min(mask_length, sz - 1) min_len = min(lengths) if sz - min_len <= num_mask: min_len = sz - num_mask - 1 mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) mask_idc = np.asarray([ mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j]) ]) mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) min_len = min([len(m) for m in mask_idcs]) for i, mask_idc in enumerate(mask_idcs): if len(mask_idc) > min_len: mask_idc = np.random.choice(mask_idc, min_len, replace=False) mask[i, mask_idc] = True return mask class Wav2Vec2Model(Wav2Vec2Model): def __init__(self, config): super().__init__(config) def _freeze_wav2vec2_parameters(self, do_freeze: bool = True): for param in self.parameters(): param.requires_grad = (not do_freeze) def forward( self, input_values, attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, frame_num=None, interpolate_pos: int = 0, ): self.config.output_attentions = True output_attentions = output_attentions if \ output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None \ else self.config.use_return_dict conv_features = self.feature_extractor(input_values) hidden_states = conv_features.transpose(1, 2) if interpolate_pos == 0: hidden_states = linear_interpolation( hidden_states, output_len=frame_num) if attention_mask is not None: output_lengths = self._get_feat_extract_output_lengths( attention_mask.sum(-1)) attention_mask = torch.zeros( hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device) attention_mask[(torch.arange( attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip( [-1]).bool() hidden_states, _ = self.feature_projection(hidden_states) if self.config.apply_spec_augment and self.training: batch_size, sequence_length, hidden_size = hidden_states.size() if self.config.mask_time_prob > 0: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), self.config.mask_time_prob, self.config.mask_time_length, attention_mask=attention_mask, min_masks=2, ) hidden_states[torch.from_numpy( mask_time_indices)] = self.masked_spec_embed.to( hidden_states.dtype) if self.config.mask_feature_prob > 0: mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), self.config.mask_feature_prob, self.config.mask_feature_length, ) mask_feature_indices = torch.from_numpy( mask_feature_indices).to(hidden_states.device) hidden_states[mask_feature_indices[:, None].expand( -1, sequence_length, -1)] = 0 encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if interpolate_pos == 1: hidden_states = linear_interpolation( hidden_states, output_len=frame_num) if not return_dict: return (hidden_states, ) + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) ================================================ FILE: models/wavlm.py ================================================ import numpy as np import torch from transformers import WavLMModel from transformers.modeling_outputs import Wav2Vec2BaseModelOutput from typing import Optional, Tuple, Union from .model_utils import linear_interpolation # the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501 # initialize our encoder with the pre-trained wav2vec 2.0 weights. class WavLMModel(WavLMModel): def __init__(self, config): super().__init__(config) def _freeze_wav2vec2_parameters(self, do_freeze: bool = True): for param in self.parameters(): param.requires_grad = (not do_freeze) def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, frame_num=None, interpolate_pos: int = 0, ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) if interpolate_pos == 0: extract_features = linear_interpolation( extract_features, output_len=frame_num) if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask( extract_features.shape[1], attention_mask, add_adapter=False ) hidden_states, extract_features = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states( hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask ) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if interpolate_pos == 1: hidden_states = linear_interpolation( hidden_states, output_len=frame_num) if self.adapter is not None: hidden_states = self.adapter(hidden_states) if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] return Wav2Vec2BaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) ================================================ FILE: utils/config.py ================================================ # ----------------------------------------------------------------------------- # Functions for parsing args # ----------------------------------------------------------------------------- import copy import os import yaml from ast import literal_eval class CfgNode(dict): """CfgNode represents an internal node in the configuration tree. It's a simple dict-like container that allows for attribute-based access to keys. """ def __init__(self, init_dict=None, key_list=None, new_allowed=False): # Recursively convert nested dictionaries in init_dict into CfgNodes init_dict = {} if init_dict is None else init_dict key_list = [] if key_list is None else key_list for k, v in init_dict.items(): if type(v) is dict: # Convert dict to CfgNode init_dict[k] = CfgNode(v, key_list=key_list + [k]) super(CfgNode, self).__init__(init_dict) def __getattr__(self, name): if name in self: return self[name] else: raise AttributeError(name) def __setattr__(self, name, value): self[name] = value def __str__(self): def _indent(s_, num_spaces): s = s_.split('\n') if len(s) == 1: return s_ first = s.pop(0) s = [(num_spaces * ' ') + line for line in s] s = '\n'.join(s) s = first + '\n' + s return s r = '' s = [] for k, v in sorted(self.items()): separator = '\n' if isinstance(v, CfgNode) else ' ' attr_str = '{}:{}{}'.format(str(k), separator, str(v)) attr_str = _indent(attr_str, 2) s.append(attr_str) r += '\n'.join(s) return r def __repr__(self): return '{}({})'.format(self.__class__.__name__, super(CfgNode, self).__repr__()) def load_cfg_from_cfg_file(file): cfg = {} assert os.path.isfile(file) and file.endswith('.yaml'), \ '{} is not a yaml file'.format(file) with open(file, 'r') as f: cfg_from_file = yaml.safe_load(f) for key in cfg_from_file: for k, v in cfg_from_file[key].items(): cfg[k] = v cfg = CfgNode(cfg) return cfg def merge_cfg_from_list(cfg, cfg_list): new_cfg = copy.deepcopy(cfg) assert len(cfg_list) % 2 == 0 for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): subkey = full_key.split('.')[-1] assert subkey in cfg, 'Non-existent key: {}'.format(full_key) value = _decode_cfg_value(v) value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey, full_key) setattr(new_cfg, subkey, value) return new_cfg def _decode_cfg_value(v): """Decodes a raw config value (e.g., from a yaml config files or command line argument) into a Python object.""" # All remaining processing is only applied to strings if not isinstance(v, str): return v # Try to interpret `v` as a: # string, number, tuple, list, dict, boolean, or None try: v = literal_eval(v) # The following two excepts allow v to pass through when it represents a # string. # # Longer explanation: # The type of v is always a string (before calling literal_eval), but # sometimes it *represents* a string and other times a data structure, like # a list. In the case that v represents a string, what we got back from the # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is # ok with '"foo"', but will raise a ValueError if given 'foo'. In other # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval # will raise a SyntaxError. except ValueError: pass except SyntaxError: pass return v def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): """Checks that `replacement`, which is intended to replace `original` is of the right type. The type is correct if it matches exactly or is one of a few cases in which the type can be easily coerced. """ original_type = type(original) replacement_type = type(replacement) # The types must match (with some exceptions) if replacement_type == original_type or original is None: return replacement # types match from_type and to_type def conditional_cast(from_type, to_type): if replacement_type == from_type and original_type == to_type: return True, to_type(replacement) else: return False, None # Conditionally casts # list <-> tuple casts = [(tuple, list), (list, tuple)] # For py2: allow converting from str (bytes) to a unicode string try: casts.append((str, unicode)) # noqa: F821 except Exception: pass for (from_type, to_type) in casts: converted, converted_value = conditional_cast(from_type, to_type) if converted: return converted_value raise ValueError( 'Type mismatch ({} vs. {}) with values ({} vs. {}) for config ' 'key: {}'.format(original_type, replacement_type, original, replacement, full_key)) ================================================ FILE: utils/utils.py ================================================ import argparse import logging import numpy as np import os import random import torch from copy import deepcopy from . import config def get_audio_encoder_dim(audio_encoder: str): if audio_encoder == 'microsoft/wavlm-base-plus': return 768 elif audio_encoder == 'facebook/wav2vec2-large-xlsr-53': return 1024 else: raise ValueError("wrong audio_encoder") def filer_list(in_list: list, key: str): ret_list = [] for line in in_list: if line.startswith(key): ret_list.append(line) return ret_list def count_checkpoint_params(ckpt: dict): params = 0 pca_params = 0 for k, v in ckpt.items(): if 'pca' in k: print(k) pca_params += np.prod(v.size()) params += np.prod(v.size()) return params, pca_params def read_obj(in_path): with open(in_path, 'r') as obj_file: # Read the lines of the OBJ file lines = obj_file.readlines() # Initialize empty lists for vertices and faces verts = [] faces = [] for line in lines: line = line.strip() # Remove leading/trailing whitespace elements = line.split() # Split the line into elements if len(elements) == 0: continue # Skip empty lines # Check the type of line (vertex or face) if elements[0] == 'v': # Vertex line x, y, z = map(float, elements[1:4]) # Extract the vertex coordinates verts.append((x, y, z)) # Add the vertex to the list elif elements[0] == 'f': # Face line face_indices = [ int(index.split('/')[0]) for index in elements[1:] ] # Extract the vertex indices faces.append(face_indices) # Add the face to the list verts = np.array(verts) faces = np.array(faces) if faces.min() == 1: faces = faces - 1 return verts, faces def get_logger(log_file: str): logger_name = 'main-logger' logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) handler = logging.FileHandler(log_file, mode='w') fmt = '[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d]=>%(message)s' # noqa: E501 handler.setFormatter(logging.Formatter(fmt)) for hdl in logger.handlers: logger.removeHandler(hdl) logger.addHandler(handler) return logger def get_template_dict(data_root, annot_type: str): import pickle if annot_type == 'BIWI_23370_vertices': template_pkl_path = os.path.join(data_root, 'BIWI', 'templates.pkl') elif annot_type == 'FLAME_5023_vertices': template_pkl_path = os.path.join(data_root, 'vocaset', 'templates.pkl') with open(template_pkl_path, 'rb') as f: template_dict = pickle.load(f, encoding='latin') return template_dict def filter_unitalker_state_dict(in_dict): out_dict = {} for k, v in in_dict.items(): if not 'pca_layer_dict' in k: out_dict[k] = v return out_dict def get_template_verts(data_root, dataset_name: str, subject: int): from dataset.dataset_config import dataset_config template_path = os.path.join(data_root, dataset_config[dataset_name]['dirname'], 'id_template.npy') return np.load(template_path)[subject] def load_ckpt(model, ckpt_path, re_init_decoder_and_head: bool = False): checkpoint = torch.load(ckpt_path, map_location='cpu') id_embedding_key = 'decoder.learnable_style_emb.weight' if len(checkpoint[id_embedding_key]) != len( model.state_dict()[id_embedding_key]): target_id_num = len(model.state_dict()[id_embedding_key]) checkpoint[id_embedding_key] = checkpoint[id_embedding_key][-1][ None].repeat(target_id_num, 1) if re_init_decoder_and_head is True: ckpt_keys = list(checkpoint.keys()) encoder_keys = filer_list(ckpt_keys, 'audio_encoder') reinit_keys = list(set(ckpt_keys) - set(encoder_keys)) for k in reinit_keys: if k in model.state_dict(): checkpoint[k] = model.state_dict()[k] model.load_state_dict(checkpoint, strict=False) return class AverageMeter(object): """Computes and stores the average and current value.""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def get_avg(self, ): return self.avg def get_average_meter_dict(args, phase): if phase == 'train': metric_name_list = ['rec_L2_loss', 'pca_loss'] elif phase == 'val' or phase == 'test': metric_name_list = ['rec_L2_loss', 'LVE', 'LVD'] else: raise ValueError('wrong phase for average meter') average_meter_dict_template = {} for k in metric_name_list: average_meter_dict_template[k] = AverageMeter() mixed_dataset_name = 'mixed_dataset' dataset_name_list = args.dataset + [mixed_dataset_name] average_meter_dict = {} for dataset_name in dataset_name_list: average_meter_dict[dataset_name] = deepcopy( average_meter_dict_template) return average_meter_dict, metric_name_list def log_datasetloss( logger, epoch: int, phase: str, dataset_loss_dict: dict, only_mix: bool = True, ): base_str = f'{phase.upper():<5} Epoch: {epoch:03d} ' for dataset_name, loss_info in dataset_loss_dict.items(): if only_mix and dataset_name != 'mixed_dataset': continue for loss_type, avg_meter in loss_info.items(): base_str += f'{loss_type}_{dataset_name}:{avg_meter.avg: .8f}, ' logger.info(base_str) return base_str def write_to_tensorboard(writer, epoch: int, phase: str, dataset_loss_dict: dict, only_mix: bool = False): for dataset_name, loss_info in dataset_loss_dict.items(): if only_mix and dataset_name != 'mixed_dataset': continue for loss_type, avg_meter in loss_info.items(): key = f'{phase}/{loss_type}_{dataset_name}' writer.add_scalar(key, avg_meter.avg, epoch) return 0 def get_parser(): parser = argparse.ArgumentParser(description=' ') parser.add_argument( '--config', type=str, default='config/unitalker.yaml', help='config file') parser.add_argument( '--condition_id', type=str, default='common', help='condition id') parser.add_argument( 'opts', help=' ', default=None, nargs=argparse.REMAINDER) args = parser.parse_args() assert args.config is not None cfg = config.load_cfg_from_cfg_file(args.config) if args.opts is not None: cfg = config.merge_cfg_from_list(cfg, args.opts) cfg.condition_id = args.condition_id cfg.dataset = cfg.dataset.split(',') cfg.demo_dataset = cfg.demo_dataset.split(',') if isinstance(cfg.duplicate_list, str): cfg.duplicate_list = cfg.duplicate_list.split(',') cfg.duplicate_list = [int(i) for i in cfg.duplicate_list] return cfg def seed_everything(seed: int = 42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False return