Showing preview only (5,345K chars total). Download the full file or copy to clipboard to get everything.
Repository: theEricMa/DiffSpeaker
Branch: main
Commit: 0f2ae2d4e88f
Files: 66
Total size: 5.1 MB
Directory structure:
gitextract_v4nqs5sn/
├── README.md
├── alm/
│ ├── callback/
│ │ ├── __init__.py
│ │ └── progress.py
│ ├── config.py
│ ├── data/
│ │ ├── BIWI/
│ │ │ ├── __init__.py
│ │ │ └── dataset.py
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── biwi.py
│ │ ├── get_data.py
│ │ ├── voca/
│ │ │ ├── __init__.py
│ │ │ └── dataset.py
│ │ └── vocaset.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── architectures/
│ │ │ ├── __init__.py
│ │ │ ├── adpt_bias_denoiser.py
│ │ │ └── tools/
│ │ │ ├── embeddings.py
│ │ │ ├── transformer_adpt.py
│ │ │ └── utils.py
│ │ ├── get_model.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ ├── utils.py
│ │ │ └── voca.py
│ │ └── modeltype/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── diffusion_bias.py
│ └── utils/
│ ├── __init__.py
│ ├── demo_utils.py
│ ├── logger.py
│ └── temos_utils.py
├── configs/
│ ├── assets/
│ │ ├── biwi.yaml
│ │ └── vocaset.yaml
│ ├── base.yaml
│ └── diffusion/
│ ├── biwi/
│ │ ├── diffspeaker_hubert_biwi.yaml
│ │ └── diffspeaker_wav2vec2_biwi.yaml
│ ├── diffusion_bias_modules/
│ │ ├── denoiser.yaml
│ │ └── scheduler.yaml
│ └── vocaset/
│ ├── diffspeaker_hubert_vocaset.yaml
│ └── diffspeaker_wav2vec2_vocaset.yaml
├── datasets/
│ ├── biwi/
│ │ ├── README.md
│ │ ├── regions/
│ │ │ ├── fdd.txt
│ │ │ └── lve.txt
│ │ ├── templates/
│ │ │ └── BIWI.ply
│ │ └── templates.pkl
│ └── vocaset/
│ ├── FLAME_masks.pkl
│ ├── README.md
│ ├── templates/
│ │ ├── FLAME_sample.ply
│ │ └── README.md
│ └── templates.pkl
├── demo_biwi.py
├── demo_vocaset.py
├── demo_vocaset_text.py
├── eval_biwi.py
├── eval_vocaset.py
├── requirements.txt
├── scripts/
│ ├── demo/
│ │ ├── demo_biwi.sh
│ │ └── demo_vocaset.sh
│ └── diffusion/
│ ├── biwi_evaluation/
│ │ ├── diffspeaker_hubert_biwi.sh
│ │ └── diffspeaker_wav2vec2_biwi.sh
│ ├── biwi_training/
│ │ ├── diffspeaker_hubert_biwi.sh
│ │ └── diffspeaker_wav2vec2_biwi.sh
│ ├── vocaset_evaluation/
│ │ ├── diffspeaker_hubert_vocaset.sh
│ │ └── diffspeaker_wav2vec2_vocaset.sh
│ └── vocaset_training/
│ ├── diffspeaker_hubert_vocaset.sh
│ └── diffspeaker_wav2vec2_vocaset.sh
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# DiffSpeaker: Speech-Driven 3D Facial Animation with Diffusion Transformer
## [Paper](https://arxiv.org/pdf/2402.05712.pdf) | [Demo](https://www.youtube.com/watch?v=4-NBygHePk0)
## Update
- [30/03/2024]: The evaluation code is updated.
- [07/02/2024]: The inference script is released.
- [06/02/2024]: The model weight is released.
## Get started
### Environment Setup
```
conda create --name diffspeaker python=3.9
conda activate diffspeaker
```
Install MPI-IS. Follow the command in [MPI-IS](https://github.com/MPI-IS/mesh) to install the package. Depending on if you have `/usr/include/boost/` directories, The command is likely to be
```
git clone https://github.com/MPI-IS/mesh.git
cd mesh
sudo apt-get install libboost-dev
python -m pip install pip==20.2.4
BOOST_INCLUDE_DIRS=/usr/include/boost/ make all
python -m pip install --upgrade pip
```
Then install the rest of the dependencies.
```
cd ..
git clone https://github.com/theEricMa/DiffSpeaker.git
cd DiffSpeaker
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install imageio-ffmpeg
pip install -r requirements.txt
```
### Model Weights
You can access the model parameters by clicking [here](https://drive.google.com/drive/folders/1PezaNpQHIjyE8UE5YW0jpDPV8jtepxSL?usp=sharing). Place the `checkpoints` folder into the root directory of your project. This folder includes the models that have been trained on the `BIWI` and `vocaset` datasets, utilizing `wav2vec` and `hubert` as the backbones.
### Prediction
For the BIWI model, use the script below to perform inference on your chosen audio files. Specify the audio file using the `--example` argument.
```
sh scripts/demo/demo_biwi.sh
```
For the vocaset model, run the following script.
```
sh scripts/demo/demo_vocaset.sh
```
### Evaluation
To obtain the metrics reported in the paper, use the scripts in `scripts/diffusion/biwi_evaluation` and `scripts/diffusion/vocaset_evaluation`. For example, to evaluate DiffSpeaker in BIWI dataset with the hubert backbone, use the following script.
```
sh scripts/diffusion/biwi_evaluation/diffspeaker_hubert_biwi.sh
```
## Training
### Data Preparation
### Model Training
```
mkdir experiments
```
================================================
FILE: alm/callback/__init__.py
================================================
from .progress import ProgressLogger
================================================
FILE: alm/callback/progress.py
================================================
import logging
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
import psutil
logger = logging.getLogger()
class ProgressLogger(Callback):
def __init__(self, metric_monitor: dict, precision: int = 3):
# Metric to monitor
self.metric_monitor = metric_monitor
self.precision = precision
def on_train_start(self, trainer: Trainer, pl_module: LightningModule,
**kwargs) -> None:
logger.info("Training started")
def on_train_end(self, trainer: Trainer, pl_module: LightningModule,
**kwargs) -> None:
logger.info("Training done")
def on_validation_epoch_end(self, trainer: Trainer,
pl_module: LightningModule, **kwargs) -> None:
if trainer.sanity_checking:
logger.info("Sanity checking ok.")
def on_train_epoch_end(self,
trainer: Trainer,
pl_module: LightningModule,
padding=False,
**kwargs) -> None:
metric_format = f"{{:.{self.precision}e}}"
line = f"Epoch {trainer.current_epoch}"
if padding:
line = f"{line:>{len('Epoch xxxx')}}" # Right padding
metrics_str = []
losses_dict = trainer.callback_metrics
for metric_name, dico_name in self.metric_monitor.items():
if dico_name in losses_dict:
metric = losses_dict[dico_name].item()
metric = metric_format.format(metric)
metric = f"{metric_name} {metric}"
metrics_str.append(metric)
if len(metrics_str) == 0:
return
memory = f"Memory {psutil.virtual_memory().percent}%"
line = line + ": " + " ".join(metrics_str) + " " + memory
logger.info(line)
================================================
FILE: alm/config.py
================================================
import importlib
from argparse import ArgumentParser
from omegaconf import OmegaConf
import os
def get_module_config(cfg_model, path="modules"):
files = os.listdir(f'./configs/{path}/')
for file in files:
if file.endswith('.yaml'):
with open(f'./configs/{path}/' + file, 'r') as f:
cfg_model.merge_with(OmegaConf.load(f))
return cfg_model
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def parse_args(phase="train"):
parser = ArgumentParser()
group = parser.add_argument_group("Training options")
if phase in ["train", "test", "demo"]:
group.add_argument(
"--cfg",
type=str,
required=False,
default="./configs/config.yaml",
help="config file",
)
group.add_argument(
"--cfg_assets",
type=str,
required=False,
default="./configs/assets.yaml",
help="config file for asset paths",
)
group.add_argument(
"--frame_rate",
type=float,
default=30,
help="the frame rate for the input/output motion",
)
group.add_argument(
"--resume",
type=str,
required=False,
help="resume from a checkpoint",
)
group.add_argument("--batch_size",
type=int,
required=False,
help="training batch size")
group.add_argument("--device",
type=int,
nargs="+",
required=False,
help="training device")
group.add_argument("--nodebug",
action="store_true",
required=False,
help="debug or not")
group.add_argument("--dir",
type=str,
required=False,
help="evaluate existing npys")
if phase == "demo":
# group.add_argument("--motion_transfer", action='store_true', help="Motion Distribution Transfer")
group.add_argument("--render",
action="store_true",
help="Render visulizaed figures")
group.add_argument("--render_mode", type=str, help="video or sequence")
group.add_argument(
"--example",
type=str,
required=False,
help="input text and lengths with txt format",
)
group.add_argument(
"--out_dir",
type=str,
required=False,
help="output dir",
)
group.add_argument(
"--template",
type=str,
required=False,
help="template path",
)
group.add_argument(
"--checkpoint",
type=str,
required=True,
help="output seperate or combined npy file",
)
group.add_argument(
"--id",
type=str,
required=True,
help="the candiate subect identity",
)
group.add_argument(
"--ply",
type=str,
required=True,
help="the candiate subect identity",
)
if phase == "render":
group.add_argument(
"--cfg",
type=str,
required=False,
default="./configs/render.yaml",
help="config file",
)
group.add_argument(
"--cfg_assets",
type=str,
required=False,
default="./configs/assets.yaml",
help="config file for asset paths",
)
# group.add_argument("--motion_transfer", action='store_true', help="Motion Distribution Transfer")
group.add_argument("--dir",
type=str,
required=False,
default=None,
help="npy motion folder")
# remove None params, and create a dictionnary
params = parser.parse_args()
# params = {key: val for key, val in vars(opt).items() if val is not None}
# update config from files
cfg_base = OmegaConf.load('./configs/base.yaml')
cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg))
cfg_assets = OmegaConf.load(params.cfg_assets)
cfg_model = get_module_config(cfg_exp.model, cfg_exp.model.target)
cfg = OmegaConf.merge(cfg_exp, cfg_model, cfg_assets)
if phase in ["train", "test"]:
cfg.TRAIN.BATCH_SIZE = (params.batch_size
if params.batch_size else cfg.TRAIN.BATCH_SIZE)
cfg.DEVICE = params.device if params.device else cfg.DEVICE
cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG
# resume from a checkpoint, added by Zhiyuan Ma
cfg.TRAIN.RESUME = params.resume if params.resume else cfg.TRAIN.RESUME
# no debug in test
cfg.DEBUG = False if phase == "test" else cfg.DEBUG
if phase == "test":
cfg.DEBUG = False
cfg.DEVICE = [0]
print("Force no debugging and one gpu when testing")
cfg.TEST.TEST_DIR = params.dir if params.dir else cfg.TEST.TEST_DIR
if phase == "demo":
# cfg.DEMO.MOTION_TRANSFER = params.motion_transfer
cfg.DEMO.RENDER = params.render
cfg.DEMO.FRAME_RATE = params.frame_rate
cfg.DEMO.EXAMPLE = params.example
cfg.DEMO.CHECKPOINTS = params.checkpoint
cfg.DEMO.TEMPLATE = params.template
cfg.DEMO.ID = params.id
cfg.DEMO.PLY = params.ply
cfg.TEST.FOLDER = params.out_dir if params.dir else cfg.TEST.FOLDER
if phase == "render":
if params.npy:
cfg.RENDER.NPY = params.npy
cfg.RENDER.INPUT_MODE = "npy"
if params.dir:
cfg.RENDER.DIR = params.dir
cfg.RENDER.INPUT_MODE = "dir"
cfg.RENDER.JOINT_TYPE = params.joint_type
cfg.RENDER.MODE = params.mode
# debug mode
if cfg.DEBUG:
cfg.NAME = "debug--" + cfg.NAME
cfg.LOGGER.WANDB.OFFLINE = True
cfg.LOGGER.VAL_EVERY_STEPS = 1
return cfg
================================================
FILE: alm/data/BIWI/__init__.py
================================================
from .dataset import BIWIDataset
================================================
FILE: alm/data/BIWI/dataset.py
================================================
import numpy as np
import torch
from torch.utils import data
from transformers import Wav2Vec2Processor
from collections import defaultdict
import os
from tqdm import tqdm
import numpy as np
import pickle
class BIWIDataset(data.Dataset):
def __init__(self,
data,
subjects_dict,
data_type="train",
):
self.data = data
self.len = len(self.data)
self.subjects_dict = subjects_dict
self.data_type = data_type
self.one_hot_labels = np.eye(len(subjects_dict["train"]))
self.repeat = 20 if data_type == 'train' else 1
def __len__(self):
return self.len * self.repeat
def __getitem__(self, index):
index = index % self.len
# seq_len, fea_dim
file_name = self.data[index]["name"]
file_path = self.data[index]["path"]
audio = self.data[index]["audio"]
vertice = self.data[index]["vertice"]
template = self.data[index]["template"]
if self.data_type == "train":
subject = "_".join(file_name.split("_")[:-1])
one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)]
elif self.data_type == "val":
one_hot = self.one_hot_labels
elif self.data_type == "test":
subject = "_".join(file_name.split("_")[:-1])
if subject in self.subjects_dict["train"]:
one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)]
else:
one_hot = self.one_hot_labels
return {
'audio':torch.FloatTensor(audio),
'audio_attention':torch.ones_like(torch.Tensor(audio)).long(),
'vertice':torch.FloatTensor(vertice),
'vertice_attention':torch.ones_like(torch.Tensor(vertice)[..., 0]).long(),
'template':torch.FloatTensor(template),
'id':torch.FloatTensor(one_hot),
'file_name':file_name,
'file_path':file_path
}
================================================
FILE: alm/data/__init__.py
================================================
================================================
FILE: alm/data/base.py
================================================
from os.path import join as pjoin
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
class BASEDataModule(pl.LightningDataModule):
def __init__(self, collate_fn, batch_size: int, num_workers: int):
super().__init__()
# self.dataloader_options = {
# "batch_size": batch_size, "num_workers": num_workers,"collate_fn": collate_datastruct_and_text}
self.dataloader_options = {
"batch_size": batch_size,
"num_workers": num_workers,
"collate_fn": collate_fn,
"prefetch_factor": 2,
"pin_memory": True,
}
# self.collate_fn = collate_fn
self.persistent_workers = True
self.is_mm = False
# need to be overloaded:
# - self.Dataset
# - self._sample_set => load only a small subset
# There is an helper bellow (get_sample_set)
# - self.nfeats
# - self.transforms
def __getattr__(self, item):
# train_dataset/val_dataset etc cached like properties
# question
if item.endswith("_dataset") and not item.startswith("_"):
subset = item[:-len("_dataset")]
item_c = "_" + item
if item_c not in self.__dict__:
# todo: config name not consistent
subset = subset.upper() if subset != "val" else "EVAL"
split = eval(f"self.cfg.{subset}.SPLIT")
split_file = pjoin(
eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
eval(f"self.cfg.{subset}.SPLIT") + ".txt",
)
self.__dict__[item_c] = self.Dataset(split_file=split_file,
split=split,
**self.hparams)
return getattr(self, item_c)
classname = self.__class__.__name__
raise AttributeError(f"'{classname}' object has no attribute '{item}'")
def setup(self, stage=None):
self.stage = stage
# Use the getter the first time to load the data
if stage in (None, "fit"):
_ = self.train_dataset
_ = self.val_dataset
if stage in (None, "test"):
_ = self.test_dataset
def train_dataloader(self):
return DataLoader(
self.train_dataset,
shuffle=True,
persistent_workers=True,
**self.dataloader_options,
)
def predict_dataloader(self):
dataloader_options = self.dataloader_options.copy()
dataloader_options[
"batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
dataloader_options["shuffle"] = False
return DataLoader(
self.test_dataset,
persistent_workers=True,
**dataloader_options,
)
def val_dataloader(self):
# overrides batch_size and num_workers
dataloader_options = self.dataloader_options.copy()
dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
dataloader_options["shuffle"] = False
return DataLoader(
self.val_dataset,
persistent_workers=True,
**dataloader_options,
)
def test_dataloader(self):
# overrides batch_size and num_workers
dataloader_options = self.dataloader_options.copy()
dataloader_options[
"batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
# dataloader_options["drop_last"] = True
dataloader_options["shuffle"] = False
return DataLoader(
self.test_dataset,
persistent_workers=True,
**dataloader_options,
)
================================================
FILE: alm/data/biwi.py
================================================
from .base import BASEDataModule
from transformers import Wav2Vec2Processor
from collections import defaultdict
from .BIWI import BIWIDataset
import os
from os.path import join as pjoin
import pickle
from tqdm import tqdm
import librosa
import numpy as np
from multiprocessing import Pool
def load_data(args):
file, root_dir, processor, templates, audio_dir, vertice_dir = args
if file.endswith('wav'):
wav_path = os.path.join(root_dir, audio_dir, file)
speech_array, sampling_rate = librosa.load(wav_path, sr=16000)
input_values = np.squeeze(processor(speech_array,sampling_rate=16000).input_values)
key = file.replace("wav", "npy")
result = {}
result["audio"] = input_values
subject_id = "_".join(key.split("_")[:-1])
temp = templates[subject_id]
result["name"] = file.replace(".wav", "")
result["path"] = os.path.abspath(wav_path)
result["template"] = temp.reshape((-1))
vertice_path = os.path.join(root_dir, vertice_dir, file.replace("wav", "npy"))
if not os.path.exists(vertice_path):
return None
else:
result["vertice"] = np.load(vertice_path,allow_pickle=True) # we do not need to [::2,:], as did in vocaset
return (key, result)
class BIWIDataModule(BASEDataModule):
def __init__(self,
cfg,
batch_size,
num_workers,
collate_fn = None,
phase="train",
**kwargs):
super().__init__(batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn)
self.save_hyperparameters(logger=False)
self.name = 'VOCASET'
self.Dataset = BIWIDataset # this one is the same
self.cfg = cfg
# customized to VOCASET
self.subjects = {
'train': [
"F2",
"F3",
"F4",
"M3",
"M4",
"M5",
],
'val': [
"F2",
"F3",
"F4",
"M3",
"M4",
"M5",
],
# 'test': [ # for BIWI test B
# "F1",
# "F5",
# "F6",
# "F7",
# "F8",
# "M1",
# "M2",
# "M6"
# ]
'test': [ # for BIWI test A
"F2",
"F3",
"F4",
"M3",
"M4",
"M5",
]
}
self.root_dir = kwargs.get('data_root', 'datasets/BIWI')
self.audio_dir = 'wav'
self.vertice_dir = 'vertices_npy'
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
self.template_file = 'templates.pkl'
self.nfeats = 70110
# load
data = defaultdict(dict)
with open(os.path.join(self.root_dir, self.template_file), 'rb') as fin:
templates = pickle.load(fin, encoding='latin1')
count = 0
args_list = []
for r, ds, fs in os.walk(os.path.join(self.root_dir, self.audio_dir)):
for f in fs:
args_list.append((f, self.root_dir, processor, templates, self.audio_dir, self.vertice_dir, ))
# # comment off for full dataset
# count += 1
# if count > 50:
# break
# split dataset
self.data_splits = {
'train':[],
'val':[],
'test':[],
}
splits = {
'train':range(1,33),
'val':range(33,37),
'test':range(37,41)
}
motion_list = []
with Pool(processes=10) as pool:
results = pool.map(load_data, args_list)
for result in results:
if result is not None:
key, value = result
data[key] = value
# motion = value["vertice"] - value["template"]
# motion_list.append(motion)
# # calculate mean and std
# import pdb; pdb.set_trace()
# motion_list = np.concatenate(motion_list, axis=0)
# self.mean = np.mean(motion_list, axis=0)
# self.std = np.std(motion_list, axis=0)
for k, v in data.items():
subject_id = "_".join(k.split("_")[:-1])
sentence_id = int(k.split(".")[0][-2:])
for sub in ['train', 'val', 'test']:
if subject_id in self.subjects[sub] and sentence_id in splits[sub]:
self.data_splits[sub].append(v)
# split dataset
self.data_splits = {
'train':[],
'val':[],
'test':[],
}
# splits = {
# 'train':range(1,41),
# 'val':range(21,41),
# 'test':range(21,41)
# }
for k, v in data.items():
subject_id = "_".join(k.split("_")[:-1])
sentence_id = int(k.split(".")[0][-2:])
for sub in ['train', 'val', 'test']:
if subject_id in self.subjects[sub] and sentence_id in splits[sub]:
self.data_splits[sub].append(v)
# self._sample_set = self.__getattr__("test_dataset")
def __getattr__(self, item):
# train_dataset/val_dataset etc cached like properties``
# question
if item.endswith("_dataset") and not item.startswith("_"):
subset = item[:-len("_dataset")]
item_c = "_" + item
if item_c not in self.__dict__:
# todo: config name not consistent
self.__dict__[item_c] = self.Dataset(
data = self.data_splits[subset] ,
subjects_dict = self.subjects,
data_type = subset
)
return getattr(self, item_c)
classname = self.__class__.__name__
raise AttributeError(f"'{classname}' object has no attribute '{item}'")
================================================
FILE: alm/data/get_data.py
================================================
import numpy as np
import torch
def collate_tensors(batch):
dims = batch[0].dim()
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
size = (len(batch), ) + tuple(max_size)
canvas = batch[0].new_zeros(size=size)
for i, b in enumerate(batch):
sub_tensor = canvas[i]
for d in range(dims):
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
sub_tensor.add_(b)
return canvas
def vocaset_collate_fn(batch):
notnone_batches = [b for b in batch if b is not None]
# notnone_batches.sort(key=lambda x: x['vertice_length'], reverse=True)
adapted_batch = {
'audio': collate_tensors([b['audio'].float() for b in notnone_batches]),
'audio_attention': collate_tensors([b['audio_attention'] for b in notnone_batches]),
'vertice': collate_tensors([b['vertice'].float() for b in notnone_batches]),
'vertice_attention': collate_tensors([b['vertice_attention'] for b in notnone_batches]),
'template': collate_tensors([b['template'].float() for b in notnone_batches]),
'id': collate_tensors([b['id'].float() for b in notnone_batches]),
'file_name': [b['file_name'] for b in notnone_batches],
'file_path': [b['file_path'] for b in notnone_batches],
}
return adapted_batch
def voxcelebinsta_collate_fn(batch):
notnone_batches = [b for b in batch if b is not None]
# notnone_batches.sort(key=lambda x: x['vertice_length'], reverse=True)
adapted_batch = {
'audio': collate_tensors([b['audio'].float() for b in notnone_batches]),
'audio_attention': collate_tensors([b['audio_attention'] for b in notnone_batches]),
'vertice': collate_tensors([b['vertice'].float() for b in notnone_batches]),
'vertice_attention': collate_tensors([b['vertice_attention'] for b in notnone_batches]),
'template': collate_tensors([b['template'].float() for b in notnone_batches]),
'id': collate_tensors([b['id'].float() for b in notnone_batches]),
'file_name': [b['file_name'] for b in notnone_batches],
}
if 'pose' in notnone_batches[0]:
adapted_batch['pose'] = collate_tensors([b['pose'].float() for b in notnone_batches])
adapted_batch['pose_attention'] = collate_tensors([b['pose_attention'] for b in notnone_batches])
if 'exp' in notnone_batches[0]:
adapted_batch['exp'] = collate_tensors([b['exp'].float() for b in notnone_batches])
adapted_batch['exp_attention'] = collate_tensors([b['exp_attention'] for b in notnone_batches])
if 'image' in notnone_batches[0]:
adapted_batch['image'] = collate_tensors([b['image'].float() for b in notnone_batches])
adapted_batch['image_attention'] = collate_tensors([b['image_attention'] for b in notnone_batches])
if 'depth' in notnone_batches[0]:
adapted_batch['depth'] = collate_tensors([b['depth'].float() for b in notnone_batches])
adapted_batch['depth_attention'] = collate_tensors([b['depth_attention'] for b in notnone_batches])
if 'seg' in notnone_batches[0]:
adapted_batch['seg'] = collate_tensors([b['seg'].float() for b in notnone_batches])
adapted_batch['seg_attention'] = collate_tensors([b['seg_attention'] for b in notnone_batches])
return adapted_batch
def voxcelebinstacoeflmdb_collate_fn(batch):
notnone_batches = [b for b in batch if b is not None]
# notnone_batches.sort(key=lambda x: x['vertice_length'], reverse=True)
adpated_batch = {
### audio related ##########################################
'audio': collate_tensors([b['audio'].float() for b in notnone_batches]),
'audio_attention': collate_tensors([b['audio_attention'] for b in notnone_batches]),
### none-predictive features ###############################
'vertice': collate_tensors([b['vertice'].float() for b in notnone_batches]),
'shape': collate_tensors([b['flame_shape'].float() for b in notnone_batches]),
'template': collate_tensors([b['template'].float() for b in notnone_batches]),
'id': collate_tensors([b['id'].float() for b in notnone_batches]),
'coefficient_attention': collate_tensors([b['coef_attention'] for b in notnone_batches]),
'file_name': [b['file_name'] for b in notnone_batches],
### predictive features ####################################
'exp': collate_tensors([b['flame_exp'].float() for b in notnone_batches]),
'jaw': collate_tensors([b['flame_jaw'].float() for b in notnone_batches]),
'eyes': collate_tensors([b['flame_eyes'].float() for b in notnone_batches]),
'eyelids': collate_tensors([b['flame_eyelids'].float() for b in notnone_batches]),
}
# pose-related features ####################################
### none-predictive features ###############################
if 'flame_fl' in notnone_batches[0]:
adpated_batch['fl'] = collate_tensors([b['flame_fl'].float() for b in notnone_batches])
if 'flame_pp' in notnone_batches[0]:
adpated_batch['pp'] = collate_tensors([b['flame_pp'].float() for b in notnone_batches])
### predictive features ####################################
if 'flame_R' in notnone_batches[0]:
adpated_batch['R'] = collate_tensors([b['flame_R'].float() for b in notnone_batches])
if 'flame_t' in notnone_batches[0]:
adpated_batch['T'] = collate_tensors([b['flame_t'].float() for b in notnone_batches])
return adpated_batch
def get_datasets(cfg, logger, phase='train'):
dataset_names = eval(f"cfg.{phase.upper()}.DATASETS")
datasets = []
for dataset_name in dataset_names:
if dataset_name.lower() in ['vocaset']:
from .vocaset import VOCASETDataModule
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
collate_fn = vocaset_collate_fn
dataset = VOCASETDataModule(
cfg = cfg,
data_root = data_root,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=cfg.TRAIN.NUM_WORKERS,
debug=cfg.DEBUG,
collate_fn=collate_fn,
)
datasets.append(dataset)
if dataset_name.lower() in ['voxcelebinsta']:
from .voxceleb_insta import VoxCelebInstaDataModule
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
collate_fn = voxcelebinsta_collate_fn
dataset = VoxCelebInstaDataModule(
cfg = cfg,
data_root = data_root,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=cfg.TRAIN.NUM_WORKERS,
debug=cfg.DEBUG,
collate_fn=collate_fn,
predict_pose=cfg.model.predict_pose,
predict_exp=cfg.model.predict_exp,
use_image=cfg.model.use_image,
)
datasets.append(dataset)
if dataset_name.lower() in ['voxcelebinstalmdb']:
from .voxceleb_insta_lmdb import VoxCelebInstalmDBDataModule
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
collate_fn = voxcelebinsta_collate_fn
dataset = VoxCelebInstalmDBDataModule(
cfg = cfg,
data_root = data_root,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=cfg.TRAIN.NUM_WORKERS,
debug=cfg.DEBUG,
collate_fn=collate_fn,
predict_pose=cfg.model.predict_pose,
predict_exp=cfg.model.predict_exp,
use_image=cfg.model.use_image,
)
datasets.append(dataset)
if dataset_name.lower() in ['biwi']:
from .biwi import BIWIDataModule
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
collate_fn = vocaset_collate_fn # this is the same as vocaset
dataset = BIWIDataModule(
cfg = cfg,
data_root = data_root,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=cfg.TRAIN.NUM_WORKERS,
debug=cfg.DEBUG,
collate_fn=collate_fn,
)
datasets.append(dataset)
if dataset_name.lower() in ['voxcelebinstacoeflmdb']:
from .voxceleb_insta_coef_lmdb import VoxCelebInstaCoefLMDBDataModule
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
collate_fn = voxcelebinstacoeflmdb_collate_fn
dataset = VoxCelebInstaCoefLMDBDataModule(
cfg = cfg,
data_root = data_root,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=cfg.TRAIN.NUM_WORKERS,
debug=cfg.DEBUG,
collate_fn=collate_fn,
)
datasets.append(dataset)
cfg.DATASET.NFEATS = datasets[0].nfeats
return datasets
================================================
FILE: alm/data/voca/__init__.py
================================================
from .dataset import VOCASETDataset
================================================
FILE: alm/data/voca/dataset.py
================================================
import numpy as np
import torch
from torch.utils import data
from transformers import Wav2Vec2Processor
from collections import defaultdict
import os
from tqdm import tqdm
import numpy as np
import pickle
class VOCASETDataset(data.Dataset):
def __init__(self,
data,
subjects_dict,
data_type="train",
):
self.data = data
self.len = len(self.data)
self.subjects_dict = subjects_dict
self.data_type = data_type
self.one_hot_labels = np.eye(len(subjects_dict["train"]))
def __getitem__(self, index):
# seq_len, fea_dim
file_name = self.data[index]["name"]
file_path = self.data[index]["path"]
audio = self.data[index]["audio"]
vertice = self.data[index]["vertice"]
template = self.data[index]["template"]
if self.data_type == "train":
subject = "_".join(file_name.split("_")[:-1])
one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)]
elif self.data_type == "val":
one_hot = self.one_hot_labels
elif self.data_type == "test":
subject = "_".join(file_name.split("_")[:-1])
if subject in self.subjects_dict["train"]:
one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)]
else:
one_hot = self.one_hot_labels
return {
'audio':torch.FloatTensor(audio),
'audio_attention':torch.ones_like(torch.Tensor(audio)).long(),
'vertice':torch.FloatTensor(vertice),
'vertice_attention':torch.ones_like(torch.Tensor(vertice)[..., 0]).long(),
'template':torch.FloatTensor(template),
'id':torch.FloatTensor(one_hot),
'file_name':file_name,
'file_path':file_path
}
def __len__(self):
return self.len
================================================
FILE: alm/data/vocaset.py
================================================
from .base import BASEDataModule
from alm.data.voca import VOCASETDataset
from transformers import Wav2Vec2Processor
from collections import defaultdict
import os
from os.path import join as pjoin
import pickle
from tqdm import tqdm
import librosa
import numpy as np
from multiprocessing import Pool
def load_data(args):
file, root_dir, processor, templates, audio_dir, vertice_dir = args
if file.endswith('wav'):
wav_path = os.path.join(root_dir, audio_dir, file)
speech_array, sampling_rate = librosa.load(wav_path, sr=16000)
input_values = np.squeeze(processor(speech_array,sampling_rate=16000).input_values)
key = file.replace("wav", "npy")
result = {}
result["audio"] = input_values
subject_id = "_".join(key.split("_")[:-1])
temp = templates[subject_id]
result["name"] = file.replace(".wav", "")
result["path"] = os.path.abspath(wav_path)
result["template"] = temp.reshape((-1))
vertice_path = os.path.join(root_dir, vertice_dir, file.replace("wav", "npy"))
if not os.path.exists(vertice_path):
return None
else:
result["vertice"] = np.load(vertice_path,allow_pickle=True)[::2,:]
return (key, result)
class VOCASETDataModule(BASEDataModule):
def __init__(self,
cfg,
batch_size,
num_workers,
collate_fn = None,
phase="train",
**kwargs):
super().__init__(batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn)
self.save_hyperparameters(logger=False)
self.name = 'VOCASET'
self.Dataset = VOCASETDataset
self.cfg = cfg
# customized to VOCASET
self.subjects = {
'train': [
'FaceTalk_170728_03272_TA',
'FaceTalk_170904_00128_TA',
'FaceTalk_170725_00137_TA',
'FaceTalk_170915_00223_TA',
'FaceTalk_170811_03274_TA',
'FaceTalk_170913_03279_TA',
'FaceTalk_170904_03276_TA',
'FaceTalk_170912_03278_TA'
],
'val': [
'FaceTalk_170811_03275_TA',
'FaceTalk_170908_03277_TA'
],
'test': [
'FaceTalk_170809_00138_TA',
'FaceTalk_170731_00024_TA'
]
# 'test': [
# 'FaceTalk_170728_03272_TA',
# 'FaceTalk_170904_00128_TA',
# 'FaceTalk_170725_00137_TA',
# 'FaceTalk_170915_00223_TA',
# 'FaceTalk_170811_03274_TA',
# 'FaceTalk_170913_03279_TA',
# 'FaceTalk_170904_03276_TA',
# 'FaceTalk_170912_03278_TA'
# ]
}
self.root_dir = kwargs.get('data_root', 'datasets/vocaset')
self.audio_dir = 'wav'
self.vertice_dir = 'vertices_npy'
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
self.template_file = 'templates.pkl'
self.nfeats = 15069
# load
data = defaultdict(dict)
with open(os.path.join(self.root_dir, self.template_file), 'rb') as fin:
templates = pickle.load(fin, encoding='latin1')
count = 0
args_list = []
for r, ds, fs in os.walk(os.path.join(self.root_dir, self.audio_dir)):
for f in fs:
args_list.append((f, self.root_dir, processor, templates, self.audio_dir, self.vertice_dir, ))
# # comment off for full dataset
# count += 1
# if count > 10:
# break
# split dataset
self.data_splits = {
'train':[],
'val':[],
'test':[],
}
motion_list = []
if True: # multi-process
with Pool(processes=os.cpu_count()) as pool:
results = pool.map(load_data, args_list)
for result in results:
if result is not None:
key, value = result
data[key] = value
else: # single process
for args in tqdm(args_list, desc="Loading data"):
result = load_data(args)
if result is not None:
key, value = result
data[key] = value
else:
print("Warning: data not found")
# # calculate mean and std
# motion_list = np.concatenate(motion_list, axis=0)
# self.mean = np.mean(motion_list, axis=0)
# self.std = np.std(motion_list, axis=0)
splits = {
'train':range(1,41),
'val':range(21,41),
'test':range(21,41)
}
for k, v in data.items():
subject_id = "_".join(k.split("_")[:-1])
sentence_id = int(k.split(".")[0][-2:])
for sub in ['train', 'val', 'test']:
if subject_id in self.subjects[sub] and sentence_id in splits[sub]:
self.data_splits[sub].append(v)
# self._sample_set = self.__getattr__("test_dataset")
def __getattr__(self, item):
# train_dataset/val_dataset etc cached like properties
# question
if item.endswith("_dataset") and not item.startswith("_"):
subset = item[:-len("_dataset")]
item_c = "_" + item
if item_c not in self.__dict__:
# todo: config name not consistent
self.__dict__[item_c] = self.Dataset(
data = self.data_splits[subset] ,
subjects_dict = self.subjects,
data_type = subset
)
return getattr(self, item_c)
classname = self.__class__.__name__
raise AttributeError(f"'{classname}' object has no attribute '{item}'")
================================================
FILE: alm/models/__init__.py
================================================
================================================
FILE: alm/models/architectures/__init__.py
================================================
================================================
FILE: alm/models/architectures/adpt_bias_denoiser.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from alm.models.architectures.tools.embeddings import (TimestepEmbedding,
Timesteps)
from .tools.utils import PeriodicPositionalEncoding, init_bi_biased_mask_faceformer, init_mem_mask_faceformer
from typing import Optional, Tuple, Union, Callable
from .tools.transformer_adpt import TransformerDecoderLayer_w_Adapter, TransformerDecoder_w_Adapter
class Adpt_Bias_Denoiser(nn.Module):
# this model is based on the trasnformer_adpt.py but with some modifications for the diffusion denoising task
def __init__(self,
nfeats: int = 15069,
latent_dim: list = 174,
ff_size: int = 1024,
num_layers: int = 6,
num_heads: int = 4,
dropout: float = 0.1,
normalize_before: bool = False,
activation: str = "gelu",
arch: str = "trans_dec",
audio_encoded_dim: int = 768,
max_len: int = 600,
id_dim: int = 10,
return_intermediate_dec: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
mem_attn_scale: float = 1.0,
tgt_attn_scale: float = 0.1,
period: int = 30,
no_cross: bool = False,
**kwargs) -> None:
super().__init__()
self.latent_dim = latent_dim
self.arch = arch
self.audio_encoded_dim = audio_encoded_dim
# audio projecter
self.audio_feature_map = nn.Linear(audio_encoded_dim, latent_dim)
# motion projecter
self.vertice_map = nn.Linear(nfeats, latent_dim)
# periodic positional encoding
self.PPE = PeriodicPositionalEncoding(latent_dim, period = period, max_seq_len=5000) # max_seq_len can be adjusted if thit reporst an error
# attention bias
assert mem_attn_scale in [-1.0, 0.0, 1.0]
self.use_mem_attn_bias = mem_attn_scale != 0.0
self.use_tgt_attn_bias = tgt_attn_scale != 0.0
self.memory_bi_bias = init_mem_mask_faceformer(max_len)
if tgt_attn_scale < 0.0: # means we only use the causal attention
self.target_bi_bias = init_bi_biased_mask_faceformer(num_heads, max_len, period)
mask = torch.triu(torch.ones(max_len, max_len), diagonal=1) == 1
self.target_bi_bias = self.target_bi_bias.masked_fill(mask, float('-inf'))
else:
self.target_bi_bias = init_bi_biased_mask_faceformer(num_heads, max_len, period)
# init decoder
decoder_layer = TransformerDecoderLayer_w_Adapter(
d_model=latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
norm_first=normalize_before,
batch_first=True
)
self.transformer_decoder = TransformerDecoder_w_Adapter(
decoder_layer=decoder_layer,
num_layers=num_layers,
)
# used for diffusion denoising
self.time_proj = Timesteps(
audio_encoded_dim,
flip_sin_to_cos=flip_sin_to_cos, # because baseline models is trained with this
downscale_freq_shift=freq_shift, # same as above
)
self.time_embedding = TimestepEmbedding(
audio_encoded_dim,
latent_dim * num_layers
)
# motion decoder
self.motion_decoder = nn.Linear(latent_dim, nfeats)
nn.init.constant_(self.motion_decoder.weight, 0)
nn.init.constant_(self.motion_decoder.bias, 0)
# style embedding
self.obj_vector = nn.Embedding(id_dim, latent_dim * num_layers, )
# whether we do not use cross attention
self.no_cross = no_cross
def forward(self,
vertice_input: torch.Tensor,
hidden_state: torch.Tensor,
timesteps: torch.Tensor,
adapter: torch.Tensor = None, # conditions other than the time embedding
tgt_mask: torch.Tensor = None,
tgt_key_padding_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None,
**kwargs):
"""
Auto-regressive forward pass for the decoder.
To be used during training.
Args:
vertice_input: [N, T, E]
hidden_state: [N, S, E]
adapter: [N, A, E]
tgt_mask: [N * H, T, T]
tgt_key_padding_mask: [N, T]
memory_mask: [T, S]
memory_key_padding_mask: [N, S]
"""
# vertice projection
vertice_input = self.vertice_map(vertice_input)
vertice_input = self.PPE(vertice_input)
# time projection
time_emb = self.time_proj(timesteps).to(vertice_input.device)
time_emb = self.time_embedding(time_emb).unsqueeze(1) # time_emb.shape = [N, 1, E]
# treat the time embedding as an adapter
if adapter is not None:
adapter = torch.concat([adapter, time_emb], dim=1)
else:
adapter = time_emb
vertice_out = vertice_input
# split the adpater in to num_layers pieces, in order to feed them into the transformer
adapters = adapter.split(self.latent_dim, dim=-1)
# concat the hidden state and the vertice input
if self.no_cross:
hidden_len = hidden_state.shape[1]
vertice_out = torch.cat([hidden_state, vertice_out], dim=1)
hidden_state = torch.cat([hidden_state, hidden_state], dim=1)
for mod,adapter in zip(self.transformer_decoder.layers, adapters):
vertice_out = mod(
tgt=vertice_out,
memory=hidden_state,
adapter=adapter,
tgt_mask=tgt_mask,
tgt_key_padding_mask = tgt_key_padding_mask,
memory_mask=memory_mask,
memory_key_padding_mask=memory_key_padding_mask,
**kwargs
)
if self.no_cross: # remove the hidden state
vertice_out = vertice_out[:, hidden_len:]
if self.transformer_decoder.norm is not None:
vertice_out = self.transformer_decoder.norm(vertice_out)
self.transformer_decoder.layers[0].self_attn
vertice_out = self.motion_decoder(vertice_out)
return vertice_out
================================================
FILE: alm/models/architectures/tools/embeddings.py
================================================
# This file is taken from signjoey repository
import math
import torch
from torch import Tensor, nn
def get_activation(activation_type):
if activation_type == "relu":
return nn.ReLU()
elif activation_type == "relu6":
return nn.ReLU6()
elif activation_type == "prelu":
return nn.PReLU()
elif activation_type == "selu":
return nn.SELU()
elif activation_type == "celu":
return nn.CELU()
elif activation_type == "gelu":
return nn.GELU()
elif activation_type == "sigmoid":
return nn.Sigmoid()
elif activation_type == "softplus":
return nn.Softplus()
elif activation_type == "softshrink":
return nn.Softshrink()
elif activation_type == "softsign":
return nn.Softsign()
elif activation_type == "tanh":
return nn.Tanh()
elif activation_type == "tanhshrink":
return nn.Tanhshrink()
else:
raise ValueError("Unknown activation type {}".format(activation_type))
class MaskedNorm(nn.Module):
"""
Original Code from:
https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8
"""
def __init__(self, norm_type, num_groups, num_features):
super().__init__()
self.norm_type = norm_type
if self.norm_type == "batch":
self.norm = nn.BatchNorm1d(num_features=num_features)
elif self.norm_type == "group":
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)
elif self.norm_type == "layer":
self.norm = nn.LayerNorm(normalized_shape=num_features)
else:
raise ValueError("Unsupported Normalization Layer")
self.num_features = num_features
def forward(self, x: Tensor, mask: Tensor):
if self.training:
reshaped = x.reshape([-1, self.num_features])
reshaped_mask = mask.reshape([-1, 1]) > 0
selected = torch.masked_select(reshaped, reshaped_mask).reshape(
[-1, self.num_features]
)
batch_normed = self.norm(selected)
scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
return scattered.reshape([x.shape[0], -1, self.num_features])
else:
reshaped = x.reshape([-1, self.num_features])
batched_normed = self.norm(reshaped)
return batched_normed.reshape([x.shape[0], -1, self.num_features])
# TODO (Cihan): Spatial and Word Embeddings are pretty much the same
# We might as well convert them into a single module class.
# Only difference is the lut vs linear layers.
class Embeddings(nn.Module):
"""
Simple embeddings class
"""
# pylint: disable=unused-argument
def __init__(
self,
embedding_dim: int = 64,
num_heads: int = 8,
scale: bool = False,
scale_factor: float = None,
norm_type: str = None,
activation_type: str = None,
vocab_size: int = 0,
padding_idx: int = 1,
freeze: bool = False,
**kwargs
):
"""
Create new embeddings for the vocabulary.
Use scaling for the Transformer.
:param embedding_dim:
:param scale:
:param vocab_size:
:param padding_idx:
:param freeze: freeze the embeddings during training
"""
super().__init__()
self.embedding_dim = embedding_dim
self.vocab_size = vocab_size
self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx)
self.norm_type = norm_type
if self.norm_type:
self.norm = MaskedNorm(
norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
)
self.activation_type = activation_type
if self.activation_type:
self.activation = get_activation(activation_type)
self.scale = scale
if self.scale:
if scale_factor:
self.scale_factor = scale_factor
else:
self.scale_factor = math.sqrt(self.embedding_dim)
if freeze:
freeze_params(self)
# pylint: disable=arguments-differ
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
"""
Perform lookup for input `x` in the embedding table.
:param mask: token masks
:param x: index in the vocabulary
:return: embedded representation for `x`
"""
x = self.lut(x)
if self.norm_type:
x = self.norm(x, mask)
if self.activation_type:
x = self.activation(x)
if self.scale:
return x * self.scale_factor
else:
return x
def __repr__(self):
return "%s(embedding_dim=%d, vocab_size=%d)" % (
self.__class__.__name__,
self.embedding_dim,
self.vocab_size,
)
class SpatialEmbeddings(nn.Module):
"""
Simple Linear Projection Layer
(For encoder outputs to predict glosses)
"""
# pylint: disable=unused-argument
def __init__(
self,
embedding_dim: int,
input_size: int,
num_heads: int,
freeze: bool = False,
norm_type: str = "batch",
activation_type: str = "softsign",
scale: bool = False,
scale_factor: float = None,
**kwargs
):
"""
Create new embeddings for the vocabulary.
Use scaling for the Transformer.
:param embedding_dim:
:param input_size:
:param freeze: freeze the embeddings during training
"""
super().__init__()
self.embedding_dim = embedding_dim
self.input_size = input_size
self.ln = nn.Linear(self.input_size, self.embedding_dim)
self.norm_type = norm_type
if self.norm_type:
self.norm = MaskedNorm(
norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
)
self.activation_type = activation_type
if self.activation_type:
self.activation = get_activation(activation_type)
self.scale = scale
if self.scale:
if scale_factor:
self.scale_factor = scale_factor
else:
self.scale_factor = math.sqrt(self.embedding_dim)
if freeze:
freeze_params(self)
# pylint: disable=arguments-differ
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
"""
:param mask: frame masks
:param x: input frame features
:return: embedded representation for `x`
"""
x = self.ln(x)
if self.norm_type:
x = self.norm(x, mask)
if self.activation_type:
x = self.activation(x)
if self.scale:
return x * self.scale_factor
else:
return x
def __repr__(self):
return "%s(embedding_dim=%d, input_size=%d)" % (
self.__class__.__name__,
self.embedding_dim,
self.input_size,
)
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
================================================
FILE: alm/models/architectures/tools/transformer_adpt.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from alm.models.architectures.tools.embeddings import (TimestepEmbedding,
Timesteps)
from .utils import PeriodicPositionalEncoding, init_bi_biased_mask
from typing import Optional, Tuple, Union, Callable
import math
class Transformer_Adpt(nn.Module):
def __init__(self,
nfeats: int = 15069,
latent_dim: list = 174,
ff_size: int = 1024,
num_layers: int = 6,
num_heads: int = 4,
dropout: float = 0.1,
normalize_before: bool = False,
activation: str = "gelu",
arch: str = "trans_dec",
audio_encoded_dim: int = 768,
max_len: int = 3000,
id_dim: int = 10,
return_intermediate_dec: bool = False,
require_start_token: bool = False,
require_time_encoding: bool = True,
**kwargs) -> None:
super().__init__()
self.latent_dim = latent_dim
self.arch = arch
self.audio_encoded_dim = audio_encoded_dim
# audio projecter
self.audio_feature_map = nn.Linear(audio_encoded_dim, latent_dim)
# motion projecter
self.vertice_map = nn.Linear(nfeats, latent_dim)
# periodic positional encoding
self.PPE = PeriodicPositionalEncoding(latent_dim, period = max_len)
# temporal bias
self.memory_bi_bias = init_bi_biased_mask(max_len) # this is for the memory bias, not all the model
# init decoder
decoder_layer = TransformerDecoderLayer_w_Adapter(
d_model=latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
norm_first=normalize_before,
batch_first=True
)
self.transformer_decoder = TransformerDecoder_w_Adapter(
decoder_layer=decoder_layer,
num_layers=num_layers,
)
# motion decoder
self.motion_decoder = nn.Linear(latent_dim, nfeats)
# used for auto-regressive decoding
if require_start_token:
self.start_token = nn.Parameter(torch.randn(1, 1, latent_dim), requires_grad=True)
# style embedding
# self.obj_vector = nn.Linear(id_dim, latent_dim, bias=False)
id_len = 1000
self.id_enc = torch.zeros([id_len, latent_dim])
position = torch.arange(0, id_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, latent_dim, 2).float() * (-math.log(10000.0) / latent_dim))
self.id_enc[:, 0::2] = torch.sin(position * div_term)
self.id_enc[:, 1::2] = torch.cos(position * div_term)
nn.init.constant_(self.motion_decoder.weight, 0)
nn.init.constant_(self.motion_decoder.bias, 0)
def obj_vector(self, id):
# if id is a one-hot vector
if id.dim() == 2: # [N, id_dim]
id = id.argmax(dim=1)
return self.id_enc[id] # [N, id_dim]
def _forward(self,
vertice_input: torch.Tensor,
hidden_state: torch.Tensor,
adapter: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
tgt_key_padding_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None,
**kwargs):
"""
Auto-regressive forward pass for the decoder.
To be used during training.
Args:
vertice_input: [N, T, E]
hidden_state: [N, S, E]
adapter: [N, A, E]
tgt_mask: [N * H, T, T]
tgt_key_padding_mask: [N, T]
memory_mask: [T, S]
memory_key_padding_mask: [N, S]
"""
vertice_input = self.PPE(vertice_input)
vertice_out = self.transformer_decoder(
tgt=vertice_input,
memory=hidden_state,
adapter = adapter,
tgt_mask=tgt_mask,
tgt_key_padding_mask = tgt_key_padding_mask,
memory_mask=memory_mask,
memory_key_padding_mask = memory_key_padding_mask,
**kwargs
)
vertice_out = self.motion_decoder(vertice_out)
return vertice_out
from torch.nn.modules.transformer import _get_clones
class TransformerDecoder_w_Adapter(nn.TransformerDecoder):
"""
A transformer decoder with adapter layer.
Args:
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
num_layers: the number of sub-DecoderLayer in the decoder (required).
norm: the layer normalization component (optional).
"""
__constants__ = ['norm']
def __init__(self, decoder_layer, num_layers, norm=None):
super(TransformerDecoder_w_Adapter, self).__init__(decoder_layer, num_layers, norm)
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
adapter: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
adapter: the adapter for the decoder layer (optional).
Shape:
see the docs in Transformer class.
"""
output = tgt
for mod in self.layers:
output = mod(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
adapter=adapter)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoderLayer_w_Adapter(nn.TransformerDecoderLayer):
"""
A single layer of the transformer decoder with adapter.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: if ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False``.
norm_first: if ``True``, then the input and output tensors are provided as (seq, batch, feature). Default: ``False``.
device: the desired device of the encoder layer. Default: if ``None`` will use ``torch.device("cuda")`` if ``torch.cuda.is_available()`` else ``torch.device("cpu")``
dtype: the desired dtype of the encoder layer. Default: if ``None`` will use ``torch.float32``
"""
__constants__ = ['batch_first', 'norm_first']
def __init__(self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = False,
device=None, dtype=None) -> None:
# folow the original transformer decoder layer
factory_kwargs = {'device': device, 'dtype': dtype}
super(TransformerDecoderLayer_w_Adapter, self).__init__(
d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, **factory_kwargs)
def forward(self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
adapter: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
adapter: the adapter for the decoder layer (optional).
Shape:
see the docs in Transformer class.
"""
x = tgt
if self.norm_first:
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, adapter=adapter)
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, adapter=adapter)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, adapter=adapter))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, adapter=adapter))
x = self.norm3(x + self._ff_block(x))
return x
# self-attention block with adapter
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
adapter: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x: [B, T, E] if batch_first else [T, B, E]
attn_mask: [T, T]
key_padding_mask: [B, T]
adapter: [B, A, E] if batch_first else [A, B, E]
Returns:
[B, T, E] if batch_first else [T, B, E]
"""
batch_first = self.self_attn.batch_first
# concate adapter to key and value if it is not None
if adapter is not None:
x_adpt = self._concate_adapter(adapter, x, batch_first=batch_first)
else:
x_adpt = x
# # original self-attention block
# tmp = self.self_attn(x, x_adpt, x_adpt, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=True, )[1]
# # visualize attention, use sns
# import matplotlib.pyplot as plt
# import seaborn as sns
# length = 100
# fig, ax = plt.subplots(figsize=(15, 10))
# sns.heatmap(tmp[0, :length, :length+2].detach().cpu().numpy())
# # save to disk
# plt.savefig('self_attention.png')
x = self.self_attn(x, x_adpt, x_adpt, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0]
return self.dropout1(x)
# cross-attention block with adapter
def _mha_block(self, x: Tensor, mem: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
adapter: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x: [B, T, E] if batch_first else [T, B, E]
mem: [B, S, E] if batch_first else [S, B, E]
attn_mask: [T, S]
key_padding_mask: [B, T]
adapter: [B, A, E] if batch_first else [A, B, E]
Returns:
[B, T, E] if batch_first else [T, B, E]
"""
batch_first = self.multihead_attn.batch_first
# concate adapter to key and value if it is not None
if adapter is not None:
mem_adpt = self._concate_adapter(adapter, mem, batch_first=batch_first)
else:
mem_adpt = x
# tmp = self.multihead_attn(x, mem_adpt, mem_adpt, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=True, )[1]
# # visualize attention, use sns
# import matplotlib.pyplot as plt
# import seaborn as sns
# length = 100
# fig, ax = plt.subplots(figsize=(15, 10))
# sns.heatmap(tmp[0, :length, :length+2].detach().cpu().numpy())
# # save to disk
# plt.savefig('cross_attention.png')
# original cross-attention block
x = self.multihead_attn(x, mem_adpt, mem_adpt, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0]
return self.dropout2(x)
def _concate_adapter(self, adapter: Tensor, x: Tensor, batch_first: bool = True):
"""
concate adapter ahead of x
Args:
adapter: [B, A, E] if batch_first else [A, B, E]
x: [B, T, E] if batch_first else [T, B, E]
Returns:
x_adapted: [B, A+T, E] if batch_first else [A+T, B, E]
"""
if batch_first:
x_adapted = torch.concat([adapter, x], dim=1) # [B, A, E] + [B, T, E] -> [B, A+T, E]
else: # batch_first
x_adapted = torch.concat([adapter, x], dim=0) # [A, B, E] + [T, B, E] -> [A+T, B, E]
return x_adapted
================================================
FILE: alm/models/architectures/tools/utils.py
================================================
import math
import torch
import torch.nn as nn
class PeriodicPositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=600):
super(PeriodicPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(period, d_model)
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, period, d_model)
repeat_num = (max_seq_len//period) + 1
pe = pe.repeat(1, repeat_num, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# def init_biased_mask(n_head, max_seq_len, period):
# # this code is from https://github.com/EvelynFan/FaceFormer/blob/dfaea81983665b22b99af336a80574208cfcc099/faceformer.py#L10
# # however, the original code is not working for the case where the batch size is not 1
# # so I modified it a little bit
# def get_slopes(n):
# def get_slopes_power_of_2(n):
# start = (2**(-2**-(math.log2(n)-3)))
# ratio = start
# return [start*ratio**i for i in range(n)]
# if math.log2(n).is_integer():
# return get_slopes_power_of_2(n)
# else:
# closest_power_of_2 = 2**math.floor(math.log2(n))
# return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
# slopes = torch.Tensor(get_slopes(n_head))
# bias = torch.div(
# torch.arange(
# start=0,
# end=max_seq_len,
# step=period,
# dtype=torch.float
# ).unsqueeze(1).repeat(1,period).view(-1),
# period,
# rounding_mode='floor'
# )
# bias = - torch.flip(bias,dims=[0])
# alibi = torch.zeros(max_seq_len, max_seq_len)
# for i in range(max_seq_len):
# alibi[i, :i+1] = bias[-(i+1):]
# alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
# mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
# mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
# mask = mask.unsqueeze(0) + alibi
# return mask
# def init_bi_biased_mask_dev(n_head, max_seq_len, period):
# # this code is from # https://github.com/tensorflow/tensor2tensor
# # and I modified it a little bit to match the original code in https://github.com/EvelynFan/FaceFormer/blob/dfaea81983665b22b99af336a80574208cfcc099/faceformer.py#L10
# # such the code is working for the case where the batch size is not 1
# def get_slopes(n):
# def get_slopes_power_of_2(n):
# start = (2**(-2**-(math.log2(n)-3)))
# ratio = start
# return [start*ratio**i for i in range(n)]
# if math.log2(n).is_integer():
# return get_slopes_power_of_2(n)
# else:
# closest_power_of_2 = 2**math.floor(math.log2(n))
# return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
# slopes = torch.Tensor(get_slopes(n_head))
# range_vec = torch.div(
# torch.arange(
# start=0,
# end=max_seq_len,
# step=period,
# dtype=torch.float
# ).unsqueeze(1).repeat(1,period).view(-1),
# period,
# rounding_mode='floor'
# )
# relative_matrix = range_vec[None, :] - range_vec[:, None]
# relative_matrix[torch.where(relative_matrix > 0)] *= -1
# alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_matrix.unsqueeze(0)
# return alibi
def init_biased_mask(n_head, max_seq_len, period):
# this code is from https://github.com/EvelynFan/FaceFormer/blob/dfaea81983665b22b99af336a80574208cfcc099/faceformer.py#L10
# however, the original code is not working for the case where the batch size is not 1
# so I modified it a little bit
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
slopes = torch.Tensor(get_slopes(n_head))
bias = torch.div(
torch.arange(
start=0,
end=max_seq_len,
step=period,
dtype=torch.float
).unsqueeze(1).repeat(1,period).view(-1),
period,
rounding_mode='floor'
)
bias = - torch.flip(bias,dims=[0])
alibi = torch.zeros(max_seq_len, max_seq_len)
for i in range(max_seq_len):
alibi[i, :i+1] = bias[-(i+1):]
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
mask = mask.unsqueeze(0) + alibi
return mask
def init_bi_biased_mask(max_seq_len, ):
# any attention mask that is as more than 3 elements is not working
range_vec = torch.arange(
start=0,
end=max_seq_len,
dtype=torch.float
)
relative_matrix = range_vec[None, :] - range_vec[:, None]
relative_matrix[torch.where(relative_matrix > 0)] *= -1
return relative_matrix
def init_mem_mask_faceformer(max_seq_len):
mask = torch.ones(max_seq_len, max_seq_len)
# set the diagonal to 0
mask = mask.masked_fill(torch.eye(max_seq_len) == 1, 0)
return mask
def init_bi_biased_mask_faceformer(n_head, max_seq_len, period):
# any attention mask that is as more than 3 elements is not working
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
slopes = torch.Tensor(get_slopes(n_head))
bias = torch.div(torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1), period, rounding_mode='floor')
bias = - torch.flip(bias,dims=[0])
alibi = torch.zeros(max_seq_len, max_seq_len)
for i in range(max_seq_len):
alibi[i, :i+1] = bias[-(i+1):]
if i+1 < max_seq_len:
alibi[i, i+1:] = bias[-(max_seq_len-(i+1)):].flip(dims=[0])
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
return alibi
================================================
FILE: alm/models/get_model.py
================================================
import importlib
def get_model(cfg, datamodule):
modeltype = cfg.model.model_type
return get_module(cfg, datamodule)
# if modeltype in ["proto","faceformer"]:
# return get_module(cfg, datamodule)
# else:
# raise ValueError(f"Invalid model {modeltype}.")
def get_module(cfg, datamodule):
modeltype = cfg.model.model_type
model_module = importlib.import_module(
f".modeltype.{cfg.model.model_type}", package="alm.models")
Model = model_module.__getattribute__(f"{modeltype.upper()}")
return Model(cfg=cfg, datamodule=datamodule)
================================================
FILE: alm/models/losses/__init__.py
================================================
================================================
FILE: alm/models/losses/utils.py
================================================
import torch
import torch.nn as nn
class MaskedConsistency:
def __init__(self) -> None:
self.loss = nn.MSELoss(reduction="mean")
def __call__(self, pred, gt, mask):
return self.loss(mask * pred, mask * gt)
def __repr__(self):
return self.loss.__repr__()
class MaskedVelocityConsistency:
def __init__(self) -> None:
self.loss = nn.MSELoss(reduction="mean")
def __call__(self, pred, gt, mask):
term1 = self.velocity(mask * pred)
temr2 = self.velocity(mask * gt)
return self.loss(term1, temr2)
def velocity(self, term):
velocity = term[:, 1:] - term[:, :-1]
return velocity
def __repr__(self):
return self.loss.__repr__()
# flame_lmk_faces = [[2210, 2212, 2213],
# [3060, 3059, 1962],
# [3485, 3060, 1961],
# [3382, 3384, 3381],
# [3385, 3388, 3386],
# [3387, 3389, 3390],
# [3392, 3418, 3419],
# [3415, 3395, 3393],
# [3414, 3399, 3397],
# [3634, 3595, 3598],
# [3643, 3637, 3594],
# [3588, 3587, 3583],
# [3584, 3581, 3582],
# [3742, 3580, 3577],
# [2012, 3756, 566],
# [2009, 2012, 2011],
# [ 728, 731, 730],
# [1983, 1984, 1985],
# [3157, 3708, 3158],
# [ 338, 3153, 335],
# [3154, 3712, 3705],
# [2179, 2178, 3684],
# [ 674, 3851, 673],
# [3863, 3868, 2135],
# [2134, 27, 16],
# [2138, 2139, 3865],
# [ 572, 571, 570],
# [2194, 3553, 3542],
# [3561, 739, 3518],
# [1757, 3521, 3501],
# [3526, 3564, 1819],
# [2748, 2746, 2750],
# [2792, 2794, 2795],
# [1692, 3556, 3507],
# [1678, 1677, 1675],
# [1612, 1618, 1610],
# [2440, 2428, 2437],
# [2383, 2453, 2495],
# [2493, 3689, 2494],
# [2509, 3631, 3632],
# [2293, 2298, 2299],
# [2333, 2296, 2295],
# [3833, 3832, 1358],
# [1342, 1343, 3855],
# [1344, 1218, 1034],
# [1182, 1175, 1154],
# [ 955, 883, 884],
# [ 881, 897, 896],
# [2845, 2715, 2714],
# [2849, 2813, 2850],
# [2811, 2866, 2774],
# [1657, 3543, 3546],
# [1694, 1657, 1751],
# [1734, 1735, 1696],
# [1730, 1578, 1579],
# [1774, 1795, 1796],
# [1802, 1865, 1866],
# [1850, 3506, 3503],
# [2905, 2949, 2948],
# [2899, 2898, 2881],
# [2719, 2718, 2845],
# [3533, 2786, 2785],
# [3533, 3531, 2786],
# [1669, 3533, 1668],
# [1730, 1578, 1579],
# [1826, 1848, 1849],
# [3504, 3509, 2937],
# [2938, 2937, 2928]] # extracted from the mica-trakcer code
# def flame_vertice_2_lmk(vertice, lmk_faces = flame_lmk_faces):
# if len(vertice.shape) == 3: # if the vertices is resizes to (B, N, V * 3), reshape it back to (B, N, V, 3)
# resized = True
# B, N, _ = vertice.shape
# vertice = vertice.reshape(B, N, -1, 3)
# V = vertice.shape[-2]
# else:
# resized = False
# B, N, V, _ = vertice.shape
# _vertice = vertice.view(-1, V, 3) # (B * N, V, 3) -> (_B, V, 3)
# _B = _vertice.shape[0]
# lmk_faces = torch.tensor(lmk_faces,).clone() # (68, 3)
# lmk_faces = lmk_faces.unsqueeze(0).expand(_B, -1, -1).to(vertice.device) # (68, 3) -> (_B, 68, 3)
# lmk_faces += torch.arange(_B, device=vertice.device).view(-1, 1, 1) * V # (_B, 68, 3)
# lmk_vertices = _vertice.reshape(-1, 3)[lmk_faces].view(_B, -1, 3, 3) # (_B, 68, 3, 3), here we have to use .reshape(-1, 3) to make sure the index is correct, view reports error
# landmarks = lmk_vertices.mean(dim=-2) # (_B, 68, 3), every vertice in the face contributes to the landmark
# if resized:
# landmarks = landmarks.view(B, N, -1) # (B, N, 68 * 3)
# else:
# landmarks = landmarks.view(B, N, -1, 3) # (B, N, 68, 3)
# return landmarks
# def cusum(data, threshold, drift):
# """Cumulative sum algorithm (CUSUM) to detect abrupt changes in data."""
# # Initialize variables
# thres = torch.zeros_like(data)
# # Compute the cumulative sum using torch.cumsum
# for i in range(1, data.shape[1]):
# # Update the cumulative sum
# prev_thres = thres[:, i-1].unsqueeze(1)
# delta = (data[:, i, :] - data[:, i-1, :]).abs().unsqueeze(1) #torch.norm(data[:, i, :] - data[:, i-1, :], dim=1).unsqueeze(1)
# thres[:, i] = torch.max(torch.zeros_like(prev_thres), prev_thres + delta - threshold).squeeze(1)
# # Update the threshold
# threshold += drift
# return thres
# the following is required by BIWI
import os
import numpy as np
import pickle
with open(os.path.join("datasets/biwi/regions", "lve.txt")) as f:
maps = f.read().split(", ")
mouth_map = [int(i) for i in maps]
with open(os.path.join("datasets/biwi/regions", "fdd.txt")) as f:
maps = f.read().split(", ")
upper_map = [int(i) for i in maps]
# open /home/zhiyuan_ma/code/FaceDiffusion/datasets/vocaset/FLAME_masks.pkl
with open(os.path.join("datasets/vocaset", "FLAME_masks.pkl"), "rb") as f:
masks = pickle.load(f, encoding='latin1')
vocaset_mouth_map = masks["lips"].tolist()
vocaset_upper_map = masks["forehead"].tolist() + masks["eye_region"].tolist()
vocaset_upper_map = list(set(vocaset_upper_map))
def vocaset_upper_face_variance(motion, ):
L2_dis_upper = np.array([np.square(motion[:,v, :]) for v in vocaset_upper_map])
L2_dis_upper = np.transpose(L2_dis_upper, (1,0,2))
L2_dis_upper = np.sum(L2_dis_upper,axis=2)
L2_dis_upper = np.std(L2_dis_upper, axis=0)
motion_std = np.mean(L2_dis_upper)
return torch.tensor(motion_std).float() #torch.from_numpy(motion_std).float()
def vocaset_mouth_distance(vertices_gt, vertices_pred):
L2_dis = np.array([np.square(vertices_gt[:,v, :] - vertices_pred[:,v, :]) for v in vocaset_mouth_map])
L2_dis = np.transpose(L2_dis, (1,0,2)) # (V, N, 3) -> (N, V, 3)
L2_dis = np.sum(L2_dis, axis=2) # (N, V, 3) -> (N, V)
L2_dis = np.max(L2_dis, axis=1) # (N, V) -> (N)
return torch.tensor(L2_dis).float()
def biwi_upper_face_variance(motion, ):
L2_dis_upper = np.array([np.square(motion[:,v, :]) for v in upper_map])
L2_dis_upper = np.transpose(L2_dis_upper, (1,0,2))
L2_dis_upper = np.sum(L2_dis_upper,axis=2)
L2_dis_upper = np.std(L2_dis_upper, axis=0)
motion_std = np.mean(L2_dis_upper)
return torch.tensor(motion_std).float() #torch.from_numpy(motion_std).float()
def biwi_mouth_distance(vertices_gt, vertices_pred):
L2_dis = np.array([np.square(vertices_gt[:,v, :] - vertices_pred[:,v, :]) for v in mouth_map])
L2_dis = np.transpose(L2_dis, (1,0,2)) # (V, N, 3) -> (N, V, 3)
L2_dis = np.sum(L2_dis, axis=2) # (N, V, 3) -> (N, V)
L2_dis = np.max(L2_dis, axis=1) # (N, V) -> (N)
return torch.tensor(L2_dis).float()
================================================
FILE: alm/models/losses/voca.py
================================================
# import numpy as np
# import torch
# import torch.nn as nn
# from torchmetrics import Metric
# class almLosses(Metric):
# """
# Audio latent motion losses
# """
# def __init__(self, cfg):
# super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP)
# # Save parameters
# # self.vae = vae
# self.cfg = cfg
# self.stage = cfg.TRAIN.STAGE
# def updata(self, rs_set):
# return None
# def compute(self, split):
# count = getattr(self, "count")
# return {loss: getattr(self, loss) / count for loss in self.losses}
import numpy as np
import torch
import torch.nn as nn
from torchmetrics import Metric
import os
import pickle
class VOCALosses(Metric):
"""
MLD Loss
"""
def __init__(self, cfg, split):
super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP)
self.cfg = cfg
# set up loss
losses = []
self._losses_func = {}
self._params = {}
reconstruct = MaskedConsistency()
reconstruct_v = MaskedVelocityConsistency()
self.split = split
is_train = split in ['losses_train']
if split in ['losses_train', 'losses_val']:
# vertice
name = "vertice_enc" # enc here means encoding, just for matching the name in the the diffusion-denoising experiment
losses.append(name)
self._losses_func[name] = reconstruct
self._params[name] = cfg.LOSS.VERTICE_ENC if is_train else 1.0
self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")
name = "vertice_encv" # encv here means encoding velocity, just for matching the name in the the diffusion-denoising experiment
losses.append(name)
self._losses_func[name] = reconstruct_v
self._params[name] = cfg.LOSS.VERTICE_ENC_V if is_train else 1.0
self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")
name = "lip_enc"
losses.append(name)
self._losses_func[name] = reconstruct
self._params[name] = cfg.LOSS.LIP_ENC if is_train else 1.0
self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")
name = "lip_encv"
losses.append(name)
self._losses_func[name] = reconstruct_v
self._params[name] = cfg.LOSS.LIP_ENC_V if is_train else 1.0
self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")
elif split in ['losses_test']:
pass # no loss for test
else:
raise ValueError(f"split {split} not supported")
name = "total"
losses.append(name)
self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")
name = 'count'
self.add_state(name, default=torch.tensor(0), dist_reduce_fx="sum")
self.losses = losses
# # obtain the lmk index
# # the following is required by BIWI
# with open(os.path.join("datasets/biwi/regions", "lve.txt")) as f:
# maps = f.read().split(", ")
# self.biwi_mouth_map = [int(i) for i in maps]
# self.biwi_mouth_map = torch.tensor(self.biwi_mouth_map).long()
# # the following is required by vocaset
# with open(os.path.join("datasets/vocaset", "FLAME_masks.pkl"), "rb") as f:
# masks = pickle.load(f, encoding='latin1')
# self.vocaset_mouth_map = masks["lips"].tolist()
# self.vocaset_mouth_map = torch.tensor(self.vocaset_mouth_map).long()
# def vert2lip(self, vertice):
# num_verts = vertice.shape[-1] // 3
# if num_verts == 5023:
# mouth_map = self.vocaset_mouth_map.to(vertice.device)
# elif num_verts == 23370:
# mouth_map = self.biwi_mouth_map.to(vertice.device)
# else:
# raise ValueError(f"num_verts {num_verts} not supported")
# shape = vertice.shape
# lip_vertice = vertice.view(shape[0], shape[1], -1, 3)[:, :, mouth_map, :].view(shape[0], shape[1], -1)
# return lip_vertice
def update(self, rs_set):
# rs_set.keys() = dict_keys(['latent', 'latent_pred', 'vertice', 'vertice_recon', 'vertice_pred', 'vertice_attention'])
total: float = 0.0
# Compute the losses
# Compute instance loss
# padding mask
mask = rs_set['vertice_attention'].unsqueeze(-1)
if self.split in ['losses_train', 'losses_val']:
# vertice loss
total += self._update_loss("vertice_enc", rs_set['vertice'], rs_set['vertice_pred'], mask = mask)
total += self._update_loss("vertice_encv", rs_set['vertice'], rs_set['vertice_pred'], mask = mask)
# lip loss
# lip_vertice = self.vert2lip(rs_set['vertice'])
# lip_vertice_pred = self.vert2lip(rs_set['vertice_pred'])
# total += self._update_loss("lip_enc", lip_vertice, lip_vertice_pred, mask = mask)
# total += self._update_loss("lip_encv", lip_vertice, lip_vertice_pred, mask = mask)
self.total += total.detach()
self.count += 1
return total
if self.split in ['losses_test']:
raise ValueError(f"split {self.split} not supported")
def compute(self, split):
count = getattr(self, "count")
return {loss: getattr(self, loss) / count for loss in self.losses}
def _update_loss(self, loss: str, outputs, inputs, mask = None):
# Update the loss
if mask is not None:
val = self._losses_func[loss](outputs, inputs, mask)
else:
val = self._losses_func[loss](outputs, inputs)
getattr(self, loss).__iadd__(val.detach())
# Return a weighted sum
weighted_loss = self._params[loss] * val
return weighted_loss
def loss2logname(self, loss: str, split: str):
if loss == "total":
log_name = f"{loss}/{split}"
else:
loss_type, name = loss.split("_")
log_name = f"{loss_type}/{name}/{split}"
return log_name
class MaskedConsistency:
def __init__(self) -> None:
self.loss = nn.MSELoss(reduction="mean")
def __call__(self, pred, gt, mask):
# # masking nan
# is_nan = torch.logical_or(torch.isnan(pred), torch.isnan(gt))
# nan_mask = torch.logical_not(is_nan).long()
# torch.where(nan_mask[0, ..., 0] != mask[0].squeeze())
return self.loss(mask * pred, mask * gt)
def __repr__(self):
return self.loss.__repr__()
class MaskedVelocityConsistency:
def __init__(self) -> None:
self.loss = nn.MSELoss(reduction="mean")
def __call__(self, pred, gt, mask):
term1 = self.velocity(mask * pred)
temr2 = self.velocity(mask * gt)
return self.loss(term1, temr2)
def velocity(self, term):
velocity = term[:, 1:] - term[:, :-1]
return velocity
def __repr__(self):
return self.loss.__repr__()
================================================
FILE: alm/models/modeltype/__init__.py
================================================
================================================
FILE: alm/models/modeltype/base.py
================================================
import os
from pathlib import Path
import numpy as np
from pytorch_lightning import LightningModule
import torch
from collections import OrderedDict
class BaseModel(LightningModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.times = []
def __post_init__(self):
trainable, nontrainable = 0, 0
for p in self.parameters():
if p.requires_grad:
trainable += np.prod(p.size())
else:
nontrainable += np.prod(p.size())
self.hparams.n_params_trainable = trainable
self.hparams.n_params_nontrainable = nontrainable
def training_step(self, batch, batch_idx):
return self.allsplit_step("train", batch, batch_idx)
def validation_step(self, batch, batch_idx):
return self.allsplit_step("val", batch, batch_idx)
def test_step(self, batch, batch_idx):
if len(self.times) *self.cfg.TEST.BATCH_SIZE % (100) > 0 and len(self.times) > 0:
print(f"Average time per sample ({self.cfg.TEST.BATCH_SIZE*len(self.times)}): ", np.mean(self.times)/self.cfg.TEST.BATCH_SIZE)
return self.allsplit_step("test", batch, batch_idx)
def predict_step(self, batch, batch_idx):
return self.forward(batch)
def allsplit_epoch_end(self, split: str, outputs):
dico = {}
if split in ["train", "val"]:
losses = self.losses[split]
loss_dict = losses.compute(split)
losses.reset()
dico.update({
losses.loss2logname(loss, split): value.item()
for loss, value in loss_dict.items() #if not torch.isnan(value)
})
dico.update({
"epoch": float(self.trainer.current_epoch),
"step": float(self.trainer.current_epoch),
})
if split == "test":
metircs = {key: [] for key in outputs[0].keys()}
for output in outputs: # collect the results from all batches
for key, value in output.items():
metircs[key].append(value)
lengths = torch.stack(metircs.pop("Length"))
for key, value in metircs.items():
if key == 'Lip Vertex Error':
metircs[key] = torch.mean(torch.stack(value) * lengths) / torch.mean(lengths)
metircs[key] = torch.mean(torch.stack(value))
dico.update(metircs)
if not self.trainer.sanity_checking:
self.log_dict(dico, sync_dist=True, rank_zero_only=True)
def training_epoch_end(self, outputs):
return self.allsplit_epoch_end("train", outputs)
def validation_epoch_end(self, outputs):
return self.allsplit_epoch_end("val", outputs)
def test_epoch_end(self, outputs):
return self.allsplit_epoch_end("test", outputs)
# def on_save_checkpoint(self, checkpoint):
# # don't save audio_encoder to checkpoint
# state_dict = checkpoint['state_dict']
# clip_k = []
# for k, v in state_dict.items():
# if 'audio_encoder' in k:
# clip_k.append(k)
# for k in clip_k:
# del checkpoint['state_dict'][k]
# def on_load_checkpoint(self, checkpoint):
# # restore clip state_dict to checkpoint
# clip_state_dict = self.audio_encoder.state_dict()
# new_state_dict = OrderedDict()
# for k, v in clip_state_dict.items():
# new_state_dict['audio_encoder.' + k] = v
# for k, v in checkpoint['state_dict'].items():
# if 'audio_encoder' not in k:
# new_state_dict[k] = v
# checkpoint['audio_dict'] = new_state_dict
# def load_state_dict(self, state_dict, strict=True):
# # load clip state_dict to checkpoint
# clip_state_dict = self.audio_encoder.state_dict()
# new_state_dict = OrderedDict()
# for k, v in clip_state_dict.items():
# new_state_dict['audio_encoder.' + k] = v
# for k, v in state_dict.items():
# if 'audio_encoder' not in k:
# new_state_dict[k] = v
# super().load_state_dict(new_state_dict, strict)
def configure_optimizers(self):
return {"optimizer": self.optimizer}
================================================
FILE: alm/models/modeltype/diffusion_bias.py
================================================
import torch
from torch.optim import AdamW, Adam
import torch.nn.functional as F
from torchmetrics import MetricCollection
from transformers import Wav2Vec2Model
from alm.config import instantiate_from_config
from alm.models.modeltype.base import BaseModel
from alm.models.losses.voca import VOCALosses
from alm.utils.demo_utils import animate
from .base import BaseModel
import inspect
from typing import Optional, Tuple, Union, Callable
import os
import time
from multiprocessing import Process
from tqdm import tqdm
import numpy as np
from time import time as infer_time
import pickle
class DIFFUSION_BIAS(BaseModel):
def __init__(self, cfg, datamodule, **kwargs):
"""
Initialize the model
"""
# we only use the functions in the GPt_ADPT_LOCAL_ATTEN class, so we don't need to call the __init__ function of the GPT_ADPT_LOCAL_ATTEN class
super().__init__()
self.cfg = cfg
self.datamodule = datamodule
# set up losses
self._losses = MetricCollection({
split: VOCALosses(cfg=cfg, split=split)
for split in ["losses_train", "losses_test", "losses_val",] # "losses_train_val"
})
self.losses = {
key: self._losses["losses_" + key]
for key in ["train", "test", "val", ] # "train_val"
}
# set up model
self.audio_encoder = Wav2Vec2Model.from_pretrained(cfg.audio_encoder.model_name_or_path)
if cfg.audio_encoder.train_audio_encoder:
self.audio_encoder.feature_extractor._freeze_parameters() # we don't want to train the feature extractor
else:
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.denoiser = instantiate_from_config(cfg.model.denoiser)
# set up optimizer
if cfg.TRAIN.OPTIM.TYPE.lower() == "adamw":
self.optimizer = AdamW(lr=cfg.TRAIN.OPTIM.LR,
params=filter(lambda p: p.requires_grad,self.parameters())
)
elif cfg.TRAIN.OPTIM.TYPE.lower() == "adam":
self.optimizer = Adam(lr=cfg.TRAIN.OPTIM.LR,
params=filter(lambda p: p.requires_grad,self.parameters())
)
else:
raise NotImplementedError(
"Do not support other optimizer for now.")
# set up diffusion specific initialization
if not cfg.model.predict_epsilon:
cfg.model.scheduler.params['prediction_type'] = 'sample'
cfg.model.noise_scheduler.params['prediction_type'] = 'sample'
self.scheduler = instantiate_from_config(cfg.model.scheduler)
self.noise_scheduler = instantiate_from_config(cfg.model.noise_scheduler)
# set up the hidden state resizing parameters
self.audio_fps = cfg.denoiser.params.audio_fps
self.hidden_fps = cfg.denoiser.params.hidden_fps
# set up the vertice dimension
self.nfeats = cfg.denoiser.params.nfeats
# guided diffusion
self.guidance_uncondp = cfg.model.guidance_uncondp if hasattr(cfg.model, "guidance_uncondp") else 0.0
self.guidance_scale = cfg.model.guidance_scale if hasattr(cfg.model, "guidance_scale") else 1.0
# assert self.guidance_scale >= 0.0 and self.guidance_scale <= 1.0
assert self.guidance_scale >= 0.0
self.do_classifier_free_guidence = self.guidance_scale > 0.0
self.smooth_output = False
if hasattr(cfg.model, "smooth_output") and cfg.model.smooth_output:
self.smooth_output = True
def allsplit_step(self, split: str, batch, batch_idx):
"""
One step
Args:
split (str): train, test, val
batch (dict): batch
batch contains:
template (torch.Tensor): [batch_size, vert_dim]
vertice (torch.Tensor): [batch_size, vert_len, vert_dim ]
vertice_attention (torch.Tensor): [batch_size, vert_len]
audio (torch.Tensor): [batch_size, aud_len]
audio_attention (torch.Tensor): [batch_size, aud_len]
id (torch.Tensor): [batch_size, id_dim]
batch_idx (int): batch index
"""
# training
if split == "train":
if self.guidance_uncondp > 0: # we randomly mask the audio feature
audio_mask = torch.rand(batch['audio'].shape[0]) < self.guidance_uncondp
batch['audio'][audio_mask] = 0
rs_set = self._diffusion_forward(batch, batch_idx, phase="train")
loss = self.losses[split].update(rs_set)
return loss
if split in ["val", ]:
# the id is not used in the validation
# because the id in the validation is not the same as anyone in the training
# so we set the id to be any one of the id in the training
bs = batch["vertice"].shape[0]
id_dim = self.cfg.denoiser.params.id_dim
# collect the results for each id
loss_list = []
for idx in range(id_dim):
batch["id"] = torch.zeros(bs, id_dim).to(batch["vertice"].device)
batch["id"][:, idx] = 1
with torch.no_grad():
# same as the training, we use the autoregressive inference
rs_set = self._diffusion_forward(batch, batch_idx, phase="val")
loss = self.losses[split].update(rs_set)
if loss is None:
return ValueError("loss is None")
loss_list.append(loss)
# visualize the result for the first id and the first batch
if batch_idx == 0 and idx == 0:
self._visualize(batch, rs_set)
loss = torch.stack(loss_list, dim=0).mean(dim=0)
return loss
if split in ["test"]:
from alm.models.losses.utils import biwi_upper_face_variance, biwi_mouth_distance
from alm.models.losses.utils import vocaset_upper_face_variance, vocaset_mouth_distance
# we also need to collect the results for each id in the test
bs = batch["vertice"].shape[0]
id_dim = self.cfg.denoiser.params.id_dim
# we don't need the vertice_attention in the test time
# because we do not have the ground truth vertice
if 'vertice_attention' in batch:
batch.pop('vertice_attention')
# collect the results for each id
ndim = batch['id'].ndim
if ndim == 2:
batch_id_list = [batch['id']] # the id is given
elif ndim == 3:
batch_id_list = []
for i in range(id_dim): # the id is not given, we use the id of each person in the training set
batch["id"] = torch.zeros(bs, id_dim).to(batch["vertice"].device)
batch["id"][:, i] = 1
batch_id_list.append(batch['id'])
else:
raise ValueError(f"the dimension of the id should be 2 or 3, but got {ndim}")
# collect the results for each id
metrics_list = {}
# for idx in tqdm(range(id_dim), desc="alternating among identities"):
for idx, batch_id in enumerate(batch_id_list):
# batch["id"] = torch.zeros(bs, id_dim).to(batch["vertice"].device)
# batch["id"][:, idx] = 1
id_idx = torch.where(batch_id == 1)[1][0].item()
batch["id"] = batch_id.to(batch["vertice"].device)
with torch.no_grad():
# start time
start_time = infer_time()
# same as the validation, we use the autoregressive inference
rs_set = self._diffusion_forward(batch, batch_idx, phase="val")
# end time
end_time = infer_time()
# calculate the metrics
if rs_set['vertice_pred'].shape[-1] == 70110: # BIWI
exp = 'BIWI'
pred = rs_set['vertice_pred'].view(-1, 23370, 3).detach().cpu().numpy()
gt = rs_set['vertice'].view(-1, 23370, 3).detach().cpu().numpy()
template = batch['template'].view(-1, 23370, 3).detach().cpu().numpy()
min_len = min(pred.shape[0], gt.shape[0])
pred = pred[:min_len, :]
gt = gt[:min_len, :]
metrics = {
"FDD": biwi_upper_face_variance( motion = (gt - template), ) \
- biwi_upper_face_variance(motion = (pred - template), ),
"Lip Vertex Error": biwi_mouth_distance(
vertices_gt=gt,
vertices_pred=pred,
).mean(),
"Length": torch.tensor(min_len).float(),
}
else: # VOCASET
exp = 'vocaset'
pred = rs_set['vertice_pred'].view(-1, 5023, 3).detach().cpu().numpy()
gt = rs_set['vertice'].view(-1, 5023, 3).detach().cpu().numpy()
template = batch['template'].view(-1, 5023, 3).detach().cpu().numpy()
min_len = min(pred.shape[0], gt.shape[0])
pred = pred[:min_len, :]
gt = gt[:min_len, :]
metrics = {
"FDD": vocaset_upper_face_variance( motion = (gt - template), ) \
- vocaset_upper_face_variance(motion = (pred - template), ),
"Lip Vertex Error": vocaset_mouth_distance(
vertices_gt=gt,
vertices_pred=pred,
).mean(),
"Length": torch.tensor(min_len).float(),
}
# save the results
# the code is a little bit messy, sorry for that
result_dir = self.cfg.FOLDER_EXP + '/results_{}'.format(exp)
os.makedirs(result_dir, exist_ok=True)
result_file = batch['file_name'][0].split('/')[-1].split('.')[0] + '_condition_' + str(id_idx) + '.pkl'
with open(os.path.join(result_dir, result_file), 'wb') as f:
report = {
'prediction': rs_set['vertice_pred'].detach().cpu().numpy(),
'ground_truth': rs_set['vertice'].detach().cpu().numpy(),
'template': batch['template'].detach().cpu().numpy(),
'audio_length': rs_set['vertice_pred'].shape[1] / self.cfg.hidden_fps,
'time': end_time - start_time,
'fdd': metrics['FDD'],
've': metrics['Lip Vertex Error'],
}
pickle.dump(report, f)
# save the metrics
if metrics is None:
return ValueError("metrics is None")
for key in metrics: # collect the metrics for each id
if key not in metrics_list:
metrics_list[key] = []
metrics_list[key].append(metrics[key])
# visualize the result for the first id and the first batch
if batch_idx == 0 and idx == 0:
None #TODO: visualize the result
# average the metrics for each id
for key in metrics_list:
metrics_list[key] = torch.stack(metrics_list[key], dim=0).mean(dim=0)
return metrics_list
def _memory_mask(self, hidden_attention, ):
"""
Create memory_mask for transformer decoder, which is used to mask the padding information
Args:
hidden_attention: [batch_len, source_len]
frame_num: int
"""
if self.denoiser.use_mem_attn_bias:
# since the source_len is the same as the target_len, we can use the same size to create the mask
memory_mask = self.denoiser.memory_bi_bias[:hidden_attention.shape[1], :hidden_attention.shape[1]]
# since the adapter is used, we need to unmask another position to make the adapter work, this position is the first and the second positions of the memory_mask
adpater_mask = torch.zeros_like(memory_mask[:, :2]) # [1, source_len], since the apdater length = id + time = 2
memory_mask = torch.cat([adpater_mask, memory_mask], dim = 1) # [source_len, latent_len + 2]
# # visualize the attention bias using sns.heatmap
# import seaborn as sns
# import matplotlib.pyplot as plt
# fig, ax = plt.subplots(figsize=(15, 10))
# length = 100
# # visualize the memory_mask
# minimum = 1
# mask = (-minimum* memory_mask[:length, :length+2].long()).detach().cpu().numpy().astype(int)
# sns.heatmap(mask, ax=ax,)
# # set the cbar to be discrete
# colorbar = ax.collections[0].colorbar
# colorbar.set_ticks([-minimum, 0])
# colorbar.set_ticklabels(['-inf','0'])
# # save the figure
# plt.savefig('memory_mask.png')
return memory_mask.bool().to(hidden_attention.device) # [source_len, latent_len + 2]
else:
return None
def _tgt_mask(self, vertice_attention, ):
"""
Create tgt_key_padding_mask for transformer decoder
Args:
vertice_attention: [batch_len, source_len]
frame_num: int
"""
if self.denoiser.use_tgt_attn_bias:
batch_size = vertice_attention.shape[0]
tgt_mask = self.denoiser.target_bi_bias[:, :vertice_attention.shape[1], :vertice_attention.shape[1]] # [num_heads, target_len, target_len]
adapter_mask = torch.zeros_like(tgt_mask[..., :2]) # [num_heads, target_len, 2], since the apdater length = id + time = 2
tgt_mask = torch.cat([adapter_mask, tgt_mask], dim = -1) # [num_heads, target_len, target_len + 2]
# # visualize the attention bias using sns.heatmap
# import seaborn as sns
# import matplotlib.pyplot as plt
# fig, ax = plt.subplots(figsize=(15, 10))
# length = 100
# # visualize the tgt_mask
# mask = (5 * tgt_mask[0, :length, :length+2]).long().detach().cpu().numpy().astype(int)
# # set the cbar to be discrete
# cbar_kws = {
# "ticks": np.arange(mask.min(), mask.max()+1),
# "boundaries": np.arange(mask.min() - 0.5, mask.max() + 1.5)
# }
# sns.heatmap(mask, ax=ax, cbar_kws=cbar_kws)
# # save the figure
# plt.savefig('tgt_mask.png')
# repeat the mask for each batch
tgt_mask = tgt_mask.repeat(batch_size, 1, 1) # [batch_size * num_heads, target_len, target_len + 2]
return tgt_mask.to(vertice_attention.device, non_blocking=True) # [batch_size * num_heads, target_len, target_len + 2]
else:
return None
def _mem_key_padding_mask(self, vertice_attention):
"""
Create mem_key_padding_mask for transformer decoder, which is used to mask the padding information
Args:
hidden_attention: [batch_len, source_len]
"""
# since the adapter is used, we need to unmask another position to make the adapter work
# this position is the first and the second positions of the mem_key_padding_mask
adpater_mask = torch.ones_like(vertice_attention[:, :2]) # [batch_size, 2], since the apdater length = id + time = 2
vertice_attention = torch.cat([adpater_mask, vertice_attention], dim = 1) # [batch_size, source_len + 2]
# mask with 1 means that the position is masked
return ~vertice_attention.bool()
def _tgt_key_padding_mask(self, vertice_attention):
"""
Create tgt_key_padding_mask for transformer decoder, which is used to mask the padding information
Args:
hidden_attention: [batch_len, target_len]
"""
# since the adapter is used, we need to unmask another position to make the adapter work
# this position is the first and the second positions of the tgt_key_padding_mask
adpater_mask = torch.ones_like(vertice_attention[:, :2])
vertice_attention = torch.cat([adpater_mask, vertice_attention], dim = 1)
# mask with 1 means that the position is masked
return ~vertice_attention.bool()
def _audio_resize(self, hidden_state: torch.Tensor, input_fps: Optional[float] = None , output_fps: Optional[float] = None, output_len = None):
"""
Resize the audio feature to the same length as the vertice
Args:
hidden_state (torch.Tensor): [batch_size, hidden_size, seq_len]
input_fps (float): input fps
output_fps (float): output fps
output_len (int): output length
"""
# if the input_fps and output_fps is not given, we use the default value
input_fps = input_fps if input_fps is not None else self.cfg.denoiser.params.audio_fps
output_fps = output_fps if output_fps is not None else self.cfg.denoiser.params.hidden_fps
hidden_state = hidden_state.transpose(1,2)
if output_len is None:
seq_len = hidden_state.shape[2] / input_fps
output_len = int(seq_len * output_fps)
output_features = F.interpolate(hidden_state, size = output_len, align_corners=True, mode="linear")
return output_features.transpose(2,1)
def _audio_2_hidden(self, audio, audio_attention, length = None):
"""
This function takes in an audio tensor and its corresponding attention mask,
and returns a hidden state tensor that represents the audio feature map.
The function first passes the audio tensor through an audio encoder to obtain the last hidden state.
It then resizes the hidden state to match the length of the input sequence, using the _audio_resize function.
Finally, the function passes the resized hidden state through the audio_feature_map layer of the denoiser to obtain the final hidden state tensor.
The output tensor has shape [batch_size, seq_len, latent_dim], where seq_len is the length of the input audio sequence and latent_dim is the dimensionality of the latent space.
"""
hidden_state = self.audio_encoder(audio, attention_mask = audio_attention).last_hidden_state
hidden_state = self._audio_resize(
hidden_state,
output_len = length # if vertice is not given, we use the full length of the audio
)
hidden_state = self.denoiser.audio_feature_map(hidden_state) # hidden_state.shape = [batch_size, seq_len, latent_dim]
return hidden_state
def _diffusion_forward(self, batch, batch_idx, phase):
"""
Forward pass for training
Args:
batch (dict): batch
batch contains:
template (torch.Tensor): [batch_size, vert_dim]
vertice (torch.Tensor): [batch_size, vert_len, vert_dim ]
vertice_attention (torch.Tensor): [batch_size, vert_len]
audio (torch.Tensor): [batch_size, aud_len]
audio_attention (torch.Tensor): [batch_size, aud_len]
id (torch.Tensor): [batch_size, id_dim]
phase (str): eihter 'train' or 'val'
batch_idx (int): batch index
"""
# process audio condition
hidden_state = self._audio_2_hidden(batch['audio'], batch['audio_attention'], length = batch['vertice'].shape[1] if 'vertice' in batch else None) # hidden_state.shape = [batch_size, seq_len, latent_dim]
if 'vertice_attention' not in batch:
# if the vertice_attention is not given, we assume that all the vertices are valid, so we set the attention to be all ones
batch['vertice_attention'] = torch.ones(
hidden_state.shape[0],
hidden_state.shape[1], # in our setting, the length of the vertice_attention should be the same as the length of the hidden_state
).long().to(hidden_state.device) # this attention should be long type
# template is subtracted from the vertice_input to make the template as the origin of the vertice
template = batch['template'].unsqueeze(1) # template.shape = [batch_size, 1, vert_dim]
if phase == 'train':
vertice_input = batch['vertice'] - template # vertice_input.shape = [batch_size, vert_len, vert_dim]
# perform the diffusion forward process
vertice_output = self._diffusion_process(
vertice_input,
hidden_state,
batch['id'],
vertice_attention = batch['vertice_attention'],
) + template # vertice_output.shape = [batch_size, vert_len, vert_dim]
elif phase == 'val':
if self.do_classifier_free_guidence:
silent_hidden_state = self._audio_2_hidden(
torch.zeros_like(batch['audio']), # we use the silent audio as the input
batch['audio_attention'],
length=hidden_state.shape[1], # just use the length of the hidden_state, in case their length is different
)
else:
silent_hidden_state = None
# perform the diffusion revise process
vertice_output = self._diffusion_reverse(
hidden_state,
batch['id'],
vertice_attention = batch['vertice_attention'],
silent_hidden_state = silent_hidden_state,
) + template # vertice_output.shape = [batch_size, vert_len, vert_dim]
if self.smooth_output:
# # smooth the prediction does not significantly affect the metric but makes the animation smoother
vertice_output = self.smooth(vertice_output)
else:
raise ValueError(f"phase should be either 'train' or 'val', but got {phase}")
rs_set = {
"vertice_pred": vertice_output,
"vertice": batch['vertice'] if 'vertice' in batch else None,
"vertice_attention": batch['vertice_attention'],
}
return rs_set
def smooth(self, vertices):
vertices_smooth = F.avg_pool1d(
vertices.permute(0, 2, 1),
kernel_size=3,
stride=1,
padding=1
).permute(0, 2, 1) # smooth the prediction with a moving average filter
vertices[:, 1:-1] = vertices_smooth[:, 1:-1]
return vertices
def predict(self, batch, **kwargs):
"""
Predict the result in the test time
Here the length of the vertice_attention is decided by the length of the audio
"""
if 'audio_attention' not in batch:
# if the audio_attention is not given, we assume that all the audio is valid, so we set the attention to be all ones
batch['audio_attention'] = torch.ones(
batch['audio'].shape[0],
batch['audio'].shape[1],
).long().to(batch['audio'].device) # this attention should be long type
if 'id' not in batch:
# if the id is not given, we use the id of the first person in the training set
batch['id'] = kwargs.get(
'id',
torch.zeros(
1, #batch['vertice'].shape[0],
self.cfg.denoiser.params.id_dim
).to(batch['audio'].device)
)
batch['id'][:, 0] = 1
else:
assert batch['id'].shape[1] == self.cfg.denoiser.params.id_dim, \
f"the id dimension should be {self.cfg.denoiser.params.id_dim}, but got {batch['id'].shape[1]}"
if 'vertice' in batch:
# if the vertice is given, we use the given vertice as the vertice attention mask
batch['vertice_attention'] = torch.ones(
batch['vertice'].shape[0],
batch['vertice'].shape[1]
).long().to(batch['vertice'].device)
# this attention should be long type
# add the batch dimension to the template
batch['template'] = batch['template'][None, ...]
# perform the diffusion forward process
vertice_output = self._diffusion_forward(batch, 0, 'val')['vertice_pred'] # vertice_output.shape = [batch_size, vert_len, vert_dim]
rs_set = {
"vertice_pred": vertice_output,
"vertice": batch['vertice'] if 'vertice' in batch else None,
"vertice_attention": batch['vertice_attention'],
}
return rs_set
def _diffusion_process(
self,
vertice_input: torch.Tensor,
hidden_state: torch.Tensor,
id: torch.Tensor,
vertice_attention: Optional[torch.Tensor] = None,
):
"""
Perform the diffusion forward process during training
Args:
vertice_input (torch.Tensor): [batch_size, vert_len, vert_dim], the grount truth vertices, padding may included
hidden_state (torch.Tensor): [batch_size, seq_len, latent_dim], the audio feature, padding may included
id (torch.Tensor): [batch_size, id_dim], the id of the subject
vertice_attention (torch.Tensor): [batch_size, vert_len], the attention of the vertices to indicate which vertices are valid, since the audio feature has the same length as the vertices, the vertice_attention should be the same length as the hidden_state
"""
# extract the id style
object_emb = self.denoiser.obj_vector(torch.argmax(id, dim = 1)).unsqueeze(1) # object_emb.shape = [batch_size, 1, latent_dim]
# sample noise
noise = torch.randn_like(vertice_input) # noise.shape = [batch_size, vert_len, vert_dim]
# sample a random timestep for the minibatch
bsz = vertice_input.shape[0]
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bsz,),
device = vertice_input.device
) # timesteps.shape = [batch_size]
# add noise to the latents
noise_input = self.noise_scheduler.add_noise(
vertice_input,
noise,
timesteps,
) # noise_input.shape = [batch_size, vert_len, vert_dim]
# predict the noise or the input
vertice_pred = self.denoiser(
vertice_input = noise_input, # noise_input.shape = [batch_size, vert_len, vert_dim]
hidden_state = hidden_state, # hidden_state.shape = [batch_size, seq_len, latent_dim]
timesteps = timesteps, # timesteps.shape = [batch_size]
adapter = object_emb, # object_emb.shape = [batch_size, 1, latent_dim]
tgt_mask = self._tgt_mask(vertice_attention), # tgt_mask.shape = [vert_len, vert_len]
memory_mask = self._memory_mask(vertice_attention), # memory_mask.shape = [vert_len, seq_len]
tgt_key_padding_mask = self._tgt_key_padding_mask(vertice_attention), # tgt_key_padding_mask.shape = [batch_size, vert_len]
memory_key_padding_mask = self._mem_key_padding_mask(vertice_attention), # memory_key_padding_mask.shape = [batch_size, seq_len]
)
return vertice_pred
def _diffusion_reverse(
self,
hidden_state: torch.Tensor,
id: torch.Tensor,
vertice_attention: torch.Tensor,
silent_hidden_state: Optional[torch.Tensor] = None,
):
"""
Perform the diffusion reverse process during inference
Args:
hidden_state (torch.Tensor): [batch_size, seq_len, latent_dim], the audio feature, padding may included
id (torch.Tensor): [batch_size, id_dim], the id of the subject
vertice_attention (torch.Tensor): [batch_size, vert_len], the attention of the vertices to indicate which vertices are valid, since the audio feature has the same length as the vertices, the vertice_attention should be the same length as the hidden_state
"""
# extract the id style
object_emb = self.denoiser.obj_vector(torch.argmax(id, dim = 1)).unsqueeze(1) # object_emb.shape = [batch_size, 1, latent_dim]
# sample noise
vertices = torch.randn(
(
hidden_state.shape[0], # batch_size
hidden_state.shape[1], # vert_len
self.nfeats, # latent_dim
),
device = hidden_state.device,
dtype = torch.float,
)
# scale the initial noise by the standard deviation required by the scheduler
vertices = vertices * self.scheduler.init_noise_sigma
# set timesteps
self.scheduler.set_timesteps(self.cfg.model.scheduler.num_inference_timesteps)
timesteps = self.scheduler.timesteps.to(hidden_state.device, non_blocking=True)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, and between [0, 1]
extra_step_kwargs = {}
if "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()):
extra_step_kwargs["eta"] = self.cfg.model.scheduler.eta
if silent_hidden_state is not None: # self.do_classifier_free_guidence is True
hidden_state = torch.cat([hidden_state, silent_hidden_state], dim = 0) # hidden_state.shape = [batch_size * 2, seq_len, latent_dim
vertice_attention = torch.cat([vertice_attention, ] * 2, dim = 0) # vertice_attention.shape = [batch_size * 2, vert_len]
object_emb = torch.cat([object_emb, ] * 2, dim = 0) # object_emb.shape = [batch_size * 2, 1, latent_dim]
# perform denoising
for i, t in enumerate(timesteps):
if silent_hidden_state is not None: # self.do_classifier_free_guidence is True
vertices = torch.cat(
[vertices] * 2,
dim = 0,
) # vertices.shape = [batch_size * 2, vert_len, latent_dim]
# perform denoising step
vertices_pred = self.denoiser(
vertice_input = vertices, # vertices.shape = [batch_size, vert_len, latent_dim]
hidden_state = hidden_state, # hidden_state.shape = [batch_size, seq_len, latent_dim]
timesteps = t.expand(hidden_state.shape[0]), # timesteps.shape = [batch_size]
adapter = object_emb, # object_emb.shape = [batch_size, 1, latent_dim]
tgt_mask = self._tgt_mask(vertice_attention), # tgt_mask.shape = [vert_len, vert_len]
memory_mask = self._memory_mask(vertice_attention), # memory_mask.shape = [vert_len, seq_len]
tgt_key_padding_mask = self._tgt_key_padding_mask(vertice_attention), # tgt_key_padding_mask.shape = [batch_size, vert_len]
memory_key_padding_mask = self._mem_key_padding_mask(vertice_attention), # memory_key_padding_mask.shape = [batch_size, seq_len]
)
# perform guided denoising step
if silent_hidden_state is not None: # self.do_classifier_free_guidence is True
vertices_pred_audio, vertices_pred_uncond = vertices_pred.chunk(2, dim = 0)
vertices_pred = vertices_pred_audio + (vertices_pred_audio - vertices_pred_uncond)* self.guidance_scale
vertices, _ = vertices.chunk(2, dim = 0)
vertices = self.scheduler.step(vertices_pred, t, vertices, **extra_step_kwargs).prev_sample
return vertices
def _visualize(self, batch, rs_set, parrallel = True):
"""
Visualize the result
Args:
batch (dict): batch
batch contains:
file_path (list): audio file path
vertice (torch.Tensor): [batch_size, vert_len, vert_dim ]
rs_set (dict): result set
rs_set contains:
vertice_pred (torch.Tensor): [batch_size, vert_len, vert_dim ]
parrallel (bool): if True, the visualization will be performed in the current thread,
otherwise, the visualization will be performed in a new thread
"""
# visualize the result only for the first data in the batch
data_idx = 0
audio_path = batch["file_path"][data_idx]
vis_path = os.path.join(
self.cfg.FOLDER_EXP,
"visualization",
"{}_{}.{}".format(
time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()), # current time, it should be better to use the epoch number, but I am lazy
audio_path.split("/")[-1].split(".")[0], # audio name
'mp4' # video format
)
)
# visualize the result
if not parrallel:
# use the current thread to visualize the result
animate(
vertices = rs_set["vertice_pred"][data_idx, ...].squeeze().cpu().numpy(),
wav_path = audio_path,
file_name = vis_path,
ply = self.cfg.DEMO.PLY,
fps = self.cfg.DEMO.FPS,
)
else:
# use another thread to visualize the result
p = Process(
target=animate,
args=(
rs_set["vertice_pred"][data_idx, ...].squeeze().cpu().numpy(),
audio_path,
vis_path,
self.cfg.DEMO.PLY,
self.cfg.DEMO.FPS,
),
)
p.start()
================================================
FILE: alm/utils/__init__.py
================================================
================================================
FILE: alm/utils/demo_utils.py
================================================
from transformers import Wav2Vec2Processor
import numpy as np
import librosa
import os
import torch
import cv2
import pyrender
import trimesh
import tempfile
import imageio
from tqdm import tqdm
try:
from psbody.mesh import Mesh
except:
Mesh = None
import platform
if platform.system() == "Linux":
# os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
os.environ['PYOPENGL_PLATFORM'] = 'egl'
def load_example_input(audio_path, processor = None):
if processor is None:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
speech_array, sampling_rate = librosa.load(
os.path.join(audio_path),
sr=16000
)
audio_feature = np.squeeze(
processor(
speech_array,
sampling_rate = sampling_rate
).input_values
)
audio_feature = np.reshape(
audio_feature,
(-1,audio_feature.shape[0])
)
return torch.FloatTensor(audio_feature)
# # The implementation of rendering is borrowed from VOCA: https://github.com/TimoBolkart/voca/blob/master/utils/rendering.py
# def render_mesh_helper(mesh, t_center, rot=np.zeros(3), tex_img=None, z_offset=0, template_type: str = "flame"):
# assert template_type in ["flame", "biwi"], "template_type should be one of ['flame', 'biwi'],but got {}".format(template_type)
# if template_type == "flame":
# camera_params = {'c': np.array([400, 400]),
# 'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
# 'f': np.array([4754.97941935 / 2, 4754.97941935 / 2])}
# elif template_type == "biwi":
# camera_params = {'c': np.array([400, 400]),
# 'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
# 'f': np.array([4754.97941935 / 8, 4754.97941935 / 8])}
# frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800}
# mesh_copy = Mesh(mesh.v, mesh.f)
# mesh_copy.v[:] = cv2.Rodrigues(rot)[0].dot((mesh_copy.v-t_center).T).T+t_center
# intensity = 2.0
# rgb_per_v = None
# primitive_material = pyrender.material.MetallicRoughnessMaterial(
# alphaMode='BLEND',
# baseColorFactor=[0.3, 0.3, 0.3, 1.0],
# metallicFactor=0.8,
# roughnessFactor=0.8
# )
# tri_mesh = trimesh.Trimesh(vertices=mesh_copy.v, faces=mesh_copy.f, vertex_colors=rgb_per_v)
# render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=primitive_material,smooth=True)
# # scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0])
# scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
# camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0],
# fy=camera_params['f'][1],
# cx=camera_params['c'][0],
# cy=camera_params['c'][1],
# znear=frustum['near'],
# zfar=frustum['far'])
# scene.add(render_mesh, pose=np.eye(4))
# camera_pose = np.eye(4)
# camera_pose[:3,3] = np.array([0, 0, 1.0-z_offset])
# scene.add(camera, pose=[[1, 0, 0, 0],
# [0, 1, 0, 0],
# [0, 0, 1, 1],
# [0, 0, 0, 1]])
# angle = np.pi / 6.0
# pos = camera_pose[:3,3]
# light_color = np.array([1., 1., 1.])
# light = pyrender.DirectionalLight(color=light_color, intensity=intensity)
# light_pose = np.eye(4)
# light_pose[:3,3] = pos
# scene.add(light, pose=light_pose.copy())
# light_pose[:3,3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos)
# scene.add(light, pose=light_pose.copy())
# light_pose[:3,3] = cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos)
# scene.add(light, pose=light_pose.copy())
# light_pose[:3,3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos)
# scene.add(light, pose=light_pose.copy())
# light_pose[:3,3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos)
# scene.add(light, pose=light_pose.copy())
# flags = pyrender.RenderFlags.SKIP_CULL_FACES
# # try:
# r = pyrender.OffscreenRenderer(viewport_width=frustum['width'], viewport_height=frustum['height'])
# color, _ = r.render(scene, flags=flags)
# # except:
# # print('pyrender: Failed rendering frame')
# # color = np.zeros((frustum['height'], frustum['width'], 3), dtype='uint8')
# return color[..., ::-1]
# The implementation of rendering is borrowed from VOCA: https://github.com/TimoBolkart/voca/blob/master/utils/rendering.py
def render_mesh_helper(mesh, t_center, rot=np.zeros(3), tex_img=None, z_offset=0, template_type: str = "flame", rgb_per_v = None):
assert template_type in ["flame", "biwi"], "template_type should be one of ['flame', 'biwi'],but got {}".format(template_type)
if template_type == "flame":
camera_params = {'c': np.array([400, 400]),
'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
'f': np.array([4754.97941935 / 2, 4754.97941935 / 2])}
elif template_type == "biwi":
camera_params = {'c': np.array([400, 400]),
'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
'f': np.array([4754.97941935 / 8, 4754.97941935 / 8])}
frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800}
mesh_copy = Mesh(mesh.v, mesh.f)
mesh_copy.v[:] = cv2.Rodrigues(rot)[0].dot((mesh_copy.v-t_center).T).T+t_center
if rgb_per_v is None:
intensity = 2.0
primitive_material = pyrender.material.MetallicRoughnessMaterial(
alphaMode='BLEND',
baseColorFactor=[0.3, 0.3, 0.3, 1.0],
metallicFactor=0.8,
roughnessFactor=0.8
)
tri_mesh = trimesh.Trimesh(vertices=mesh_copy.v, faces=mesh_copy.f, vertex_colors=rgb_per_v)
render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=primitive_material,smooth=True)
else:
intensity = 0.5
tri_mesh = trimesh.Trimesh(vertices=mesh_copy.v, faces=mesh_copy.f, vertex_colors=rgb_per_v)
render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, smooth=True)
# scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0])
scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0],
fy=camera_params['f'][1],
cx=camera_params['c'][0],
cy=camera_params['c'][1],
znear=frustum['near'],
zfar=frustum['far'])
scene.add(render_mesh, pose=np.eye(4))
camera_pose = np.eye(4)
camera_pose[:3,3] = np.array([0, 0, 1.0-z_offset])
scene.add(camera, pose=[[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 1],
[0, 0, 0, 1]])
angle = np.pi / 6.0
pos = camera_pose[:3,3]
light_color = np.array([1., 1., 1.])
light = pyrender.DirectionalLight(color=light_color, intensity=intensity)
light_pose = np.eye(4)
light_pose[:3,3] = pos
scene.add(light, pose=light_pose.copy())
light_pose[:3,3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos)
scene.add(light, pose=light_pose.copy())
light_pose[:3,3] = cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos)
scene.add(light, pose=light_pose.copy())
light_pose[:3,3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos)
scene.add(light, pose=light_pose.copy())
light_pose[:3,3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos)
scene.add(light, pose=light_pose.copy())
flags = pyrender.RenderFlags.SKIP_CULL_FACES
# try:
r = pyrender.OffscreenRenderer(viewport_width=frustum['width'], viewport_height=frustum['height'])
color, _ = r.render(scene, flags=flags)
# except:
# print('pyrender: Failed rendering frame')
# color = np.zeros((frustum['height'], frustum['width'], 3), dtype='uint8')
return color[..., ::-1]
def render_frame(args):
predicted_vertice, f, center, template_type = args
render_mesh = Mesh(predicted_vertice, f)
pred_img = render_mesh_helper(render_mesh, center, template_type=template_type)
pred_img = pred_img.astype(np.uint8)
return pred_img
def animate(vertices: np.array, wav_path: str, file_name: str, ply: str, fps: int = 25, vertice_gt: np.array = None, use_tqdm: bool = False, multi_process = False):
"""
Animate the predicted vertices with the synchronized audio and save the video to the output directory.
Args:
vertices: (num_frames, num_vertices*3)
wav_path: path to wav file
file_name: name of the output file
ply: path to the ply file
fps: frames per second
use_tqdm: whether to use tqdm to show the progress
vertice_gt: (num_frames, num_vertices*3)
template: template to use, can be "flame" or "biwi"
"""
# make output dir
output_dir = os.path.dirname(file_name)
os.makedirs(output_dir, exist_ok=True)
template = Mesh(filename=ply)
# determine biwi or flame
if "FLAME" in ply:
template_type = "flame"
elif "BIWI" in ply:
template_type = "biwi"
else:
raise ValueError("Template type not recognized, please use either BIWI or FLAME")
# reshape vertices
predicted_vertices = vertices.reshape(-1, vertices.shape[1]//3, 3) if vertices.ndim < 3 else vertices
num_frames = predicted_vertices.shape[0]
if vertice_gt is not None:
vertice_gt = vertice_gt.reshape(-1, vertice_gt.shape[1]//3, 3) if vertice_gt.ndim < 3 else vertice_gt
num_frames = np.where(np.sum(vertice_gt, axis=(1, 2)) != 0)[0][-1] + 1 # find the number of frames where the vertices are not all zeros
tmp_video_file = tempfile.NamedTemporaryFile('w', suffix='.mp4', dir=output_dir)
center = np.mean(predicted_vertices[0], axis=0)
# make animation
if multi_process:
from multiprocessing import Pool, cpu_count
from itertools import cycle
# get maximum num of process
frames = []
max_processes = cpu_count()
with Pool(processes=max_processes) as pool:
args = [(
predicted_vertice,
template.f,
center,
template_type
) for predicted_vertice in predicted_vertices]
for pred_img in pool.imap(render_frame, tqdm(args)):
frames.append(pred_img)
if vertice_gt is not None:
frames_gt = []
with Pool(processes=max_processes) as pool:
args = [(
gt_vertice,
template.f,
center,
template_type
) for gt_vertice in vertice_gt]
for gt_img in pool.imap(render_frame, tqdm(args)):
frames_gt.append(gt_img)
# concat two videos
frames_final = []
for i in range(num_frames):
frames_final.append(np.concatenate([frames_gt[i], frames[i]], axis=1))
frames = frames_final
else:
frames = []
for i_frame in tqdm(range(num_frames)) if use_tqdm else range(num_frames):
render_mesh = Mesh(predicted_vertices[i_frame], template.f)
pred_img = render_mesh_helper(render_mesh, center, template_type=template_type)
pred_img = pred_img.astype(np.uint8)
frames.append(pred_img)
if vertice_gt is not None:
frames_gt = []
for i_frame in tqdm(range(num_frames)) if use_tqdm else range(num_frames):
render_mesh = Mesh(vertice_gt[i_frame], template.f)
pred_img = render_mesh_helper(render_mesh, center)
pred_img = pred_img.astype(np.uint8)
frames_gt.append(pred_img)
# concat two videos
frames_final = []
for i in range(num_frames):
frames_final.append(np.concatenate([frames_gt[i], frames[i]], axis=1))
frames = frames_final
imageio.mimsave(tmp_video_file.name, frames, fps = fps)
cmd = " ".join(['ffmpeg', '-hide_banner -loglevel error', '-y', '-i', tmp_video_file.name, '-i', wav_path, '-c:v copy -c:a aac', '-pix_fmt yuv420p -qscale 0',file_name, ])
cmd = " ".join(['ffmpeg', '-i', tmp_video_file.name, '-i', wav_path, '-c:v copy -c:a aac', '-pix_fmt yuv420p -qscale 0',file_name, ])
os.system(cmd)
tmp_dir = tempfile.gettempdir() # check if the wav file is in the tmp dir
if os.path.exists(wav_path) and tmp_dir in wav_path:
os.remove(wav_path)
print(f"Video saved to {file_name}")
================================================
FILE: alm/utils/logger.py
================================================
from pathlib import Path
import os
import time
import logging
from omegaconf import OmegaConf
from pytorch_lightning.utilities.rank_zero import rank_zero_only
def create_logger(cfg, phase='train'):
# root dir set by cfg
root_output_dir = Path(cfg.FOLDER)
# set up logger
if not root_output_dir.exists():
print('=> creating {}'.format(root_output_dir))
root_output_dir.mkdir()
cfg_name = cfg.NAME
model = cfg.model.model_type
cfg_name = os.path.basename(cfg_name).split('.')[0]
final_output_dir = root_output_dir / model / cfg_name
cfg.FOLDER_EXP = str(final_output_dir)
time_str = time.strftime('%Y-%m-%d-%H-%M-%S')
new_dir(cfg, phase, time_str, final_output_dir)
head = '%(asctime)-15s %(message)s'
logger = config_logger(final_output_dir, time_str, phase, head)
if logger is None:
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)
logging.basicConfig(format=head)
return logger
@rank_zero_only
def config_logger(final_output_dir, time_str, phase, head):
log_file = '{}_{}_{}.log'.format('log', time_str, phase)
final_log_file = final_output_dir / log_file
logging.basicConfig(filename=str(final_log_file))
logger = logging.getLogger()
logger.setLevel(logging.INFO)
console = logging.StreamHandler()
formatter = logging.Formatter(head)
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
file_handler = logging.FileHandler(final_log_file, 'w')
file_handler.setFormatter(logging.Formatter(head))
file_handler.setLevel(logging.INFO)
logging.getLogger('').addHandler(file_handler)
return logger
@rank_zero_only
def new_dir(cfg, phase, time_str, final_output_dir):
# new experiment folder
cfg.TIME = str(time_str)
if os.path.exists(
final_output_dir) and cfg.TRAIN.RESUME is None and not cfg.DEBUG:
file_list = sorted(os.listdir(final_output_dir), reverse=True)
for item in file_list:
if item.endswith('.log'):
os.rename(str(final_output_dir),
str(final_output_dir) + '_' + cfg.TIME)
break
final_output_dir.mkdir(parents=True, exist_ok=True)
# write config yaml
config_file = '{}_{}_{}.yaml'.format('config', time_str, phase)
final_config_file = final_output_dir / config_file
OmegaConf.save(config=cfg, f=final_config_file)
================================================
FILE: alm/utils/temos_utils.py
================================================
from typing import Dict, List
import numpy as np
import torch
from torch import Tensor
def lengths_to_mask(lengths: Tensor, # [batch_size]
device: torch.device,
max_len: int = None) -> Tensor:
max_len = max_len if max_len else max(lengths.cpu().tolist())
mask = torch.arange(max_len, device=device).expand(
len(lengths), max_len) < lengths.unsqueeze(1)
return mask
def remove_padding(tensors, lengths):
return [
tensor[:tensor_length]
for tensor, tensor_length in zip(tensors, lengths)
]
================================================
FILE: configs/assets/biwi.yaml
================================================
FOLDER: './experiments/biwi' # Experiment files saving path
TEST:
FOLDER: './results' # Testing files saving path
DATASET:
BIWI:
ROOT: ./datasets/biwi
================================================
FILE: configs/assets/vocaset.yaml
================================================
FOLDER: './experiments/vocaset' # Experiment files saving path
TEST:
FOLDER: './results' # Testing files saving path
DATASET:
VOCASET:
ROOT: ./datasets/vocaset
================================================
FILE: configs/base.yaml
================================================
# FOLDER: ./experiments
SEED_VALUE: 1234
DEBUG: True
TRAIN:
SPLIT: 'train'
NUM_WORKERS: 1 #2 # Number of workers
BATCH_SIZE: 4 # Size of batches
START_EPOCH: 0 # Start epoch
END_EPOCH: 2000 # End epoch
RESUME: '' # Experiment path to be resumed training
PRETRAINED: '' # Pretrained model path
OPTIM:
OPTIM.TYPE: 'AdamW' # Optimizer type
OPTIM.LR: 1e-4 # Learning rate
ABLATION:
SKIP_CONNECT: False # skip connection for denoiser va
# use linear to expand mean and std rather expand token nums
MLP_DIST: False
IS_DIST: False # Mcross distribution kl
EVAL:
SPLIT: 'gtest'
BATCH_SIZE: 1 # Evaluating Batch size
NUM_WORKERS: 1 #12 # Evaluating Batch size
TEST:
TEST_DIR: ''
CHECKPOINTS: '' # Pretrained model path
NUM_WORKERS: 1 #12 # Evaluating Batch size
BATCH_SIZE: 1 # Evaluating Batch size
REPLICATION_TIMES: 1 # Number of times to replicate the test
SAVE_PREDICTIONS: False # Weather to save predictions
COUNT_TIME: False # Weather to count time during test
model:
target: 'modules'
LOSS:
METRIC:
None
DATASET:
VOCASET:
NONE: none
DEMO:
EAMPLE: null
ID: null
CHECKPOINTS: "templates"
TEMPLATE: "datasets/vocaset/templates.pkl"
PLY: "datasets/vocaset/templates/FLAME_sample.ply"
FPS: 30
LOGGER:
SACE_CHECKPOINT_EPOCH: 1
LOG_EVERY_STEPS: 1
VAL_EVERY_STEPS: 10
TENSORBOARD: true
WANDB:
OFFLINE: false
PROJECT: null
RESUME_ID: null
================================================
FILE: configs/diffusion/biwi/diffspeaker_hubert_biwi.yaml
================================================
NAME: diffspeaker_hubert_biwi # Experiment name
DEBUG: False # Debug mode
ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
DEVICE: [0, 1, 2, 3, 4, 5, 6, 7] # Index of gpus eg. [0] or [0,1,2,3]
# Training configuration
TRAIN:
#---------------------------------
DATASETS: ['biwi'] # Training datasets
NUM_WORKERS: 1 # Number of workers
BATCH_SIZE: 8 # Size of batches
START_EPOCH: 0 # Start epochMMOTIONENCODER
END_EPOCH: 500 # End epoch
RESUME: '' # Resume training from this path
OPTIM:
TYPE: AdamW # Optimizer type
LR: 1e-4 # Learning rate
# Evaluating Configuration
EVAL:
DATASETS: ['biwi'] # Evaluating datasets
BATCH_SIZE: 8 # Evaluating Batch size
# Datasets Configuration
DATASET:
JOINT_TYPE: 'biwi' # join type
TEST:
CHECKPOINTS: checkpoints/biwi/diffspeaker_hubert_biwi.ckpt # Pretrained model path
DATASETS: ['biwi'] # training datasets
BATCH_SIZE: 1 # training Batch size
SPLIT: test # split type
REPLICATION_TIMES: 10 # replication times
# Losses Configuration
LOSS:
TYPE: voca # Losses type
VERTICE_ENC: 1 # Lambda for vertices reconstruction Losses
VERTICE_ENC_V: 1 # lambda for vertices velocity reconstruction loss
LIP_ENC: 0 # lambda for lip reconstruction loss
LIP_ENC_V: 0 # lambda for lip velocity reconstruction loss
DIST_SYNC_ON_STEP: True # Sync Losses on step when distributed trained
audio_encoder:
train_audio_encoder: True
model_name_or_path: "facebook/hubert-base-ls960"
# Model Configuration
model:
target: 'diffusion/diffusion_bias_modules'
audio_encoded_dim: 768 # audio hidden dimension
model_type: diffusion_bias # model type
latent_dim: 1024 # latent dimension
id_dim: 6 # the dimension of the id vector
ff_size: 2048 # latent_dim * 2
num_layers: 1 # number of layers
num_heads: 4 # number of head layers
dropout: 0.1 # dropout rate
max_len: 600 # the attention mask maximum length
activation: gelu # activation type
normalize_before: True
require_start_token: True # start_token is need for autogressive generation only
arch: 'default'
predict_epsilon: False # noise or motion, motion here
freq_shift: 0
flip_sin_to_cos: True
mem_attn_scale: 1.
tgt_attn_scale: 1.
audio_fps: 50
hidden_fps: 25 # 30
guidance_scale: 0 # not used
guidance_uncondp: 0. # not used
period: 25
no_cross: False
smooth_output: False
# rewrite the template
DEMO:
EAMPLE: null
ID: null
TEMPLATE: "datasets/biwi/templates.pkl"
PLY: "datasets/biwi/templates/BIWI.ply"
FPS: 25
# Logger configuration
LOGGER:
SACE_CHECKPOINT_EPOCH: 100
LOG_EVERY_STEPS: 10
VAL_EVERY_STEPS: 100 # 200
TENSORBOARD: True
WANDB:
PROJECT: null
OFFLINE: False
RESUME_ID: null
================================================
FILE: configs/diffusion/biwi/diffspeaker_wav2vec2_biwi.yaml
================================================
NAME: diffspeaker_wav2vec2_biwi # Experiment name
DEBUG: False # Debug mode
ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
DEVICE: [0, 1, 2, 3, 4, 5, 6, 7] # Index of gpus eg. [0] or [0,1,2,3]
# Training configuration
TRAIN:
#---------------------------------
DATASETS: ['biwi'] # Training datasets
NUM_WORKERS: 1 # Number of workers
BATCH_SIZE: 16 # Size of batches
START_EPOCH: 0 # Start epochMMOTIONENCODER
END_EPOCH: 700 # End epoch
RESUME: '' # Resume training from this path
OPTIM:
TYPE: AdamW # Optimizer type
LR: 1e-4 # Learning rate
# Evaluating Configuration
EVAL:
DATASETS: ['biwi'] # Evaluating datasets
BATCH_SIZE: 16 # Evaluating Batch size
# Datasets Configuration
DATASET:
JOINT_TYPE: 'biwi' # join type
TEST:
CHECKPOINTS: checkpoints/biwi/diffspeaker_wav2vec2_biwi.ckpt #experiments/biwi/diffusion_bias/diffspeaker_wav2vec2_biwi/checkpoints/epoch=699.ckpt # Pretrained model path
DATASETS: ['biwi'] # training datasets
BATCH_SIZE: 1 # training Batch size
SPLIT: test # split type
REPLICATION_TIMES: 10 # replication times for each test sample
# Losses Configuration
LOSS:
TYPE: voca # Losses type
VERTICE_ENC: 1 # Lambda for vertices reconstruction Losses
VERTICE_ENC_V: 1 # lambda for vertices velocity reconstruction loss
LIP_ENC: 0 # lambda for lip reconstruction loss
LIP_ENC_V: 0 # lambda for lip velocity reconstruction loss
DIST_SYNC_ON_STEP: True # Sync Losses on step when distributed trained
audio_encoder:
train_audio_encoder: True
model_name_or_path: 'facebook/wav2vec2-base-960h'
# Model Configuration
model:
target: 'diffusion/diffusion_bias_modules'
audio_encoded_dim: 768 # audio hidden dimension
model_type: diffusion_bias # model type
latent_dim: 1024 # latent dimension
id_dim: 6 # the dimension of the id vector
ff_size: 2048 # latent_dim * 2
num_layers: 1 # number of layers
num_heads: 4 # number of head layers
dropout: 0.1 # dropout rate
max_len: 600 # the attention mask maximum length
activation: gelu # activation type
normalize_before: True
require_start_token: True # start_token is need for autogressive generation only
arch: 'default'
predict_epsilon: False # noise or motion, motion here
freq_shift: 0
flip_sin_to_cos: True
mem_attn_scale: 1.
tgt_attn_scale: 1.
audio_fps: 50
hidden_fps: 25 # 30
guidance_scale: 0 # not used
guidance_uncondp: 0. # not used
period: 25
no_cross: False
smooth_output: False
# rewrite the template
DEMO:
EAMPLE: null
ID: null
TEMPLATE: "datasets/biwi/templates.pkl"
PLY: "datasets/biwi/templates/BIWI.ply"
FPS: 25
# Logger configuration
LOGGER:
SACE_CHECKPOINT_EPOCH: 100
LOG_EVERY_STEPS: 10
VAL_EVERY_STEPS: 100 # 200
TENSORBOARD: True
WANDB:
PROJECT: null
OFFLINE: False
RESUME_ID: null
================================================
FILE: configs/diffusion/diffusion_bias_modules/denoiser.yaml
================================================
denoiser: # this is copied from configs/baselines/transformer_adpt_modules/transformer.yaml
target: alm.models.architectures.adpt_bias_denoiser.Adpt_Bias_Denoiser
params:
audio_encoded_dim: ${model.audio_encoded_dim}
ff_size: ${model.ff_size}
num_layers: ${model.num_layers}
num_heads: ${model.num_heads}
dropout: ${model.dropout}
normalize_before: ${model.normalize_before}
activation: ${model.activation}
return_intermediate_dec: False
arch: ${model.arch}
latent_dim: ${model.latent_dim}
nfeats: ${DATASET.NFEATS}
freq_shift: ${model.freq_shift}
flip_sin_to_cos: ${model.flip_sin_to_cos}
max_len: 3000 # the attention mask maximum length
id_dim: ${model.id_dim} # the number of identities
require_start_token: ${model.require_start_token} # start_token is need for autogressive generation only
mem_attn_scale: ${model.mem_attn_scale}
tgt_attn_scale: ${model.tgt_attn_scale}
audio_fps: ${model.audio_fps} #
hidden_fps: ${model.hidden_fps}
# unconditional generation
guidance_scale: ${model.guidance_scale}
guidance_uncondp: ${model.guidance_uncondp}
period: ${model.period}
no_cross: ${model.no_cross}
================================================
FILE: configs/diffusion/diffusion_bias_modules/scheduler.yaml
================================================
scheduler:
target: diffusers.DDIMScheduler
num_inference_timesteps: 50
eta: 0.0
params:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
# variance_type: 'fixed_small'
clip_sample: false # clip sample to -1~1
prediction_type: 'sample'
# below are for ddim
set_alpha_to_one: false
steps_offset: 1
noise_scheduler:
target: diffusers.DDPMScheduler
params:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
variance_type: 'fixed_small'
prediction_type: 'sample'
clip_sample: false #
================================================
FILE: configs/diffusion/vocaset/diffspeaker_hubert_vocaset.yaml
================================================
NAME: diffspeaker_hubert_vocaset # Experiment name
DEBUG: False # Debug mode
ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3]
# Training configuration
TRAIN:
#---------------------------------
DATASETS: ['vocaset'] # Training datasets
NUM_WORKERS: 1 # Number of workers
BATCH_SIZE: 32 # Size of batches
START_EPOCH: 0 # Start epochMMOTIONENCODER
END_EPOCH: 10000 # End epoch
RESUME: '' # Resume training from this path
OPTIM:
TYPE: AdamW # Optimizer type
LR: 1e-4 # Learning rate
# Evaluating Configuration
EVAL:
DATASETS: ['vocaset'] # Evaluating datasets
BATCH_SIZE: 32 # Evaluating Batch size
# Datasets Configuration
DATASET:
JOINT_TYPE: 'vocaset' # join type
TEST:
CHECKPOINTS: checkpoints/vocaset/diffspeaker_hubert_vocaset.ckpt # Pretrained model path
DATASETS: ['vocaset'] # training datasets
BATCH_SIZE: 1 # training Batch size
SPLIT: test # split type
REPLICATION_TIMES: 10 # replication times for each test sample
# Losses Configuration
LOSS:
TYPE: voca # Losses type
VERTICE_ENC: 1 # Lambda for vertices reconstruction Losses
VERTICE_ENC_V: 1 # lambda for vertices velocity reconstruction loss
LIP_ENC: 0 # lambda for lip reconstruction loss
LIP_ENC_V: 0 # lambda for lip velocity reconstruction loss
DIST_SYNC_ON_STEP: True # Sync Losses on step when distributed trained
audio_encoder:
train_audio_encoder: True
model_name_or_path: 'facebook/wav2vec2-base-960h'
# Model Configuration
model:
target: 'diffusion/diffusion_bias_modules'
audio_encoded_dim: 768 # audio hidden dimension
model_type: diffusion_bias # model type
latent_dim: 512 # latent dimension
id_dim: 8 # the dimension of the id vector
ff_size: 1024 # latent_dim * 2
num_layers: 1 # number of layers
num_heads: 4 # number of head layers
dropout: 0.1 # dropout rate
max_len: 600 # the attention mask maximum length
activation: gelu # activation type
normalize_before: True
require_start_token: True # start_token is need for autogressive generation only
arch: 'default'
predict_epsilon: False # noise or motion, motion here
freq_shift: 0
flip_sin_to_cos: True
mem_attn_scale: 1.
tgt_attn_scale: 1.
audio_fps: 50
hidden_fps: 30
guidance_scale: 0 # not used
guidance_uncondp: 0. # not used
period: 30
no_cross: False
smooth_output: True
# Logger configuration
LOGGER:
SACE_CHECKPOINT_EPOCH: 1000
LOG_EVERY_STEPS: 100
VAL_EVERY_STEPS: 1000 # 200
TENSORBOARD: True
WANDB:
PROJECT: null
OFFLINE: False
RESUME_ID: null
================================================
FILE: configs/diffusion/vocaset/diffspeaker_wav2vec2_vocaset.yaml
================================================
NAME: diffspeaker_wav2vec2_vocaset # Experiment name
DEBUG: False # Debug mode
ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3]
# Training configuration
TRAIN:
#---------------------------------
DATASETS: ['vocaset'] # Training datasets
NUM_WORKERS: 1 # Number of workers
BATCH_SIZE: 32 # Size of batches
START_EPOCH: 0 # Start epochMMOTIONENCODER
END_EPOCH: 9000 # End epoch
RESUME: '' # Resume training from this path
OPTIM:
TYPE: AdamW # Optimizer type
LR: 1e-4 # Learning rate
# Evaluating Configuration
EVAL:
DATASETS: ['vocaset'] # Evaluating datasets
BATCH_SIZE: 32 # Evaluating Batch size
# Datasets Configuration
DATASET:
JOINT_TYPE: 'vocaset' # join type
TEST:
CHECKPOINTS: checkpoints/vocaset/diffspeaker_wav2vec2_vocaset.ckpt # Pretrained model path
DATASETS: ['vocaset'] # training datasets
BATCH_SIZE: 1 # training Batch size
SPLIT: test # split type
REPLICATION_TIMES: 10 # replication times for each test sample
# Losses Configuration
LOSS:
TYPE: voca # Losses type
VERTICE_ENC: 1 # Lambda for vertices reconstruction Losses
VERTICE_ENC_V: 1 # lambda for vertices velocity reconstruction loss
LIP_ENC: 0 # lambda for lip reconstruction loss
LIP_ENC_V: 0 # lambda for lip velocity reconstruction loss
DIST_SYNC_ON_STEP: True # Sync Losses on step when distributed trained
audio_encoder:
train_audio_encoder: True
model_name_or_path: 'facebook/wav2vec2-base-960h'
# Model Configuration
model:
target: 'diffusion/diffusion_bias_modules'
audio_encoded_dim: 768 # audio hidden dimension
model_type: diffusion_bias # model type
latent_dim: 512 # latent dimension
id_dim: 8 # the dimension of the id vector
ff_size: 1024 # latent_dim * 2
num_layers: 1 # number of layers
num_heads: 4 # number of head layers
dropout: 0.1 # dropout rate
max_len: 600 # the attention mask maximum length
activation: gelu # activation type
normalize_before: True
require_start_token: True # start_token is need for autogressive generation only
arch: 'default'
predict_epsilon: False # noise or motion, motion here
freq_shift: 0
flip_sin_to_cos: True
mem_attn_scale: 1.
tgt_attn_scale: 1.
audio_fps: 50
hidden_fps: 30
guidance_scale: 0 # not used
guidance_uncondp: 0. # not used
period: 30
no_cross: False
smooth_output: True
# Logger configuration
LOGGER:
SACE_CHECKPOINT_EPOCH: 1000
LOG_EVERY_STEPS: 100
VAL_EVERY_STEPS: 1000 # 200
TENSORBOARD: True
WANDB:
PROJECT: null
OFFLINE: False
RESUME_ID: null
================================================
FILE: datasets/biwi/README.md
================================================
should contanin
```
regions
templates
templates.pkl
vertices_npy
wav
```
================================================
FILE: datasets/biwi/regions/fdd.txt
================================================
7, 9, 59, 89, 94, 103, 122, 126, 152, 159, 160, 161, 164, 166, 169, 174, 177, 178, 179, 180, 182, 183, 184, 185, 195, 216, 217, 221, 223, 224, 226, 227, 230, 232, 233, 234, 235, 236, 238, 240, 241, 243, 244, 259, 270, 286, 300, 393, 394, 397, 400, 415, 416, 417, 418, 419, 422, 424, 440, 446, 447, 448, 449, 450, 452, 454, 456, 457, 459, 461, 462, 467, 468, 469, 470, 472, 473, 476, 477, 478, 497, 498, 499, 506, 508, 511, 513, 519, 525, 553, 588, 594, 601, 610, 613, 622, 627, 633, 635, 637, 642, 643, 645, 647, 663, 665, 681, 682, 690, 702, 705, 706, 712, 713, 714, 715, 716, 717, 718, 721, 727, 728, 730, 734, 735, 737, 738, 739, 741, 743, 744, 745, 746, 747, 748, 751, 753, 754, 756, 757, 759, 760, 763, 765, 767, 769, 776, 777, 778, 779, 780, 781, 783, 784, 788, 789, 790, 791, 795, 796, 797, 801, 802, 805, 806, 810, 813, 815, 830, 831, 832, 836, 839, 840, 841, 851, 860, 877, 883, 885, 888, 898, 900, 921, 922, 924, 925, 939, 940, 942, 947, 952, 956, 957, 968, 969, 971, 973, 975, 976, 977, 982, 983, 985, 986, 988, 990, 995, 999, 1000, 1001, 1002, 1003, 1005, 1008, 1009, 1011, 1012, 1013, 1014, 1015, 1019, 1020, 1021, 1022, 1025, 1027, 1033, 1042, 1107, 1110, 1114, 1117, 1132, 1134, 1141, 1142, 1147, 1149, 1150, 1152, 1155, 1156, 1158, 1160, 1161, 1163, 1166, 1167, 1168, 1170, 1171, 1172, 1175, 1186, 1208, 1211, 1242, 1250, 1258, 1261, 1264, 1304, 1310, 1317, 1319, 1372, 1388, 1398, 1400, 1402, 1403, 1405, 1408, 1409, 1410, 1411, 1412, 1416, 1418, 1420, 1424, 1426, 1433, 1435, 1436, 1439, 1441, 1445, 1446, 1456, 1459, 1476, 1483, 1490, 1491, 1506, 1507, 1510, 1513, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1540, 1549, 1550, 1552, 1554, 1558, 1559, 1561, 1562, 1563, 1564, 1565, 1571, 1572, 1573, 1574, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1585, 1590, 1591, 1593, 1594, 1595, 1596, 1597, 1598, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1611, 1613, 1615, 1620, 1622, 1625, 1626, 1627, 1628, 1629, 1631, 1638, 1639, 1641, 1643, 1645, 1652, 1653, 1655, 1660, 1661, 1670, 1672, 1676, 1678, 1680, 1682, 1683, 1685, 1686, 1690, 1693, 1696, 1698, 1701, 1704, 1705, 1706, 1707, 1708, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1718, 1722, 1723, 1724, 1726, 1727, 1729, 1730, 1734, 1737, 1738, 1743, 1744, 1746, 1749, 1751, 1752, 1756, 1757, 1760, 1763, 1767, 1768, 1769, 1772, 1778, 1779, 1782, 1785, 1786, 1787, 1789, 1793, 1795, 1796, 1802, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1820, 1825, 1826, 1831, 1832, 1841, 1847, 1850, 1851, 1853, 1854, 1859, 1860, 1861, 1862, 1866, 1867, 1868, 1869, 1873, 1875, 1881, 1883, 1886, 1887, 1889, 1892, 1896, 1902, 1903, 1904, 1909, 1923, 1941, 1949, 1964, 1966, 1967, 1968, 1971, 1972, 1974, 1976, 1977, 1978, 1979, 1980, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1991, 1992, 1993, 1994, 1999, 2004, 2009, 2011, 2012, 2013, 2016, 2017, 2021, 2022, 2024, 2025, 2026, 2034, 2043, 2044, 2047, 2051, 2053, 2055, 2064, 2069, 2073, 2074, 2078, 2081, 2082, 2085, 2090, 2092, 2099, 2104, 2106, 2107, 2109, 2123, 2130, 2134, 2138, 2142, 2159, 2160, 2173, 2215, 2243, 2244, 2245, 2246, 2247, 2248, 2251, 2253, 2256, 2258, 2263, 2266, 2267, 2268, 2273, 2274, 2275, 2276, 2277, 2279, 2280, 2281, 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2291, 2293, 2294, 2296, 2297, 2298, 2299, 2300, 2303, 2304, 2305, 2306, 2307, 2308, 2309, 2310, 2311, 2312, 2313, 2314, 2315, 2316, 2317, 2318, 2322, 2325, 2327, 2328, 2330, 2331, 2333, 2336, 2337, 2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347, 2348, 2350, 2351, 2352, 2353, 2356, 2357, 2358, 2360, 2361, 2362, 2363, 2364, 2365, 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2375, 2378, 2380, 2381, 2383, 2384, 2385, 2386, 2390, 2391, 2394, 2395, 2399, 2404, 2406, 2407, 2408, 2409, 2410, 2411, 2413, 2414, 2416, 2417, 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2433, 2437, 2440, 2441, 2445, 2447, 2452, 2455, 2456, 2458, 2459, 2460, 2464, 2466, 2469, 2470, 2471, 2474, 2478, 2479, 2485, 2487, 2488, 2490, 2494, 2496, 2497, 2503, 2505, 2508, 2509, 2512, 2516, 2517, 2519, 2520, 2522, 2523, 2524, 2528, 2530, 2531, 2532, 2533, 2535, 2536, 2537, 2539, 2542, 2543, 2546, 2547, 2548, 2552, 2554, 2564, 2567, 2583, 2593, 2602, 2605, 2608, 2609, 2610, 2612, 2614, 2615, 2617, 2620, 2622, 2624, 2625, 2627, 2628, 2629, 2632, 2633, 2634, 2637, 2699, 2723, 2724, 2725, 2738, 2741, 2743, 2755, 2756, 2776, 2779, 2780, 2796, 2797, 2798, 2799, 2803, 2808, 2812, 2813, 2819, 2829, 2842, 2844, 2846, 2851, 2852, 2853, 2855, 2857, 2861, 2862, 2868, 2874, 2879, 2881, 2882, 2883, 2885, 2893, 2899, 2944, 2961, 2964, 2965, 2970, 2971, 2972, 2973, 2978, 2982, 2983, 2984, 2985, 2986, 2988, 2991, 2995, 2996, 2998, 3005, 3020, 3021, 3027, 3028, 3030, 3031, 3036, 3039, 3041, 3044, 3045, 3046, 3048, 3051, 3056, 3057, 3060, 3065, 3068, 3069, 3073, 3079, 3083, 3085, 3088, 3092, 3096, 3102, 3118, 3125, 3127, 3128, 3131, 3133, 3140, 3141, 3142, 3144, 3147, 3149, 3151, 3152, 3153, 3164, 3165, 3170, 3171, 3172, 3174, 3178, 3184, 3186, 3188, 3190, 3192, 3193, 3195, 3196, 3197, 3198, 3200, 3201, 3203, 3204, 3205, 3206, 3207, 3209, 3211, 3212, 3213, 3214, 3216, 3217, 3220, 3222, 3223, 3224, 3225, 3227, 3228, 3230, 3231, 3232, 3233, 3234, 3236, 3240, 3244, 3245, 3247, 3248, 3249, 3250, 3257, 3258, 3259, 3260, 3261, 3262, 3268, 3269, 3270, 3274, 3275, 3279, 3282, 3283, 3287, 3291, 3292, 3293, 3294, 3295, 3298, 3300, 3302, 3303, 3308, 3312, 3313, 3314, 3316, 3317, 3319, 3320, 3324, 3325, 3327, 3330, 3331, 3333, 3339, 3340, 3343, 3349, 3352, 3353, 3355, 3356, 3357, 3358, 3359, 3362, 3365, 3366, 3368, 3369, 3370, 3371, 3374, 3378, 3379, 3380, 3381, 3383, 3384, 3386, 3387, 3388, 3389, 3390, 3391, 3394, 3398, 3399, 3401, 3402, 3404, 3413, 3414, 3415, 3421, 3422, 3423, 3429, 3430, 3432, 3434, 3436, 3437, 3438, 3439, 3441, 3442, 3444, 3445, 3446, 3451, 3453, 3454, 3455, 3456, 3458, 3461, 3462, 3463, 3464, 3465, 3467, 3468, 3470, 3471, 3472, 3473, 3476, 3478, 3482, 3484, 3485, 3491, 3493, 3494, 3495, 3500, 3502, 3503, 3506, 3507, 3508, 3511, 3512, 3514, 3516, 3520, 3524, 3525, 3526, 3527, 3528, 3531, 3538, 3540, 3541, 3542, 3544, 3546, 3550, 3551, 3552, 3565, 3573, 3581, 3593, 3594, 3606, 3608, 3651, 3654, 3656, 3658, 3663, 3665, 3668, 3679, 3698, 3709, 3710, 3714, 3715, 3718, 3719, 3721, 3733, 3735, 3760, 3769, 3772, 3773, 3778, 3779, 3780, 3783, 3784, 3786, 3788, 3791, 3792, 3793, 3794, 3795, 3797, 3801, 3828, 3830, 3832, 3838, 3839, 3843, 3845, 3846, 3851, 3856, 3862, 3865, 3870, 3872, 3874, 3878, 3879, 3886, 3887, 3889, 3892, 3893, 3904, 3905, 3906, 3912, 3914, 3918, 3919, 3923, 3927, 3938, 3944, 3948, 3959, 3962, 3963, 3965, 3966, 3969, 3973, 3985, 3986, 3991, 3996, 3999, 4003, 4004, 4006, 4007, 4010, 4012, 4014, 4015, 4017, 4020, 4021, 4023, 4026, 4028, 4029, 4030, 4032, 4035, 4038, 4040, 4043, 4058, 4059, 4061, 4063, 4064, 4065, 4068, 4071, 4072, 4073, 4077, 4078, 4079, 4080, 4082, 4084, 4085, 4087, 4088, 4090, 4094, 4098, 4099, 4100, 4101, 4107, 4115, 4122, 4123, 4124, 4126, 4127, 4128, 4129, 4147, 4150, 4153, 4154, 4155, 4156, 4157, 4158, 4161, 4165, 4182, 4183, 4184, 4187, 4190, 4191, 4192, 4194, 4197, 4199, 4204, 4206, 4208, 4209, 4210, 4215, 4216, 4218, 4220, 4221, 4222, 4225, 4226, 4227, 4228, 4232, 4235, 4236, 4237, 4238, 4240, 4241, 4244, 4247, 4252, 4253, 4255, 4258, 4259, 4260, 4261, 4264, 4266, 4268, 4272, 4274, 4276, 4278, 4279, 4280, 4281, 4284, 4287, 4288, 4289, 4290, 4293, 4294, 4295, 4297, 4298, 4300, 4301, 4302, 4303, 4304, 4305, 4306, 4307, 4308, 4309, 4310, 4311, 4313, 4315, 4316, 4317, 4318, 4319, 4322, 4324, 4326, 4327, 4328, 4331, 4335, 4337, 4338, 4346, 4352, 4355, 4358, 4359, 4360, 4361, 4365, 4366, 4367, 4368, 4369, 4370, 4371, 4372, 4373, 4374, 4375, 4376, 4380, 4383, 4388, 4391, 4394, 4395, 4396, 4398, 4399, 4409, 4412, 4416, 4418, 4419, 4425, 4426, 4431, 4435, 4441, 4443, 4452, 4454, 4455, 4462, 4463, 4464, 4465, 4467, 4468, 4475, 4478, 4479, 4480, 4481, 4483, 4484, 4485, 4486, 4487, 4488, 4492, 4495, 4497, 4498, 4502, 4503, 4507, 4508, 4509, 4511, 4512, 4514, 4515, 4516, 4517, 4518, 4519, 4520, 4523, 4524, 4525, 4526, 4527, 4528, 4529, 4530, 4531, 4532, 4533, 4534, 4535, 4536, 4537, 4538, 4539, 4541, 4542, 4543, 4544, 4545, 4546, 4548, 4549, 4552, 4553, 4554, 4555, 4556, 4557, 4558, 4559, 4560, 4561, 4562, 4563, 4564, 4565, 4566, 4567, 4569, 4571, 4572, 4573, 4574, 4575, 4576, 4577, 4578, 4579, 4580, 4581, 4582, 4583, 4584, 4590, 4591, 4592, 4594, 4595, 4598, 4600, 4602, 4604, 4605, 4606, 4608, 4611, 4616, 4617, 4618, 4619, 4631, 4632, 4633, 4634, 4635, 4636, 4637, 4638, 4641, 4642, 4645, 4648, 4649, 4651, 4656, 4657, 4664, 4666, 4667, 4668, 4673, 4675, 4678, 4679, 4684, 4687, 4690, 4693, 4695, 4696, 4697, 4700, 4702, 4704, 4707, 4709, 4710, 4712, 4715, 4719, 4723, 4730, 4733, 4735, 4736, 4737, 4739, 4740, 4743, 4744, 4745, 4747, 4750, 4754, 4755, 4756, 4758, 4759, 4761, 4763, 4765, 4766, 4768, 4769, 4770, 4771, 4772, 4773, 4774, 4775, 4776, 4777, 4778, 4779, 4780, 4781, 4782, 4783, 4784, 4785, 4786, 4787, 4788, 4789, 4790, 4791, 4794, 4795, 4800, 4802, 4804, 4805, 4806, 4807, 4810, 4813, 4814, 4816, 4817, 4818, 4819, 4821, 4823, 4825, 4826, 4827, 4828, 4829, 4830, 4831, 4832, 4834, 4835, 4836, 4837, 4838, 4839, 4840, 4841, 4842, 4845, 4848, 4849, 4852, 4854, 4855, 4857, 4858, 4860, 4861, 4863, 4864, 4866, 4867, 4868, 4869, 4870, 4871, 4872, 4873, 4874, 4875, 4877, 4878, 4880, 4881, 4887, 4889, 4890, 4891, 4892, 4893, 4894, 4895, 4896, 4897, 4898, 4899, 4900, 4901, 4902, 4903, 4905, 4906, 4909, 4910, 4914, 4918, 4920, 4922, 4924, 4926, 4928, 4929, 4931, 4946, 4953, 4998, 4999, 5001, 5003, 5005, 5009, 5010, 5012, 5013, 5014, 5016, 5017, 5019, 5023, 5024, 5028, 5030, 5032, 5034, 5036, 5039, 5042, 5049, 5051, 5053, 5057, 5063, 5074, 5083, 5087, 5089, 5093, 5094, 5095, 5099, 5106, 5109, 5115, 5116, 5118, 5119, 5121, 5124, 5127, 5131, 5132, 5175, 5176, 5177, 5178, 5179, 5182, 5185, 5188, 5191, 5192, 5194, 5203, 5205, 5214, 5216, 5217, 5223, 5233, 5235, 5237, 5240, 5242, 5243, 5245, 5250, 5251, 5257, 5263, 5266, 5269, 5270, 5271, 5274, 5280, 5281, 5285, 5286, 5287, 5291, 5293, 5295, 5298, 5299, 5301, 5302, 5303, 5304, 5305, 5306, 5307, 5308, 5309, 5310, 5316, 5319, 5320, 5323, 5324, 5328, 5329, 5330, 5331, 5333, 5334, 5335, 5337, 5338, 5339, 5342, 5345, 5346, 5347, 5348, 5350, 5351, 5353, 5355, 5356, 5357, 5358, 5359, 5361, 5362, 5364, 5366, 5367, 5376, 5390, 5402, 5404, 5408, 5415, 5423, 5424, 5425, 5428, 5433, 5435, 5439, 5441, 5442, 5444, 5445, 5455, 5459, 5464, 5465, 5471, 5475, 5477, 5479, 5483, 5488, 5489, 5499, 5506, 5510, 5514, 5515, 5516, 5518, 5520, 5521, 5522, 5525, 5526, 5527, 5528, 5529, 5533, 5536, 5538, 5539, 5540, 5543, 5549, 5551, 5554, 5555, 5558, 5559, 5563, 5564, 5566, 5567, 5568, 5570, 5571, 5573, 5574, 5575, 5578, 5579, 5580, 5581, 5582, 5583, 5584, 5585, 5587, 5588, 5589, 5590, 5591, 5593, 5594, 5596, 5597, 5599, 5600, 5602, 5603, 5605, 5606, 5607, 5608, 5611, 5615, 5617, 5620, 5622, 5623, 5625, 5631, 5635, 5643, 5678, 5680, 5682, 5684, 5694, 5711, 5712, 5714, 5718, 5725, 5728, 5730, 5732, 5736, 5740, 5748, 5751, 5759, 5761, 5763, 5766, 5768, 5770, 5772, 5787, 5791, 5795, 5798, 5802, 5809, 5813, 5818, 5822, 5848, 5854, 5855, 5858, 5859, 5861, 5863, 5864, 5865, 5867, 5875, 5877, 5881, 5882, 5886, 5888, 5890, 5892, 5893, 5895, 5896, 5901, 5902, 5905, 5908, 5913, 5914, 5919, 5923, 5925, 5926, 5927, 5928, 5929, 5932, 5933, 5934, 5940, 5941, 5946, 5948, 5949, 5950, 5951, 5952, 5953, 5955, 5959, 5960, 5961, 5962, 5964, 5965, 5966, 5967, 5968, 5969, 5971, 5972, 5973, 5974, 5975, 5976, 5977, 5978, 5979, 5980, 5981, 5983, 5985, 5986, 5987, 5988, 5989, 5991, 5992, 5993, 5994, 5995, 5996, 5997, 6004, 6007, 6008, 6012, 6015, 6016, 6017, 6022, 6031, 6033, 6035, 6037, 6038, 6039, 6040, 6041, 6042, 6046, 6047, 6048, 6049, 6050, 6053, 6055, 6057, 6058, 6061, 6063, 6065, 6066, 6067, 6070, 6072, 6074, 6075, 6079, 6080, 6081, 6082, 6083, 6085, 6087, 6088, 6089, 6090, 6091, 6092, 6093, 6094, 6096, 6097, 6098, 6106, 6111, 6112, 6113, 6114, 6116, 6122, 6123, 6124, 6126, 6127, 6131, 6132, 6133, 6134, 6135, 6136, 6137, 6138, 6139, 6140, 6141, 6143, 6145, 6146, 6147, 6148, 6149, 6150, 6151, 6152, 6153, 6157, 6158, 6159, 6160, 6161, 6164, 6169, 6170, 6171, 6172, 6174, 6175, 6176, 6177, 6180, 6181, 6182, 6183, 6184, 6185, 6187, 6191, 6192, 6193, 6195, 6196, 6198, 6200, 6204, 6206, 6209, 6211, 6212, 6215, 6219, 6221, 6222, 6225, 6235, 6236, 6241, 6244, 6245, 6247, 6248, 6249, 6251, 6253, 6255, 6257, 6259, 6260, 6265, 6273, 6274, 6276, 6280, 6282, 6283, 6284, 6285, 6286, 6287, 6288, 6289, 6291, 6293, 6294, 6295, 6298, 6299, 6300, 6311, 6314, 6334, 6335, 6336, 6338, 6340, 6343, 6351, 6353, 6354, 6357, 6359, 6360, 6361, 6364, 6368, 6372, 6380, 6384, 6391, 6395, 6401, 6402, 6403, 6406, 6412, 6414, 6415, 6416, 6418, 6419, 6421, 6425, 6426, 6427, 6428, 6429, 6430, 6434, 6435, 6436, 6443, 6446, 6450, 6452, 6453, 6454, 6459, 6460, 6461, 6463, 6466, 6472, 6479, 6490, 6491, 6492, 6496, 6497, 6498, 6500, 6501, 6503, 6504, 6508, 6509, 6510, 6511, 6512, 6515, 6519, 6520, 6522, 6523, 6527, 6534, 6536, 6538, 6540, 6543, 6545, 6546, 6549, 6552, 6553, 6554, 6562, 6565, 6567, 6568, 6569, 6571, 6573, 6575, 6576, 6578, 6579, 6581, 6584, 6586, 6587, 6591, 6593, 6597, 6598, 6604, 6611, 6615, 6616, 6617, 6618, 6621, 6622, 6625, 6626, 6627, 6629, 6630, 6639, 6640, 6642, 6643, 6644, 6645, 6646, 6652, 6655, 6659, 6665, 6673, 6675, 6677, 6679, 6680, 6698, 6710, 6737, 6746, 6747, 6748, 6749, 6754, 6760, 6764, 6771, 6781, 6782, 6792, 6794, 6804, 6815, 6822, 6831, 6842, 6853, 6855, 6856, 6859, 6864, 6867, 6871, 6876, 6879, 6882, 6883, 6887, 6890, 6913, 6914, 6917, 6928, 6945, 6956, 6957, 6959, 6960, 6962, 6964, 6965, 6968, 6970, 6971, 6972, 6974, 6975, 6976, 6977, 6981, 6982, 6983, 6984, 6986, 6988, 6989, 6990, 6991, 6992, 6993, 6994, 6995, 6996, 6997, 7005, 7006, 7012, 7015, 7020, 7024, 7025, 7026, 7027, 7028, 7029, 7030, 7031, 7032, 7033, 7034, 7035, 7036, 7037, 7038, 7039, 7040, 7042, 7043, 7044, 7045, 7046, 7047, 7050, 7051, 7052, 7053, 7054, 7055, 7056, 7057, 7058, 7059, 7060, 7061, 7062, 7063, 7064, 7065, 7068, 7069, 7070, 7071, 7072, 7073, 7075, 7078, 7081, 7083, 7084, 7086, 7088, 7093, 7094, 7095, 7096, 7097, 7098, 7099, 7100, 7101, 7103, 7105, 7106, 7108, 7109, 7110, 7111, 7112, 7114, 7115, 7117, 7118, 7119, 7121, 7125, 7126, 7127, 7131, 7132, 7133, 7137, 7138, 7139, 7140, 7141, 7142, 7143, 7144, 7146, 7149, 7150, 7151, 7153, 7156, 7169, 7171, 7172, 7178, 7179, 7182, 7184, 7185, 7186, 7187, 7190, 7191, 7192, 7193, 7194, 7195, 7196, 7197, 7198, 7199, 7200, 7201, 7202, 7203, 7204, 7205, 7206, 7207, 7208, 7209, 7210, 7211, 7212, 7213, 7214, 7215, 7216, 7217, 7218, 7219, 7220, 7221, 7222, 7223, 7224, 7225, 7226, 7228, 7229, 7234, 7235, 7239, 7240, 7243, 7245, 7246, 7247, 7249, 7251, 7253, 7254, 7256, 7262, 7265, 7269, 7270, 7280, 7282, 7283, 7284, 7285, 7286, 7287, 7288, 7289, 7290, 7291, 7292, 7293, 7294, 7295, 7296, 7297, 7298, 7299, 7300, 7301, 7302, 7303, 7304, 7305, 7306, 7307, 7309, 7311, 7312, 7313, 7314, 7315, 7316, 7317, 7318, 7322, 7323, 7324, 7326, 7327, 7329, 7330, 7335, 7336, 7341, 7343, 7344, 7345, 7346, 7348, 7350, 7353, 7354, 7356, 7358, 7360, 7361, 7362, 7363, 7365, 7370, 7372, 7373, 7375, 7376, 7377, 7378, 7379, 7381, 7382, 7385, 7386, 7387, 7388, 7389, 7390, 7391, 7392, 7393, 7394, 7395, 7396, 7397, 7398, 7399, 7400, 7401, 7402, 7403, 7404, 7405, 7406, 7407, 7408, 7409, 7410, 7411, 7412, 7413, 7414, 7415, 7416, 7417, 7418, 7419, 7420, 7421, 7422, 7424, 7425, 7426, 7428, 7430, 7432, 7434, 7435, 7436, 7437, 7438, 7439, 7441, 7442, 7443, 7444, 7449, 7451, 7453, 7455, 7456, 7459, 7460, 7461, 7463, 7466, 7468, 7470, 7471, 7472, 7485, 7523, 7524, 7528, 7529, 7531, 7532, 7537, 7543, 7585, 7602, 7608, 7623, 7624, 7660, 7662, 7663, 7665, 7667, 7673, 7674, 7678, 7680, 7683, 7684, 7685, 7686, 7687, 7690, 7691, 7694, 7695, 7697, 7698, 7744, 7791, 7792, 7793, 7795, 7799, 7800, 7802, 7804, 7807, 7810, 7812, 7814, 7822, 7823, 7827, 7830, 7831, 7839, 7842, 7846, 7849, 7852, 7857, 7859, 7861, 7865, 7866, 7867, 7871, 7872, 7879, 7885, 7886, 7887, 7891, 7892, 7893, 7895, 7897, 7899, 7901, 7902, 7914, 7923, 7924, 7929, 7932, 7933, 7934, 7936, 7937, 7948, 7951, 7954, 7956, 7957, 7958, 7962, 7965, 7966, 7973, 7979, 7984, 7986, 7989, 7990, 7998, 8006, 8007, 8010, 8014, 8015, 8017, 8020, 8022, 8023, 8025, 8026, 8029, 8030, 8031, 8036, 8039, 8041, 8042, 8043, 8045, 8047, 8048, 8049, 8051, 8052, 8053, 8054, 8055, 8056, 8057, 8058, 8059, 8063, 8064, 8065, 8067, 8068, 8069, 8070, 8071, 8072, 8073, 8074, 8075, 8077, 8078, 8081, 8084, 8085, 8086, 8087, 8088, 8089, 8090, 8093, 8094, 8095, 8097, 8102, 8108, 8126, 8128, 8131, 8132, 8134, 8135, 8138, 8139, 8159, 8162, 8163, 8173, 8174, 8177, 8178, 8183, 8186, 8187, 8188, 8190, 8195, 8201, 8216, 8229, 8231, 8234, 8236, 8238, 8245, 8247, 8250, 8252, 8253, 8256, 8257, 8258, 8260, 8261, 8262, 8269, 8271, 8279, 8280, 8281, 8286, 8289, 8292, 8294, 8298, 8300, 8301, 8302, 8305, 8306, 8321, 8323, 8325, 8328, 8333, 8335, 8352, 8354, 8357, 8359, 8362, 8363, 8364, 8365, 8366, 8369, 8370, 8371, 8374, 8379, 8382, 8383, 8385, 8386, 8387, 8388, 8389, 8393, 8394, 8396, 8397, 8400, 8402, 8406, 8408, 8409, 8411, 8412, 8414, 8415, 8416, 8417, 8418, 8420, 8422, 8428, 8430, 8433, 8436, 8438, 8442, 8443, 8445, 8451, 8461, 8463, 8481, 8485, 8486, 8489, 8490, 8493, 8495, 8497, 8498, 8500, 8502, 8507, 8513, 8517, 8523, 8526, 8530, 8533, 8537, 8538, 8539, 8540, 8543, 8545, 8546, 8547, 8548, 8549, 8563, 8573, 8578, 8579, 8581, 8584, 8587, 8591, 8592, 8593, 8600, 8604, 8607, 8609, 8618, 8619, 8621, 8625, 8627, 8628, 8633, 8634, 8636, 8639, 8641, 8645, 8648, 8649, 8650, 8652, 8654, 8658, 8660, 8661, 8665, 8671, 8672, 8674, 8676, 8678, 8680, 8686, 8688, 8689, 8690, 8696, 8702, 8703, 8704, 8705, 8715, 8716, 8717, 8727, 8731, 8737, 8739, 8740, 8747, 8758, 8785, 8787, 8788, 8793, 8797, 8817, 8838, 8850, 8856, 8860, 8865, 8870, 8871, 8873, 8885, 8888, 8892, 8895, 8896, 8898, 8901, 8904, 8907, 8909, 8911, 8913, 8933, 8936, 8937, 8939, 8942, 8946, 8949, 8950, 8951, 8952, 8953, 8957, 8961, 8962, 8963, 8965, 8966, 8967, 8968, 8969, 8971, 8975, 8979, 8981, 8986, 8987, 8990, 9005, 9007, 9010, 9011, 9012, 9014, 9017, 9018, 9020, 9023, 9027, 9029, 9035, 9037, 9038, 9040, 9045, 9049, 9050, 9052, 9053, 9054, 9055, 9056, 9057, 9058, 9059, 9060, 9061, 9062, 9063, 9064, 9065, 9066, 9067, 9070, 9071, 9072, 9074, 9075, 9076, 9078, 9079, 9080, 9081, 9082, 9083, 9084, 9085, 9086, 9087, 9088, 9089, 9090, 9091, 9092, 9093, 9094, 9095, 9096, 9097, 9098, 9099, 9100, 9101, 9102, 9103, 9104, 9105, 9106, 9108, 9110, 9112, 9113, 9114, 9115, 9116, 9117, 9119, 9120, 9121, 9122, 9123, 9124, 9125, 9129, 9130, 9131, 9133, 9134, 9135, 9136, 9137, 9138, 9139, 9140, 9141, 9142, 9145, 9146, 9147, 9149, 9150, 9151, 9152, 9153, 9155, 9156, 9158, 9159, 9160, 9161, 9162, 9164, 9165, 9173, 9175, 9177, 9178, 9180, 9181, 9182, 9183, 9184, 9186, 9193, 9195, 9197, 9201, 9202, 9205, 9206, 9226, 9227, 9229, 9232, 9233, 9234, 9236, 9237, 9238, 9240, 9241, 9251, 9257, 9259, 9260, 9261, 9264, 9265, 9267, 9272, 9274, 9275, 9277, 9279, 9286, 9288, 9292, 9295, 9297, 9301, 9306, 9309, 9310, 9313, 9315, 9319, 9329, 9333, 9343, 9346, 9347, 9355, 9356, 9364, 9371, 9372, 9379, 9385, 9387, 9389, 9393, 9395, 9398, 9400, 9407, 9417, 9419, 9432, 9433, 9434, 9436, 9441, 9446, 9450, 9451, 9452, 9453, 9454, 9455, 9456, 9457, 9458, 9459, 9460, 9462, 9465, 9472, 9473, 9474, 9476, 9478, 9479, 9493, 9510, 9517, 9518, 9519, 9521, 9531, 9532, 9533, 9534, 9539, 9546, 9548, 9549, 9550, 9551, 9553, 9555, 9557, 9558, 9563, 9566, 9569, 9571, 9575, 9577, 9579, 9581, 9583, 9588, 9589, 9590, 9592, 9594, 9595, 9600, 9601, 9602, 9607, 9611, 9612, 9613, 9615, 9620, 9654, 9655, 9664, 9667, 9668, 9669, 9679, 9683, 9687, 9688, 9694, 9695, 9698, 9699, 9700, 9701, 9702, 9704, 9706, 9711, 9712, 9721, 9722, 9728, 9729, 9730, 9733, 9734, 9735, 9736, 9737, 9739, 9741, 9743, 9744, 9745, 9746, 9748, 9749, 9752, 9753, 9754, 9755, 9762, 9771, 9774, 9782, 9783, 9784, 9785, 9786, 9787, 9789, 9790, 9795, 9796, 9799, 9801, 9803, 9804, 9805, 9806, 9807, 9810, 9812, 9813, 9814, 9818, 9822, 9825, 9827, 9833, 9836, 9839, 9842, 9847, 9850, 9851, 9852, 9854, 9858, 9863, 9878, 9879, 9880, 9883, 9886, 9887, 9891, 9892, 9898, 9901, 9904, 9908, 9913, 9919, 9928, 9931, 9942, 9943, 9947, 9954, 9955, 9959, 9975, 9976, 9985, 9997, 10011, 10012, 10013, 10015, 10016, 10020, 10022, 10023, 10024, 10025, 10027, 10028, 10029, 10031, 10032, 10033, 10036, 10037, 10038, 10040, 10041, 10042, 10043, 10044, 10045, 10046, 10051, 10054, 10078, 10082, 10083, 10085, 10087, 10088, 10090, 10092, 10093, 10097, 10098, 10100, 10101, 10102, 10103, 10109, 10113, 10118, 10122, 10123, 10124, 10128, 10129, 10134, 10135, 10136, 10138, 10139, 10148, 10150, 10152, 10157, 10164, 10169, 10170, 10173, 10175, 10176, 10179, 10181, 10185, 10187, 10190, 10191, 10192, 10193, 10197, 10199, 10200, 10202, 10204, 10206, 10210, 10213, 10215, 10216, 10217, 10220, 10222, 10227, 10228, 10233, 10236, 10237, 10238, 10245, 10246, 10248, 10249, 10255, 10256, 10261, 10262, 10264, 10265, 10269, 10270, 10272, 10274, 10278, 10280, 10281, 10282, 10283, 10287, 10289, 10291, 10293, 10296, 10297, 10301, 10304, 10305, 10308, 10312, 10315, 10317, 10322, 10329, 10335, 10343, 10344, 10345, 10351, 10371, 10376, 10391, 10393, 10394, 10395, 10396, 10397, 10398, 10416, 10418, 10421, 10427, 10428, 10430, 10435, 10438, 10443, 10445, 10447, 10452, 10453, 10455, 10456, 10457, 10458, 10460, 10463, 10464, 10465, 10466, 10467, 10468, 10469, 10470, 10471, 10472, 10473, 10474, 10475, 10476, 10477, 10478, 10479, 10480, 10481, 10483, 10485, 10487, 10488, 10489, 10492, 10493, 10494, 10495, 10496, 10497, 10498, 10499, 10500, 10501, 10504, 10505, 10506, 10507, 10508, 10509, 10510, 10513, 10514, 10515, 10517, 10518, 10519, 10520, 10521, 10527, 10528, 10529, 10530, 10531, 10532, 10533, 10540, 10541, 10545, 10547, 10551, 10556, 10562, 10563, 10564, 10566, 10571, 10572, 10575, 10577, 10578, 10579, 10580, 10581, 10584, 10586, 10587, 10590, 10592, 10596, 10597, 10600, 10601, 10602, 10603, 10604, 10606, 10607, 10608, 10609, 10610, 10612, 10615, 10617, 10619, 10620, 10621, 10622, 10625, 10628, 10630, 10637, 10640, 10641, 10647, 10650, 10651, 10654, 10657, 10659, 10666, 10669, 10670, 10671, 10672, 10673, 10674, 10675, 10676, 10677, 10681, 10683, 10687, 10688, 10689, 10692, 10698, 10702, 10707, 10713, 10716, 10722, 10732, 10733, 10734, 10736, 10737, 10756, 10759, 10760, 10764, 10770, 10777, 10778, 10782, 10786, 10788, 10791, 10794, 10798, 10799, 10810, 10826, 10829, 10830, 10831, 10834, 10835, 10836, 10839, 10843, 10845, 10848, 10849, 10850, 10857, 10859, 10860, 10861, 10864, 10877, 10879, 10882, 10884, 10886, 10893, 10895, 10899, 10902, 10903, 10904, 10905, 10907, 10910, 10913, 10914, 10920, 10923, 10925, 10928, 10929, 10942, 10944, 10945, 10948, 10949, 10952, 10953, 10954, 10955, 10959, 10960, 10961, 10965, 10966, 10968, 10971, 10972, 10976, 10984, 10986, 10988, 10989, 10992, 11021, 11029, 11038, 11040, 11053, 11062, 11063, 11081, 11107, 11113, 11116, 11118, 11119, 11122, 11124, 11141, 11143, 11144, 11145, 11147, 11153, 11159, 11164, 11169, 11171, 11173, 11177, 11179, 11184, 11185, 11186, 11189, 11191, 11193, 11195, 11197, 11199, 11203, 11204, 11205, 11206, 11207, 11208, 11209, 11211, 11213, 11215, 11217, 11219, 11225, 11227, 11228, 11230, 11231, 11232, 11233, 11236, 11238, 11241, 11250, 11251, 11254, 11257, 11258, 11260, 11265, 11270, 11271, 11272, 11275, 11277, 11279, 11280, 11288, 11294, 11296, 11298, 11309, 11311, 11312, 11313, 11330, 11331, 11334, 11344, 11346, 11347, 11357, 11358, 11359, 11363, 11364, 11367, 11369, 11370, 11371, 11373, 11376, 11378, 11383, 11389, 11391, 11394, 11396, 11402, 11403, 11408, 11411, 11412, 11434, 11438, 11443, 11449, 11450, 11453, 11461, 11464, 11466, 11467, 11472, 11477, 11478, 11480, 11482, 11487, 11489, 11490, 11492, 11494, 11496, 11514, 11538, 11541, 11551, 11554, 11557, 11558, 11561, 11562, 11598, 11603, 11605, 11608, 11622, 11623, 11628, 11629, 11637, 11639, 11642, 11644, 11645, 11651, 11661, 11662, 11666, 11667, 11668, 11677, 11679, 11687, 11690, 11692, 11694, 11695, 11696, 11700, 11702, 11709, 11715, 11718, 11725, 11727, 11728, 11739, 11749, 11750, 11754, 11755, 11756, 11757, 11761, 11764, 11766, 11779, 11782, 11788, 11790, 11793, 11794, 11796, 11805, 11809, 11810, 11812, 11813, 11816, 11817, 11819, 11820, 11823, 11825, 11826, 11828, 11829, 11831, 11833, 11839, 11843, 11845, 11846, 11849, 11850, 11852, 11853, 11861, 11862, 11868, 11870, 11872, 11914, 11919, 11920, 11923, 11934, 11937, 11938, 11943, 11944, 11945, 11946, 11947, 11948, 11952, 11956, 11960, 11961, 11970, 11972, 11974, 11979, 11980, 11981, 11983, 11988, 11991, 11993, 11996, 11997, 11998, 12000, 12002, 12007, 12011, 12019, 12022, 12025, 12031, 12033, 12036, 12046, 12051, 12054, 12058, 12061, 12063, 12065, 12066, 12070, 12071, 12072, 12073, 12074, 12075, 12079, 12082, 12083, 12085, 12087, 12090, 12092, 12093, 12094, 12097, 12098, 12102, 12103, 12104, 12105, 12108, 12109, 12110, 12113, 12114, 12116, 12117, 12118, 12119, 12120, 12123, 12124, 12125, 12126, 12127, 12128, 12129, 12132, 12133, 12134, 12135, 12137, 12138, 12139, 12140, 12142, 12143, 12145, 12147, 12148, 12150, 12152, 12153, 12155, 12156, 12159, 12160, 12161, 12167, 12168, 12171, 12172, 12173, 12174, 12175, 12176, 12178, 12180, 12181, 12183, 12186, 12188, 12194, 12195, 12199, 12201, 12202, 12203, 12205, 12207, 12208, 12209, 12210, 12213, 12215, 12216, 12217, 12220, 12221, 12222, 12223, 12230, 12231, 12233, 12234, 12235, 12239, 12240, 12246, 12250, 12251, 12252, 12258, 12259, 12260, 12261, 12262, 12263, 12264, 12268, 12269, 12270, 12271, 12272, 12276, 12278, 12281, 12282, 12285, 12286, 12287, 12288, 12289, 12290, 12291, 12292, 12293, 12294, 12295, 12296, 12297, 12298, 12299, 12300, 12301, 12302, 12303, 12304, 12305, 12306, 12307, 12309, 12311, 12313, 12314, 12317, 12319, 12321, 12325, 12334, 12340, 12342, 12343, 12344, 12345, 12346, 12347, 12348, 12349, 12350, 12351, 12352, 12353, 12354, 12355, 12356, 12357, 12359, 12361, 12362, 12364, 12365, 12366, 12367, 12368, 12369, 12370, 12371, 12375, 12376, 12378, 12380, 12382, 12384, 12385, 12386, 12387, 12391, 12392, 12393, 12394, 12395, 12396, 12397, 12398, 12399, 12407, 12408, 12409, 12412, 12427, 12450, 12452, 12470, 12478, 12479, 12480, 12481, 12482, 12484, 12485, 12486, 12490, 12505, 12508, 12510, 12512, 12520, 12521, 12528, 12533, 12535, 12565, 12582, 12583, 12592, 12596, 12598, 12599, 12600, 12602, 12603, 12604, 12606, 12609, 12611, 12613, 12614, 12617, 12626, 12645, 12656, 12672, 12673, 12684, 12697, 12698, 12702, 12705, 12706, 12707, 12708, 12709, 12712, 12714, 12716, 12718, 12720, 12723, 12724, 12725, 12730, 12732, 12735, 12737, 12739, 12740, 12741, 12744, 12745, 12746, 12749, 12750, 12751, 12753, 12754, 12756, 12757, 12759, 12761, 12762, 12765, 12766, 12770, 12771, 12772, 12775, 12779, 12801, 12810, 12830, 12831, 12835, 12843, 12867, 12870, 12871, 12872, 12884, 12886, 12890, 12893, 12895, 12896, 12897, 12899, 12901, 12921, 12922, 12925, 12938, 12943, 12945, 12947, 12948, 12951, 12954, 12959, 12960, 12961, 12962, 12967, 12976, 12982, 12984, 12988, 12989, 12990, 12992, 12993, 12996, 12998, 13008, 13009, 13010, 13012, 13018, 13019, 13021, 13022, 13024, 13025, 13026, 13027, 13028, 13031, 13034, 13037, 13044, 13070, 13079, 13084, 13092, 13094, 13095, 13096, 13097, 13098, 13101, 13102, 13104, 13105, 13106, 13108, 13111, 13112, 13114, 13115, 13117, 13118, 13119, 13121, 13123, 13127, 13128, 13129, 13130, 13131, 13132, 13133, 13134, 13137, 13139, 13141, 13142, 13143, 13154, 13167, 13184, 13187, 13189, 13193, 13201, 13219, 13220, 13246, 13265, 13302, 13308, 13311, 13317, 13336, 13345, 13346, 13348, 13350, 13351, 13354, 13356, 13359, 13360, 13364, 13365, 13366, 13367, 13368, 13370, 13371, 13372, 13375, 13377, 13378, 13379, 13380, 13381, 13387, 13388, 13390, 13391, 13393, 13394, 13395, 13397, 13400, 13413, 13415, 13417, 13418, 13419, 13420, 13421, 13422, 13425, 13427, 13428, 13431, 13432, 13434, 13437, 13438, 13441, 13442, 13444, 13446, 13447, 13449, 13451, 13452, 13453, 13455, 13459, 13462, 13464, 13465, 13471, 13472, 13474, 13475, 13476, 13478, 13480, 13481, 13483, 13516, 13518, 13519, 13520, 13522, 13529, 13530, 13534, 13540, 13541, 13549, 13559, 13561, 13568, 13569, 13571, 13573, 13575, 13577, 13579, 13581, 13582, 13583, 13584, 13585, 13586, 13590, 13591, 13593, 13595, 13603, 13605, 13606, 13608, 13609, 13611, 13612, 13613, 13614, 13618, 13620, 13626, 13627, 13629, 13630, 13631, 13632, 13634, 13635, 13637, 13638, 13641, 13644, 13646, 13647, 13649, 13650, 13651, 13652, 13653, 13656, 13657, 13659, 13662, 13663, 13664, 13666, 13667, 13670, 13671, 13673, 13674, 13676, 13677, 13680, 13681, 13682, 13683, 13685, 13688, 13690, 13691, 13692, 13724, 13727, 13729, 13731, 13735, 13738, 13739, 13745, 13746, 13749, 13753, 13755, 13756, 13757, 13758, 13759, 13767, 13769, 13771, 13777, 13779, 13781, 13785, 13786, 13788, 13791, 13792, 13793, 13797, 13799, 13801, 13802, 13803, 13804, 13808, 13814, 13815, 13816, 13817, 13819, 13821, 13824, 13826, 13827, 13828, 13829, 13831, 13834, 13835, 13838, 13840, 13841, 13842, 13843, 13845, 13846, 13847, 13849, 13855, 13862, 13863, 13864, 13865, 13866, 13867, 13868, 13876, 13878, 13879, 13880, 13881, 13884, 13885, 13886, 13887, 13888, 13890, 13891, 13895, 13897, 13903, 13910, 13919, 13920, 13937, 13943, 13946, 13950, 13951, 13952, 13962, 13978, 13987, 14006, 14010, 14015, 14016, 14017, 14024, 14026, 14039, 14045, 14046, 14051, 14074, 14088, 14100, 14106, 14107, 14111, 14115, 14116, 14123, 14129, 14131, 14139, 14142, 14145, 14147, 14150, 14151, 14153, 14154, 14160, 14168, 14171, 14175, 14177, 14186, 14220, 14239, 14241, 14252, 14253, 14279, 14303, 14314, 14319, 14320, 14322, 14337, 14344, 14345, 14349, 14351, 14356, 14357, 14358, 14360, 14363, 14364, 14366, 14367, 14375, 14376, 14377, 14381, 14383, 14384, 14385, 14386, 14387, 14388, 14389, 14391, 14397, 14398, 14401, 14403, 14407, 14415, 14418, 14420, 14426, 14427, 14428, 14429, 14431, 14432, 14433, 14434, 14435, 14437, 14438, 14439, 14440, 14441, 14443, 14444, 14445, 14447, 14448, 14452, 14453, 14460, 14461, 14463, 14464, 14465, 14467, 14468, 14470, 14471, 14472, 14473, 14474, 14475, 14476, 14478, 14479, 14480, 14481, 14482, 14483, 14484, 14485, 14486, 14487, 14489, 14490, 14491, 14492, 14495, 14496, 14498, 14499, 14501, 14502, 14511, 14513, 14516, 14517, 14518, 14519, 14520, 14521, 14522, 14523, 14524, 14525, 14526, 14527, 14528, 14529, 14531, 14532, 14533, 14534, 14535, 14537, 14538, 14540, 14541, 14542, 14543, 14545, 14546, 14547, 14548, 14552, 14553, 14555, 14556, 14560, 14564, 14567, 14568, 14569, 14570, 14571, 14575, 14578, 14579, 14580, 14581, 14583, 14585, 14586, 14587, 14588, 14589, 14594, 14595, 14596, 14597, 14598, 14599, 14600, 14602, 14603, 14604, 14606, 14607, 14610, 14611, 14612, 14614, 14617, 14618, 14620, 14621, 14622, 14623, 14624, 14625, 14629, 14635, 14637, 14638, 14643, 14644, 14649, 14651, 14657, 14659, 14661, 14662, 14670, 14677, 14707, 14715, 14717, 14719, 14729, 14746, 14751, 14764, 14766, 14768, 14770, 14771, 14772, 14775, 14777, 14778, 14779, 14782, 14785, 14786, 14787, 14788, 14790, 14791, 14792, 14795, 14801, 14804, 14807, 14811, 14813, 14816, 14818, 14819, 14821, 14822, 14824, 14825, 14826, 14829, 14833, 14834, 14838, 14839
gitextract_v4nqs5sn/ ├── README.md ├── alm/ │ ├── callback/ │ │ ├── __init__.py │ │ └── progress.py │ ├── config.py │ ├── data/ │ │ ├── BIWI/ │ │ │ ├── __init__.py │ │ │ └── dataset.py │ │ ├── __init__.py │ │ ├── base.py │ │ ├── biwi.py │ │ ├── get_data.py │ │ ├── voca/ │ │ │ ├── __init__.py │ │ │ └── dataset.py │ │ └── vocaset.py │ ├── models/ │ │ ├── __init__.py │ │ ├── architectures/ │ │ │ ├── __init__.py │ │ │ ├── adpt_bias_denoiser.py │ │ │ └── tools/ │ │ │ ├── embeddings.py │ │ │ ├── transformer_adpt.py │ │ │ └── utils.py │ │ ├── get_model.py │ │ ├── losses/ │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ └── voca.py │ │ └── modeltype/ │ │ ├── __init__.py │ │ ├── base.py │ │ └── diffusion_bias.py │ └── utils/ │ ├── __init__.py │ ├── demo_utils.py │ ├── logger.py │ └── temos_utils.py ├── configs/ │ ├── assets/ │ │ ├── biwi.yaml │ │ └── vocaset.yaml │ ├── base.yaml │ └── diffusion/ │ ├── biwi/ │ │ ├── diffspeaker_hubert_biwi.yaml │ │ └── diffspeaker_wav2vec2_biwi.yaml │ ├── diffusion_bias_modules/ │ │ ├── denoiser.yaml │ │ └── scheduler.yaml │ └── vocaset/ │ ├── diffspeaker_hubert_vocaset.yaml │ └── diffspeaker_wav2vec2_vocaset.yaml ├── datasets/ │ ├── biwi/ │ │ ├── README.md │ │ ├── regions/ │ │ │ ├── fdd.txt │ │ │ └── lve.txt │ │ ├── templates/ │ │ │ └── BIWI.ply │ │ └── templates.pkl │ └── vocaset/ │ ├── FLAME_masks.pkl │ ├── README.md │ ├── templates/ │ │ ├── FLAME_sample.ply │ │ └── README.md │ └── templates.pkl ├── demo_biwi.py ├── demo_vocaset.py ├── demo_vocaset_text.py ├── eval_biwi.py ├── eval_vocaset.py ├── requirements.txt ├── scripts/ │ ├── demo/ │ │ ├── demo_biwi.sh │ │ └── demo_vocaset.sh │ └── diffusion/ │ ├── biwi_evaluation/ │ │ ├── diffspeaker_hubert_biwi.sh │ │ └── diffspeaker_wav2vec2_biwi.sh │ ├── biwi_training/ │ │ ├── diffspeaker_hubert_biwi.sh │ │ └── diffspeaker_wav2vec2_biwi.sh │ ├── vocaset_evaluation/ │ │ ├── diffspeaker_hubert_vocaset.sh │ │ └── diffspeaker_wav2vec2_vocaset.sh │ └── vocaset_training/ │ ├── diffspeaker_hubert_vocaset.sh │ └── diffspeaker_wav2vec2_vocaset.sh └── train.py
SYMBOL INDEX (157 symbols across 26 files)
FILE: alm/callback/progress.py
class ProgressLogger (line 10) | class ProgressLogger(Callback):
method __init__ (line 12) | def __init__(self, metric_monitor: dict, precision: int = 3):
method on_train_start (line 17) | def on_train_start(self, trainer: Trainer, pl_module: LightningModule,
method on_train_end (line 21) | def on_train_end(self, trainer: Trainer, pl_module: LightningModule,
method on_validation_epoch_end (line 25) | def on_validation_epoch_end(self, trainer: Trainer,
method on_train_epoch_end (line 30) | def on_train_epoch_end(self,
FILE: alm/config.py
function get_module_config (line 7) | def get_module_config(cfg_model, path="modules"):
function get_obj_from_str (line 16) | def get_obj_from_str(string, reload=False):
function instantiate_from_config (line 24) | def instantiate_from_config(config):
function parse_args (line 34) | def parse_args(phase="train"):
FILE: alm/data/BIWI/dataset.py
class BIWIDataset (line 13) | class BIWIDataset(data.Dataset):
method __init__ (line 15) | def __init__(self,
method __len__ (line 29) | def __len__(self):
method __getitem__ (line 32) | def __getitem__(self, index):
FILE: alm/data/base.py
class BASEDataModule (line 7) | class BASEDataModule(pl.LightningDataModule):
method __init__ (line 9) | def __init__(self, collate_fn, batch_size: int, num_workers: int):
method __getattr__ (line 32) | def __getattr__(self, item):
method setup (line 53) | def setup(self, stage=None):
method train_dataloader (line 62) | def train_dataloader(self):
method predict_dataloader (line 70) | def predict_dataloader(self):
method val_dataloader (line 82) | def val_dataloader(self):
method test_dataloader (line 94) | def test_dataloader(self):
FILE: alm/data/biwi.py
function load_data (line 15) | def load_data(args):
class BIWIDataModule (line 36) | class BIWIDataModule(BASEDataModule):
method __init__ (line 37) | def __init__(self,
method __getattr__ (line 175) | def __getattr__(self, item):
FILE: alm/data/get_data.py
function collate_tensors (line 4) | def collate_tensors(batch):
function vocaset_collate_fn (line 16) | def vocaset_collate_fn(batch):
function voxcelebinsta_collate_fn (line 31) | def voxcelebinsta_collate_fn(batch):
function voxcelebinstacoeflmdb_collate_fn (line 60) | def voxcelebinstacoeflmdb_collate_fn(batch):
function get_datasets (line 94) | def get_datasets(cfg, logger, phase='train'):
FILE: alm/data/voca/dataset.py
class VOCASETDataset (line 13) | class VOCASETDataset(data.Dataset):
method __init__ (line 15) | def __init__(self,
method __getitem__ (line 27) | def __getitem__(self, index):
method __len__ (line 57) | def __len__(self):
FILE: alm/data/vocaset.py
function load_data (line 15) | def load_data(args):
class VOCASETDataModule (line 36) | class VOCASETDataModule(BASEDataModule):
method __init__ (line 37) | def __init__(self,
method __getattr__ (line 155) | def __getattr__(self, item):
FILE: alm/models/architectures/adpt_bias_denoiser.py
class Adpt_Bias_Denoiser (line 13) | class Adpt_Bias_Denoiser(nn.Module):
method __init__ (line 15) | def __init__(self,
method forward (line 104) | def forward(self,
FILE: alm/models/architectures/tools/embeddings.py
function get_activation (line 8) | def get_activation(activation_type):
class MaskedNorm (line 37) | class MaskedNorm(nn.Module):
method __init__ (line 43) | def __init__(self, norm_type, num_groups, num_features):
method forward (line 57) | def forward(self, x: Tensor, mask: Tensor):
class Embeddings (line 76) | class Embeddings(nn.Module):
method __init__ (line 83) | def __init__(
method forward (line 133) | def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
method __repr__ (line 155) | def __repr__(self):
class SpatialEmbeddings (line 163) | class SpatialEmbeddings(nn.Module):
method __init__ (line 171) | def __init__(
method forward (line 218) | def forward(self, x: Tensor, mask: Tensor) -> Tensor:
method __repr__ (line 238) | def __repr__(self):
function get_timestep_embedding (line 245) | def get_timestep_embedding(
class TimestepEmbedding (line 288) | class TimestepEmbedding(nn.Module):
method __init__ (line 289) | def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "s...
method forward (line 298) | def forward(self, sample):
class Timesteps (line 308) | class Timesteps(nn.Module):
method __init__ (line 309) | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale...
method forward (line 315) | def forward(self, timesteps):
FILE: alm/models/architectures/tools/transformer_adpt.py
class Transformer_Adpt (line 13) | class Transformer_Adpt(nn.Module):
method __init__ (line 15) | def __init__(self,
method obj_vector (line 85) | def obj_vector(self, id):
method _forward (line 91) | def _forward(self,
class TransformerDecoder_w_Adapter (line 131) | class TransformerDecoder_w_Adapter(nn.TransformerDecoder):
method __init__ (line 141) | def __init__(self, decoder_layer, num_layers, norm=None):
method forward (line 147) | def forward(self,
class TransformerDecoderLayer_w_Adapter (line 184) | class TransformerDecoderLayer_w_Adapter(nn.TransformerDecoderLayer):
method __init__ (line 201) | def __init__(self,
method forward (line 217) | def forward(self,
method _sa_block (line 252) | def _sa_block(self, x: Tensor,
method _mha_block (line 289) | def _mha_block(self, x: Tensor, mem: Tensor,
method _concate_adapter (line 326) | def _concate_adapter(self, adapter: Tensor, x: Tensor, batch_first: bo...
FILE: alm/models/architectures/tools/utils.py
class PeriodicPositionalEncoding (line 5) | class PeriodicPositionalEncoding(nn.Module):
method __init__ (line 6) | def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=600):
method forward (line 19) | def forward(self, x):
function init_biased_mask (line 95) | def init_biased_mask(n_head, max_seq_len, period):
function init_bi_biased_mask (line 133) | def init_bi_biased_mask(max_seq_len, ):
function init_mem_mask_faceformer (line 147) | def init_mem_mask_faceformer(max_seq_len):
function init_bi_biased_mask_faceformer (line 153) | def init_bi_biased_mask_faceformer(n_head, max_seq_len, period):
FILE: alm/models/get_model.py
function get_model (line 3) | def get_model(cfg, datamodule):
function get_module (line 11) | def get_module(cfg, datamodule):
FILE: alm/models/losses/utils.py
class MaskedConsistency (line 4) | class MaskedConsistency:
method __init__ (line 5) | def __init__(self) -> None:
method __call__ (line 8) | def __call__(self, pred, gt, mask):
method __repr__ (line 11) | def __repr__(self):
class MaskedVelocityConsistency (line 14) | class MaskedVelocityConsistency:
method __init__ (line 15) | def __init__(self) -> None:
method __call__ (line 18) | def __call__(self, pred, gt, mask):
method velocity (line 23) | def velocity(self, term):
method __repr__ (line 27) | def __repr__(self):
function vocaset_upper_face_variance (line 164) | def vocaset_upper_face_variance(motion, ):
function vocaset_mouth_distance (line 172) | def vocaset_mouth_distance(vertices_gt, vertices_pred):
function biwi_upper_face_variance (line 179) | def biwi_upper_face_variance(motion, ):
function biwi_mouth_distance (line 187) | def biwi_mouth_distance(vertices_gt, vertices_pred):
FILE: alm/models/losses/voca.py
class VOCALosses (line 32) | class VOCALosses(Metric):
method __init__ (line 37) | def __init__(self, cfg, split):
method update (line 121) | def update(self, rs_set):
method compute (line 151) | def compute(self, split):
method _update_loss (line 156) | def _update_loss(self, loss: str, outputs, inputs, mask = None):
method loss2logname (line 167) | def loss2logname(self, loss: str, split: str):
class MaskedConsistency (line 175) | class MaskedConsistency:
method __init__ (line 176) | def __init__(self) -> None:
method __call__ (line 179) | def __call__(self, pred, gt, mask):
method __repr__ (line 186) | def __repr__(self):
class MaskedVelocityConsistency (line 189) | class MaskedVelocityConsistency:
method __init__ (line 190) | def __init__(self) -> None:
method __call__ (line 193) | def __call__(self, pred, gt, mask):
method velocity (line 198) | def velocity(self, term):
method __repr__ (line 202) | def __repr__(self):
FILE: alm/models/modeltype/base.py
class BaseModel (line 8) | class BaseModel(LightningModule):
method __init__ (line 10) | def __init__(self, *args, **kwargs):
method __post_init__ (line 14) | def __post_init__(self):
method training_step (line 25) | def training_step(self, batch, batch_idx):
method validation_step (line 28) | def validation_step(self, batch, batch_idx):
method test_step (line 31) | def test_step(self, batch, batch_idx):
method predict_step (line 36) | def predict_step(self, batch, batch_idx):
method allsplit_epoch_end (line 39) | def allsplit_epoch_end(self, split: str, outputs):
method training_epoch_end (line 73) | def training_epoch_end(self, outputs):
method validation_epoch_end (line 76) | def validation_epoch_end(self, outputs):
method test_epoch_end (line 79) | def test_epoch_end(self, outputs):
method configure_optimizers (line 115) | def configure_optimizers(self):
FILE: alm/models/modeltype/diffusion_bias.py
class DIFFUSION_BIAS (line 26) | class DIFFUSION_BIAS(BaseModel):
method __init__ (line 28) | def __init__(self, cfg, datamodule, **kwargs):
method allsplit_step (line 95) | def allsplit_step(self, split: str, batch, batch_idx):
method _memory_mask (line 273) | def _memory_mask(self, hidden_attention, ):
method _tgt_mask (line 312) | def _tgt_mask(self, vertice_attention, ):
method _mem_key_padding_mask (line 349) | def _mem_key_padding_mask(self, vertice_attention):
method _tgt_key_padding_mask (line 364) | def _tgt_key_padding_mask(self, vertice_attention):
method _audio_resize (line 379) | def _audio_resize(self, hidden_state: torch.Tensor, input_fps: Optiona...
method _audio_2_hidden (line 399) | def _audio_2_hidden(self, audio, audio_attention, length = None):
method _diffusion_forward (line 417) | def _diffusion_forward(self, batch, batch_idx, phase):
method smooth (line 487) | def smooth(self, vertices):
method predict (line 498) | def predict(self, batch, **kwargs):
method _diffusion_process (line 546) | def _diffusion_process(
method _diffusion_reverse (line 598) | def _diffusion_reverse(
method _visualize (line 675) | def _visualize(self, batch, rs_set, parrallel = True):
FILE: alm/utils/demo_utils.py
function load_example_input (line 25) | def load_example_input(audio_path, processor = None):
function render_mesh_helper (line 132) | def render_mesh_helper(mesh, t_center, rot=np.zeros(3), tex_img=None, z_...
function render_frame (line 218) | def render_frame(args):
function animate (line 225) | def animate(vertices: np.array, wav_path: str, file_name: str, ply: str,...
FILE: alm/utils/logger.py
function create_logger (line 9) | def create_logger(cfg, phase='train'):
function config_logger (line 38) | def config_logger(final_output_dir, time_str, phase, head):
function new_dir (line 56) | def new_dir(cfg, phase, time_str, final_output_dir):
FILE: alm/utils/temos_utils.py
function lengths_to_mask (line 7) | def lengths_to_mask(lengths: Tensor, # [batch_size]
function remove_padding (line 15) | def remove_padding(tensors, lengths):
FILE: demo_biwi.py
function main (line 12) | def main():
FILE: demo_vocaset.py
function main (line 36) | def main():
FILE: demo_vocaset_text.py
function main (line 13) | def main():
FILE: eval_biwi.py
function print_table (line 54) | def print_table(title, metrics):
function get_metric_statistics (line 66) | def get_metric_statistics(values, replication_times):
function main (line 72) | def main():
FILE: eval_vocaset.py
function print_table (line 27) | def print_table(title, metrics):
function get_metric_statistics (line 39) | def get_metric_statistics(values, replication_times):
function main (line 45) | def main():
FILE: train.py
function main (line 18) | def main():
Condensed preview — 66 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (6,533K chars).
[
{
"path": "README.md",
"chars": 2256,
"preview": "# DiffSpeaker: Speech-Driven 3D Facial Animation with Diffusion Transformer\n## [Paper](https://arxiv.org/pdf/2402.05712."
},
{
"path": "alm/callback/__init__.py",
"chars": 37,
"preview": "from .progress import ProgressLogger\n"
},
{
"path": "alm/callback/progress.py",
"chars": 1906,
"preview": "import logging\n\nfrom pytorch_lightning import LightningModule, Trainer\nfrom pytorch_lightning.callbacks import Callback\n"
},
{
"path": "alm/config.py",
"chars": 6915,
"preview": "import importlib\nfrom argparse import ArgumentParser\nfrom omegaconf import OmegaConf\nimport os\n\n\ndef get_module_config(c"
},
{
"path": "alm/data/BIWI/__init__.py",
"chars": 32,
"preview": "from .dataset import BIWIDataset"
},
{
"path": "alm/data/BIWI/dataset.py",
"chars": 2063,
"preview": "import numpy as np\nimport torch\nfrom torch.utils import data\nfrom transformers import Wav2Vec2Processor\nfrom collections"
},
{
"path": "alm/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alm/data/base.py",
"chars": 3992,
"preview": "from os.path import join as pjoin\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.utils.data import DataLoa"
},
{
"path": "alm/data/biwi.py",
"chars": 6269,
"preview": "from .base import BASEDataModule\nfrom transformers import Wav2Vec2Processor\nfrom collections import defaultdict\nfrom .BI"
},
{
"path": "alm/data/get_data.py",
"chars": 8922,
"preview": "import numpy as np\nimport torch\n\ndef collate_tensors(batch):\n dims = batch[0].dim()\n max_size = [max([b.size(i) fo"
},
{
"path": "alm/data/voca/__init__.py",
"chars": 35,
"preview": "from .dataset import VOCASETDataset"
},
{
"path": "alm/data/voca/dataset.py",
"chars": 1945,
"preview": "import numpy as np\nimport torch\nfrom torch.utils import data\nfrom transformers import Wav2Vec2Processor\nfrom collections"
},
{
"path": "alm/data/vocaset.py",
"chars": 6090,
"preview": "from .base import BASEDataModule\nfrom alm.data.voca import VOCASETDataset\nfrom transformers import Wav2Vec2Processor\nfro"
},
{
"path": "alm/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alm/models/architectures/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alm/models/architectures/adpt_bias_denoiser.py",
"chars": 6677,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom alm.models.architecture"
},
{
"path": "alm/models/architectures/tools/embeddings.py",
"chars": 9714,
"preview": "# This file is taken from signjoey repository\nimport math\n\nimport torch\nfrom torch import Tensor, nn\n\n\ndef get_activatio"
},
{
"path": "alm/models/architectures/tools/transformer_adpt.py",
"chars": 14088,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom alm.models.architecture"
},
{
"path": "alm/models/architectures/tools/utils.py",
"chars": 7517,
"preview": "import math\nimport torch\nimport torch.nn as nn\n\nclass PeriodicPositionalEncoding(nn.Module):\n def __init__(self, d_mo"
},
{
"path": "alm/models/get_model.py",
"chars": 595,
"preview": "import importlib\n\ndef get_model(cfg, datamodule):\n modeltype = cfg.model.model_type\n return get_module(cfg, datamo"
},
{
"path": "alm/models/losses/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alm/models/losses/utils.py",
"chars": 7112,
"preview": "import torch\nimport torch.nn as nn\n\nclass MaskedConsistency:\n def __init__(self) -> None:\n self.loss = nn.MSEL"
},
{
"path": "alm/models/losses/voca.py",
"chars": 7145,
"preview": "# import numpy as np\n# import torch\n# import torch.nn as nn\n# from torchmetrics import Metric\n\n# class almLosses(Metric)"
},
{
"path": "alm/models/modeltype/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alm/models/modeltype/base.py",
"chars": 4311,
"preview": "import os\nfrom pathlib import Path\nimport numpy as np\nfrom pytorch_lightning import LightningModule\nimport torch\nfrom co"
},
{
"path": "alm/models/modeltype/diffusion_bias.py",
"chars": 34657,
"preview": "import torch\nfrom torch.optim import AdamW, Adam\nimport torch.nn.functional as F\nfrom torchmetrics import MetricCollecti"
},
{
"path": "alm/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "alm/utils/demo_utils.py",
"chars": 13178,
"preview": "from transformers import Wav2Vec2Processor\nimport numpy as np\nimport librosa\nimport os\nimport torch\nimport cv2\nimport py"
},
{
"path": "alm/utils/logger.py",
"chars": 2454,
"preview": "from pathlib import Path\nimport os\nimport time\nimport logging\nfrom omegaconf import OmegaConf\nfrom pytorch_lightning.uti"
},
{
"path": "alm/utils/temos_utils.py",
"chars": 574,
"preview": "from typing import Dict, List\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\ndef lengths_to_mask(lengths: Te"
},
{
"path": "configs/assets/biwi.yaml",
"chars": 162,
"preview": "FOLDER: './experiments/biwi' # Experiment files saving path\n\nTEST:\n FOLDER: './results' # Testing files saving path\n\nDA"
},
{
"path": "configs/assets/vocaset.yaml",
"chars": 171,
"preview": "FOLDER: './experiments/vocaset' # Experiment files saving path\n\nTEST:\n FOLDER: './results' # Testing files saving path\n"
},
{
"path": "configs/base.yaml",
"chars": 1448,
"preview": "# FOLDER: ./experiments\nSEED_VALUE: 1234\nDEBUG: True\nTRAIN:\n SPLIT: 'train'\n NUM_WORKERS: 1 #2 # Number of workers\n B"
},
{
"path": "configs/diffusion/biwi/diffspeaker_hubert_biwi.yaml",
"chars": 2743,
"preview": "NAME: diffspeaker_hubert_biwi # Experiment name\nDEBUG: False # Debug mode\nACCELERATOR: 'gpu' # Devices optioncal: “cpu”,"
},
{
"path": "configs/diffusion/biwi/diffspeaker_wav2vec2_biwi.yaml",
"chars": 2861,
"preview": "NAME: diffspeaker_wav2vec2_biwi # Experiment name\nDEBUG: False # Debug mode\nACCELERATOR: 'gpu' # Devices optioncal: “cpu"
},
{
"path": "configs/diffusion/diffusion_bias_modules/denoiser.yaml",
"chars": 1207,
"preview": "denoiser: # this is copied from configs/baselines/transformer_adpt_modules/transformer.yaml\n target: alm.models.archite"
},
{
"path": "configs/diffusion/diffusion_bias_modules/scheduler.yaml",
"chars": 770,
"preview": "scheduler:\n target: diffusers.DDIMScheduler\n num_inference_timesteps: 50\n eta: 0.0\n params:\n num_train_timesteps:"
},
{
"path": "configs/diffusion/vocaset/diffspeaker_hubert_vocaset.yaml",
"chars": 2618,
"preview": "NAME: diffspeaker_hubert_vocaset # Experiment name\nDEBUG: False # Debug mode\nACCELERATOR: 'gpu' # Devices optioncal: “cp"
},
{
"path": "configs/diffusion/vocaset/diffspeaker_wav2vec2_vocaset.yaml",
"chars": 2621,
"preview": "NAME: diffspeaker_wav2vec2_vocaset # Experiment name\nDEBUG: False # Debug mode\nACCELERATOR: 'gpu' # Devices optioncal: “"
},
{
"path": "datasets/biwi/README.md",
"chars": 92,
"preview": "should contanin\n```\n regions\n templates\n templates.pkl\n vertices_npy\n wav\n```"
},
{
"path": "datasets/biwi/regions/fdd.txt",
"chars": 50854,
"preview": "7, 9, 59, 89, 94, 103, 122, 126, 152, 159, 160, 161, 164, 166, 169, 174, 177, 178, 179, 180, 182, 183, 184, 185, 195, 21"
},
{
"path": "datasets/biwi/regions/lve.txt",
"chars": 32611,
"preview": "2, 6, 21, 23, 24, 25, 27, 28, 29, 39, 45, 47, 49, 50, 55, 60, 68, 70, 73, 76, 80, 81, 82, 83, 85, 86, 87, 133, 170, 171,"
},
{
"path": "datasets/vocaset/FLAME_masks.pkl",
"chars": 215062,
"preview": "(dp0\nS'eye_region'\np1\ncnumpy.core.multiarray\n_reconstruct\np2\n(cnumpy\nndarray\np3\n(I0\ntp4\nS'b'\np5\ntp6\nRp7\n(I1\n(I751\ntp8\ncn"
},
{
"path": "datasets/vocaset/README.md",
"chars": 80,
"preview": "should contanin\n```\n templates\n templates.pkl\n vertices_npy\n wav\n```"
},
{
"path": "datasets/vocaset/templates/README.md",
"chars": 148,
"preview": "Put \"FLAME_sample.ply\" in this folder. \"FLAME_sample.ply\" can be downloaded from [voca](https://github.com/TimoBolkart/v"
},
{
"path": "datasets/vocaset/templates.pkl",
"chars": 4832756,
"preview": "(dp0\nS'FaceTalk_170904_00128_TA'\np1\ncnumpy.core.multiarray\n_reconstruct\np2\n(cnumpy\nndarray\np3\n(I0\ntp4\nS'b'\np5\ntp6\nRp7\n(I"
},
{
"path": "demo_biwi.py",
"chars": 3165,
"preview": "import os\nimport pickle\nimport torch\n\nfrom alm.config import parse_args\nfrom alm.models.get_model import get_model\nfrom "
},
{
"path": "demo_vocaset.py",
"chars": 7443,
"preview": "import os\nimport pickle\nimport torch\nimport torch.nn.functional as F\n\nfrom alm.config import parse_args\nfrom alm.models."
},
{
"path": "demo_vocaset_text.py",
"chars": 4681,
"preview": "import os\nimport pickle\nimport torch\n\nfrom alm.config import parse_args\nfrom alm.models.get_model import get_model\nfrom "
},
{
"path": "eval_biwi.py",
"chars": 5370,
"preview": "#############################################################################################################\n# this fil"
},
{
"path": "eval_vocaset.py",
"chars": 4626,
"preview": "#############################################################################################################\n# this fil"
},
{
"path": "requirements.txt",
"chars": 298,
"preview": "imageio==2.21.2\nlibrosa==0.9.2\nlmdb==1.3.0\nnumpy==1.23.5\nomegaconf==2.3.0\nopencv_python==4.7.0.72\npsbody_mesh==0.4\npsuti"
},
{
"path": "scripts/demo/demo_biwi.sh",
"chars": 781,
"preview": "export CUDA_VISIBLE_DEVICES=0\n\n# use hubert backbone\npython demo_biwi.py \\\n --cfg configs/diffusion/biwi/diffspeaker_"
},
{
"path": "scripts/demo/demo_vocaset.sh",
"chars": 887,
"preview": "export CUDA_VISIBLE_DEVICES=1\n\n# # use hubert backbone\n# python demo_vocaset.py \\\n# --cfg configs/diffusion/vocaset/"
},
{
"path": "scripts/diffusion/biwi_evaluation/diffspeaker_hubert_biwi.sh",
"chars": 160,
"preview": "export CUDA_VISIBLE_DEVICES=0\npython eval_biwi.py \\\n --cfg configs/diffusion/biwi/diffspeaker_hubert_biwi.yaml \\\n "
},
{
"path": "scripts/diffusion/biwi_evaluation/diffspeaker_wav2vec2_biwi.sh",
"chars": 162,
"preview": "export CUDA_VISIBLE_DEVICES=0\npython eval_biwi.py \\\n --cfg configs/diffusion/biwi/diffspeaker_wav2vec2_biwi.yaml \\\n "
},
{
"path": "scripts/diffusion/biwi_training/diffspeaker_hubert_biwi.sh",
"chars": 207,
"preview": "export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\npython -m train \\\n --cfg configs/diffusion/biwi/diffspeaker_hubert_biwi.y"
},
{
"path": "scripts/diffusion/biwi_training/diffspeaker_wav2vec2_biwi.sh",
"chars": 209,
"preview": "export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\npython -m train \\\n --cfg configs/diffusion/biwi/diffspeaker_wav2vec2_biwi"
},
{
"path": "scripts/diffusion/vocaset_evaluation/diffspeaker_hubert_vocaset.sh",
"chars": 172,
"preview": "export CUDA_VISIBLE_DEVICES=0\npython eval_vocaset.py \\\n --cfg configs/diffusion/vocaset/diffspeaker_hubert_vocaset.ya"
},
{
"path": "scripts/diffusion/vocaset_evaluation/diffspeaker_wav2vec2_vocaset.sh",
"chars": 174,
"preview": "export CUDA_VISIBLE_DEVICES=0\npython eval_vocaset.py \\\n --cfg configs/diffusion/vocaset/diffspeaker_wav2vec2_vocaset."
},
{
"path": "scripts/diffusion/vocaset_training/diffspeaker_hubert_vocaset.sh",
"chars": 203,
"preview": "export CUDA_VISIBLE_DEVICES=0\npython -m train \\\n --cfg configs/diffusion/vocaset/diffspeaker_hubert_vocaset.yaml \\\n "
},
{
"path": "scripts/diffusion/vocaset_training/diffspeaker_wav2vec2_vocaset.sh",
"chars": 205,
"preview": "export CUDA_VISIBLE_DEVICES=0\npython -m train \\\n --cfg configs/diffusion/vocaset/diffspeaker_wav2vec2_vocaset.yaml \\\n"
},
{
"path": "train.py",
"chars": 8560,
"preview": "import os\nfrom pprint import pformat\n\nimport pytorch_lightning as pl\nimport torch\nfrom omegaconf import OmegaConf\nfrom p"
}
]
// ... and 3 more files (download for full content)
About this extraction
This page contains the full source code of the theEricMa/DiffSpeaker GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 66 files (5.1 MB), approximately 1.3M tokens, and a symbol index with 157 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.