Full Code of theEricMa/DiffSpeaker for AI

main 0f2ae2d4e88f cached
66 files
5.1 MB
1.3M tokens
157 symbols
1 requests
Download .txt
Showing preview only (5,345K chars total). Download the full file or copy to clipboard to get everything.
Repository: theEricMa/DiffSpeaker
Branch: main
Commit: 0f2ae2d4e88f
Files: 66
Total size: 5.1 MB

Directory structure:
gitextract_v4nqs5sn/

├── README.md
├── alm/
│   ├── callback/
│   │   ├── __init__.py
│   │   └── progress.py
│   ├── config.py
│   ├── data/
│   │   ├── BIWI/
│   │   │   ├── __init__.py
│   │   │   └── dataset.py
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── biwi.py
│   │   ├── get_data.py
│   │   ├── voca/
│   │   │   ├── __init__.py
│   │   │   └── dataset.py
│   │   └── vocaset.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── architectures/
│   │   │   ├── __init__.py
│   │   │   ├── adpt_bias_denoiser.py
│   │   │   └── tools/
│   │   │       ├── embeddings.py
│   │   │       ├── transformer_adpt.py
│   │   │       └── utils.py
│   │   ├── get_model.py
│   │   ├── losses/
│   │   │   ├── __init__.py
│   │   │   ├── utils.py
│   │   │   └── voca.py
│   │   └── modeltype/
│   │       ├── __init__.py
│   │       ├── base.py
│   │       └── diffusion_bias.py
│   └── utils/
│       ├── __init__.py
│       ├── demo_utils.py
│       ├── logger.py
│       └── temos_utils.py
├── configs/
│   ├── assets/
│   │   ├── biwi.yaml
│   │   └── vocaset.yaml
│   ├── base.yaml
│   └── diffusion/
│       ├── biwi/
│       │   ├── diffspeaker_hubert_biwi.yaml
│       │   └── diffspeaker_wav2vec2_biwi.yaml
│       ├── diffusion_bias_modules/
│       │   ├── denoiser.yaml
│       │   └── scheduler.yaml
│       └── vocaset/
│           ├── diffspeaker_hubert_vocaset.yaml
│           └── diffspeaker_wav2vec2_vocaset.yaml
├── datasets/
│   ├── biwi/
│   │   ├── README.md
│   │   ├── regions/
│   │   │   ├── fdd.txt
│   │   │   └── lve.txt
│   │   ├── templates/
│   │   │   └── BIWI.ply
│   │   └── templates.pkl
│   └── vocaset/
│       ├── FLAME_masks.pkl
│       ├── README.md
│       ├── templates/
│       │   ├── FLAME_sample.ply
│       │   └── README.md
│       └── templates.pkl
├── demo_biwi.py
├── demo_vocaset.py
├── demo_vocaset_text.py
├── eval_biwi.py
├── eval_vocaset.py
├── requirements.txt
├── scripts/
│   ├── demo/
│   │   ├── demo_biwi.sh
│   │   └── demo_vocaset.sh
│   └── diffusion/
│       ├── biwi_evaluation/
│       │   ├── diffspeaker_hubert_biwi.sh
│       │   └── diffspeaker_wav2vec2_biwi.sh
│       ├── biwi_training/
│       │   ├── diffspeaker_hubert_biwi.sh
│       │   └── diffspeaker_wav2vec2_biwi.sh
│       ├── vocaset_evaluation/
│       │   ├── diffspeaker_hubert_vocaset.sh
│       │   └── diffspeaker_wav2vec2_vocaset.sh
│       └── vocaset_training/
│           ├── diffspeaker_hubert_vocaset.sh
│           └── diffspeaker_wav2vec2_vocaset.sh
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# DiffSpeaker: Speech-Driven 3D Facial Animation with Diffusion Transformer
## [Paper](https://arxiv.org/pdf/2402.05712.pdf) | [Demo](https://www.youtube.com/watch?v=4-NBygHePk0)

## Update
- [30/03/2024]: The evaluation code is updated. 
- [07/02/2024]: The inference script is released. 
- [06/02/2024]: The model weight is released.

## Get started
### Environment Setup
```
conda create --name diffspeaker python=3.9
conda activate diffspeaker
```
Install MPI-IS. Follow the command in [MPI-IS](https://github.com/MPI-IS/mesh) to install the package. Depending on if you have `/usr/include/boost/` directories, The command is likely to be
```
git clone https://github.com/MPI-IS/mesh.git
cd mesh
sudo apt-get install libboost-dev
python -m pip install pip==20.2.4
BOOST_INCLUDE_DIRS=/usr/include/boost/ make all
python -m pip install --upgrade pip
```
Then install the rest of the dependencies.
```
cd ..
git clone https://github.com/theEricMa/DiffSpeaker.git
cd DiffSpeaker
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install imageio-ffmpeg
pip install -r requirements.txt
```
### Model Weights
You can access the model parameters by clicking [here](https://drive.google.com/drive/folders/1PezaNpQHIjyE8UE5YW0jpDPV8jtepxSL?usp=sharing). Place the `checkpoints` folder into the root directory of your project. This folder includes the models that have been trained on the `BIWI` and `vocaset` datasets, utilizing `wav2vec` and `hubert` as the backbones.
### Prediction
For the BIWI model, use the script below to perform inference on your chosen audio files. Specify the audio file using the `--example` argument.
```
sh scripts/demo/demo_biwi.sh
```
For the vocaset model, run the following script.
```
sh scripts/demo/demo_vocaset.sh
```
### Evaluation
To obtain the metrics reported in the paper, use the scripts in `scripts/diffusion/biwi_evaluation` and `scripts/diffusion/vocaset_evaluation`. For example, to evaluate DiffSpeaker in BIWI dataset with the hubert backbone, use the following script.
```
sh scripts/diffusion/biwi_evaluation/diffspeaker_hubert_biwi.sh
```

## Training
### Data Preparation 

### Model Training
```
mkdir experiments
```



================================================
FILE: alm/callback/__init__.py
================================================
from .progress import ProgressLogger


================================================
FILE: alm/callback/progress.py
================================================
import logging

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
import psutil

logger = logging.getLogger()


class ProgressLogger(Callback):

    def __init__(self, metric_monitor: dict, precision: int = 3):
        # Metric to monitor
        self.metric_monitor = metric_monitor
        self.precision = precision

    def on_train_start(self, trainer: Trainer, pl_module: LightningModule,
                       **kwargs) -> None:
        logger.info("Training started")

    def on_train_end(self, trainer: Trainer, pl_module: LightningModule,
                     **kwargs) -> None:
        logger.info("Training done")

    def on_validation_epoch_end(self, trainer: Trainer,
                                pl_module: LightningModule, **kwargs) -> None:
        if trainer.sanity_checking:
            logger.info("Sanity checking ok.")

    def on_train_epoch_end(self,
                           trainer: Trainer,
                           pl_module: LightningModule,
                           padding=False,
                           **kwargs) -> None:
        metric_format = f"{{:.{self.precision}e}}"
        line = f"Epoch {trainer.current_epoch}"
        if padding:
            line = f"{line:>{len('Epoch xxxx')}}"  # Right padding
        metrics_str = []

        losses_dict = trainer.callback_metrics
        for metric_name, dico_name in self.metric_monitor.items():
            if dico_name in losses_dict:
                metric = losses_dict[dico_name].item()
                metric = metric_format.format(metric)
                metric = f"{metric_name} {metric}"
                metrics_str.append(metric)

        if len(metrics_str) == 0:
            return

        memory = f"Memory {psutil.virtual_memory().percent}%"
        line = line + ": " + "   ".join(metrics_str) + "   " + memory
        logger.info(line)


================================================
FILE: alm/config.py
================================================
import importlib
from argparse import ArgumentParser
from omegaconf import OmegaConf
import os


def get_module_config(cfg_model, path="modules"):
    files = os.listdir(f'./configs/{path}/')
    for file in files:
        if file.endswith('.yaml'):
            with open(f'./configs/{path}/' + file, 'r') as f:
                cfg_model.merge_with(OmegaConf.load(f))
    return cfg_model


def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def instantiate_from_config(config):
    if not "target" in config:
        if config == '__is_first_stage__':
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


def parse_args(phase="train"):
    parser = ArgumentParser()

    group = parser.add_argument_group("Training options")
    if phase in ["train", "test", "demo"]:
        group.add_argument(
            "--cfg",
            type=str,
            required=False,
            default="./configs/config.yaml",
            help="config file",
        )
        group.add_argument(
            "--cfg_assets",
            type=str,
            required=False,
            default="./configs/assets.yaml",
            help="config file for asset paths",
        )
        group.add_argument(
            "--frame_rate",
            type=float,
            default=30,
            help="the frame rate for the input/output motion",
        )
        group.add_argument(
            "--resume",
            type=str,
            required=False,
            help="resume from a checkpoint",
        )
        group.add_argument("--batch_size",
                           type=int,
                           required=False,
                           help="training batch size")
        group.add_argument("--device",
                           type=int,
                           nargs="+",
                           required=False,
                           help="training device")
        group.add_argument("--nodebug",
                           action="store_true",
                           required=False,
                           help="debug or not")
        group.add_argument("--dir",
                           type=str,
                           required=False,
                           help="evaluate existing npys")


    if phase == "demo":
        # group.add_argument("--motion_transfer", action='store_true', help="Motion Distribution Transfer")
        group.add_argument("--render",
                           action="store_true",
                           help="Render visulizaed figures")
        group.add_argument("--render_mode", type=str, help="video or sequence")
        group.add_argument(
            "--example",
            type=str,
            required=False,
            help="input text and lengths with txt format",
        )
        group.add_argument(
            "--out_dir",
            type=str,
            required=False,
            help="output dir",
        )
        group.add_argument(
            "--template",
            type=str,
            required=False,
            help="template path",
        )
        group.add_argument(
            "--checkpoint",
            type=str,
            required=True,
            help="output seperate or combined npy file",
        )
        group.add_argument(
            "--id",
            type=str,
            required=True,
            help="the candiate subect identity",
        )
        group.add_argument(
            "--ply",
            type=str,
            required=True,
            help="the candiate subect identity",
        )

    if phase == "render":
        group.add_argument(
            "--cfg",
            type=str,
            required=False,
            default="./configs/render.yaml",
            help="config file",
        )
        group.add_argument(
            "--cfg_assets",
            type=str,
            required=False,
            default="./configs/assets.yaml",
            help="config file for asset paths",
        )
        # group.add_argument("--motion_transfer", action='store_true', help="Motion Distribution Transfer")
        group.add_argument("--dir",
                           type=str,
                           required=False,
                           default=None,
                           help="npy motion folder")

    # remove None params, and create a dictionnary
    params = parser.parse_args()
    # params = {key: val for key, val in vars(opt).items() if val is not None}

    # update config from files
    cfg_base = OmegaConf.load('./configs/base.yaml')
    cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg))
    cfg_assets = OmegaConf.load(params.cfg_assets)
    cfg_model = get_module_config(cfg_exp.model, cfg_exp.model.target)
    cfg = OmegaConf.merge(cfg_exp, cfg_model, cfg_assets)

    if phase in ["train", "test"]:
        cfg.TRAIN.BATCH_SIZE = (params.batch_size
                                if params.batch_size else cfg.TRAIN.BATCH_SIZE)
        cfg.DEVICE = params.device if params.device else cfg.DEVICE
        cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG
        
        # resume from a checkpoint, added by Zhiyuan Ma
        cfg.TRAIN.RESUME = params.resume if params.resume else cfg.TRAIN.RESUME
        
        # no debug in test
        cfg.DEBUG = False if phase == "test" else cfg.DEBUG
        if phase == "test":
            cfg.DEBUG = False
            cfg.DEVICE = [0]
            print("Force no debugging and one gpu when testing")
        cfg.TEST.TEST_DIR = params.dir if params.dir else cfg.TEST.TEST_DIR

    if phase == "demo":
        # cfg.DEMO.MOTION_TRANSFER = params.motion_transfer
        cfg.DEMO.RENDER = params.render
        cfg.DEMO.FRAME_RATE = params.frame_rate
        cfg.DEMO.EXAMPLE = params.example
        cfg.DEMO.CHECKPOINTS = params.checkpoint
        cfg.DEMO.TEMPLATE = params.template
        cfg.DEMO.ID = params.id
        cfg.DEMO.PLY = params.ply
        cfg.TEST.FOLDER = params.out_dir if params.dir else cfg.TEST.FOLDER

    if phase == "render":
        if params.npy:
            cfg.RENDER.NPY = params.npy
            cfg.RENDER.INPUT_MODE = "npy"
        if params.dir:
            cfg.RENDER.DIR = params.dir
            cfg.RENDER.INPUT_MODE = "dir"
        cfg.RENDER.JOINT_TYPE = params.joint_type
        cfg.RENDER.MODE = params.mode

    # debug mode
    if cfg.DEBUG:
        cfg.NAME = "debug--" + cfg.NAME
        cfg.LOGGER.WANDB.OFFLINE = True
        cfg.LOGGER.VAL_EVERY_STEPS = 1

    return cfg


================================================
FILE: alm/data/BIWI/__init__.py
================================================
from .dataset import BIWIDataset

================================================
FILE: alm/data/BIWI/dataset.py
================================================
import numpy as np
import torch
from torch.utils import data
from transformers import Wav2Vec2Processor
from collections import defaultdict
import os
from tqdm import tqdm
import numpy as np
import pickle



class BIWIDataset(data.Dataset):

    def __init__(self, 
                data, 
                subjects_dict, 
                data_type="train",
                ):

        self.data = data
        self.len = len(self.data)
        self.subjects_dict = subjects_dict
        self.data_type = data_type
        self.one_hot_labels = np.eye(len(subjects_dict["train"]))

        self.repeat = 20 if data_type == 'train' else 1
        
    def __len__(self):
        return self.len * self.repeat
    
    def __getitem__(self, index):
        index = index % self.len

        # seq_len, fea_dim
        file_name = self.data[index]["name"]
        file_path = self.data[index]["path"]
        audio = self.data[index]["audio"]
        vertice = self.data[index]["vertice"]
        template = self.data[index]["template"]
        if self.data_type == "train":
            subject = "_".join(file_name.split("_")[:-1])
            one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)]
        elif self.data_type == "val":
            one_hot = self.one_hot_labels
        elif self.data_type == "test":
            subject = "_".join(file_name.split("_")[:-1])
            if subject in self.subjects_dict["train"]:
                one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)]
            else:
                one_hot = self.one_hot_labels


        return {
            'audio':torch.FloatTensor(audio),
            'audio_attention':torch.ones_like(torch.Tensor(audio)).long(),
            'vertice':torch.FloatTensor(vertice), 
            'vertice_attention':torch.ones_like(torch.Tensor(vertice)[..., 0]).long(),
            'template':torch.FloatTensor(template), 
            'id':torch.FloatTensor(one_hot), 
            'file_name':file_name,
            'file_path':file_path
        }


    


================================================
FILE: alm/data/__init__.py
================================================


================================================
FILE: alm/data/base.py
================================================
from os.path import join as pjoin
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset


class BASEDataModule(pl.LightningDataModule):

    def __init__(self, collate_fn, batch_size: int, num_workers: int):
        super().__init__()

        # self.dataloader_options = {
        #     "batch_size": batch_size, "num_workers": num_workers,"collate_fn": collate_datastruct_and_text}
        self.dataloader_options = {
            "batch_size": batch_size,
            "num_workers": num_workers,
            "collate_fn": collate_fn,
            "prefetch_factor": 2,
            "pin_memory": True, 
        }

        # self.collate_fn = collate_fn
        self.persistent_workers = True
        self.is_mm = False
        # need to be overloaded:
        # - self.Dataset
        # - self._sample_set => load only a small subset
        #   There is an helper bellow (get_sample_set)
        # - self.nfeats
        # - self.transforms

    def __getattr__(self, item):
        # train_dataset/val_dataset etc cached like properties
        # question
        if item.endswith("_dataset") and not item.startswith("_"):
            subset = item[:-len("_dataset")]
            item_c = "_" + item
            if item_c not in self.__dict__:
                # todo: config name not consistent
                subset = subset.upper() if subset != "val" else "EVAL"
                split = eval(f"self.cfg.{subset}.SPLIT")
                split_file = pjoin(
                    eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
                    eval(f"self.cfg.{subset}.SPLIT") + ".txt",
                )
                self.__dict__[item_c] = self.Dataset(split_file=split_file,
                                                     split=split,
                                                     **self.hparams)
            return getattr(self, item_c)
        classname = self.__class__.__name__
        raise AttributeError(f"'{classname}' object has no attribute '{item}'")

    def setup(self, stage=None):
        self.stage = stage
        # Use the getter the first time to load the data
        if stage in (None, "fit"):
            _ = self.train_dataset
            _ = self.val_dataset
        if stage in (None, "test"):
            _ = self.test_dataset

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            shuffle=True,
            persistent_workers=True,
            **self.dataloader_options,
        )

    def predict_dataloader(self):
        dataloader_options = self.dataloader_options.copy()
        dataloader_options[
            "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
        dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
        dataloader_options["shuffle"] = False
        return DataLoader(
            self.test_dataset,
            persistent_workers=True,
            **dataloader_options,
        )

    def val_dataloader(self):
        # overrides batch_size and num_workers
        dataloader_options = self.dataloader_options.copy()
        dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
        dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
        dataloader_options["shuffle"] = False
        return DataLoader(
            self.val_dataset,
            persistent_workers=True,
            **dataloader_options,
        )

    def test_dataloader(self):
        # overrides batch_size and num_workers
        dataloader_options = self.dataloader_options.copy()
        dataloader_options[
            "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
        dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
        # dataloader_options["drop_last"] = True
        dataloader_options["shuffle"] = False
        return DataLoader(
            self.test_dataset,
            persistent_workers=True,
            **dataloader_options,
        )


================================================
FILE: alm/data/biwi.py
================================================
from .base import BASEDataModule
from transformers import Wav2Vec2Processor
from collections import defaultdict
from .BIWI import BIWIDataset

import os
from os.path import join as pjoin
import pickle
from tqdm import tqdm
import librosa
import numpy as np
from multiprocessing import Pool


def load_data(args):
    file, root_dir, processor, templates, audio_dir, vertice_dir = args
    if file.endswith('wav'):
        wav_path = os.path.join(root_dir, audio_dir, file)
        speech_array, sampling_rate = librosa.load(wav_path, sr=16000)
        input_values = np.squeeze(processor(speech_array,sampling_rate=16000).input_values)
        key = file.replace("wav", "npy")
        result = {}
        result["audio"] = input_values
        subject_id = "_".join(key.split("_")[:-1])
        temp = templates[subject_id]
        result["name"] = file.replace(".wav", "")
        result["path"] = os.path.abspath(wav_path)
        result["template"] = temp.reshape((-1)) 
        vertice_path = os.path.join(root_dir, vertice_dir, file.replace("wav", "npy"))
        if not os.path.exists(vertice_path):
            return None
        else:
            result["vertice"] = np.load(vertice_path,allow_pickle=True) # we do not need to [::2,:], as did in vocaset
            return (key, result)

class BIWIDataModule(BASEDataModule):
    def __init__(self,
                cfg,
                batch_size,
                num_workers,
                collate_fn = None,
                phase="train",
                **kwargs):
        super().__init__(batch_size=batch_size,
                            num_workers=num_workers,
                            collate_fn=collate_fn)
        self.save_hyperparameters(logger=False)
        self.name = 'VOCASET'
        self.Dataset = BIWIDataset # this one is the same
        self.cfg = cfg
        
        # customized to VOCASET
        self.subjects = {
            'train': [
                "F2",
                "F3",
                "F4",
                "M3",
                "M4",
                "M5",
            ],
            'val': [
                "F2",
                "F3",
                "F4",
                "M3",
                "M4",
                "M5",                
            ],
            # 'test': [ # for BIWI test B
            #     "F1",
            #     "F5",
            #     "F6",
            #     "F7",
            #     "F8",
            #     "M1",
            #     "M2",
            #     "M6"
            # ]
            'test': [ # for BIWI test A
                "F2",
                "F3",
                "F4",
                "M3",
                "M4",
                "M5",
            ]
        }

        self.root_dir = kwargs.get('data_root', 'datasets/BIWI')
        self.audio_dir = 'wav'
        self.vertice_dir = 'vertices_npy'
        processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        self.template_file = 'templates.pkl'

        self.nfeats = 70110

        # load
        data = defaultdict(dict)
        with open(os.path.join(self.root_dir, self.template_file), 'rb') as fin:
            templates = pickle.load(fin, encoding='latin1')

        count = 0
        args_list = []
        for r, ds, fs in os.walk(os.path.join(self.root_dir, self.audio_dir)):
            for f in fs:
                args_list.append((f, self.root_dir, processor, templates, self.audio_dir, self.vertice_dir, ))
                
                # # comment off for full dataset
                # count += 1
                # if count > 50:
                #     break

        # split dataset
        self.data_splits = {
            'train':[],
            'val':[],
            'test':[],
        }

        splits = {
                    'train':range(1,33),
                    'val':range(33,37),
                    'test':range(37,41)
                }

        motion_list = []

        with Pool(processes=10) as pool:
            results = pool.map(load_data, args_list)
            for result in results:
                if result is not None:
                    key, value = result
                    data[key] = value

                # motion = value["vertice"] - value["template"]
                # motion_list.append(motion)

        # # calculate mean and std
        # import pdb; pdb.set_trace()
        # motion_list = np.concatenate(motion_list, axis=0)
        # self.mean = np.mean(motion_list, axis=0)
        # self.std = np.std(motion_list, axis=0)

        for k, v in data.items():
            subject_id = "_".join(k.split("_")[:-1])
            sentence_id = int(k.split(".")[0][-2:])
            for sub in ['train', 'val', 'test']:
                if subject_id in self.subjects[sub] and sentence_id in splits[sub]:
                    self.data_splits[sub].append(v)

        # split dataset
        self.data_splits = {
            'train':[],
            'val':[],
            'test':[],
        }

        # splits = {
        #             'train':range(1,41),
        #             'val':range(21,41),
        #             'test':range(21,41)
        #         }

        for k, v in data.items():
            subject_id = "_".join(k.split("_")[:-1])
            sentence_id = int(k.split(".")[0][-2:])
            for sub in ['train', 'val', 'test']:
                if subject_id in self.subjects[sub] and sentence_id in splits[sub]:
                    self.data_splits[sub].append(v)

        # self._sample_set = self.__getattr__("test_dataset")


    def __getattr__(self, item):
        # train_dataset/val_dataset etc cached like properties``
        # question
        if item.endswith("_dataset") and not item.startswith("_"):
            subset = item[:-len("_dataset")]
            item_c = "_" + item
            if item_c not in self.__dict__:
                # todo: config name not consistent
                self.__dict__[item_c] = self.Dataset(
                    data = self.data_splits[subset] ,
                    subjects_dict = self.subjects,
                    data_type = subset
                )
            return getattr(self, item_c)
        classname = self.__class__.__name__
        raise AttributeError(f"'{classname}' object has no attribute '{item}'")

================================================
FILE: alm/data/get_data.py
================================================
import numpy as np
import torch

def collate_tensors(batch):
    dims = batch[0].dim()
    max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
    size = (len(batch), ) + tuple(max_size)
    canvas = batch[0].new_zeros(size=size)
    for i, b in enumerate(batch):
        sub_tensor = canvas[i]
        for d in range(dims):
            sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
        sub_tensor.add_(b)
    return canvas

def vocaset_collate_fn(batch):
    notnone_batches = [b for b in batch if b is not None]
    # notnone_batches.sort(key=lambda x: x['vertice_length'], reverse=True)
    adapted_batch = {
        'audio': collate_tensors([b['audio'].float() for b in notnone_batches]),
        'audio_attention': collate_tensors([b['audio_attention'] for b in notnone_batches]),
        'vertice': collate_tensors([b['vertice'].float() for b in notnone_batches]),
        'vertice_attention': collate_tensors([b['vertice_attention'] for b in notnone_batches]),
        'template': collate_tensors([b['template'].float() for b in notnone_batches]),
        'id': collate_tensors([b['id'].float() for b in notnone_batches]),
        'file_name': [b['file_name'] for b in notnone_batches],
        'file_path': [b['file_path'] for b in notnone_batches],
    }
    return adapted_batch

def voxcelebinsta_collate_fn(batch):
    notnone_batches = [b for b in batch if b is not None]
    # notnone_batches.sort(key=lambda x: x['vertice_length'], reverse=True)
    adapted_batch = {
        'audio': collate_tensors([b['audio'].float() for b in notnone_batches]),
        'audio_attention': collate_tensors([b['audio_attention'] for b in notnone_batches]),
        'vertice': collate_tensors([b['vertice'].float() for b in notnone_batches]),
        'vertice_attention': collate_tensors([b['vertice_attention'] for b in notnone_batches]),
        'template': collate_tensors([b['template'].float() for b in notnone_batches]),
        'id': collate_tensors([b['id'].float() for b in notnone_batches]),
        'file_name': [b['file_name'] for b in notnone_batches],
    }
    if 'pose' in notnone_batches[0]:
        adapted_batch['pose'] = collate_tensors([b['pose'].float() for b in notnone_batches])
        adapted_batch['pose_attention'] = collate_tensors([b['pose_attention'] for b in notnone_batches])
    if 'exp' in notnone_batches[0]:
        adapted_batch['exp'] = collate_tensors([b['exp'].float() for b in notnone_batches])
        adapted_batch['exp_attention'] = collate_tensors([b['exp_attention'] for b in notnone_batches])
    if 'image' in notnone_batches[0]:
        adapted_batch['image'] = collate_tensors([b['image'].float() for b in notnone_batches])
        adapted_batch['image_attention'] = collate_tensors([b['image_attention'] for b in notnone_batches])
    if 'depth' in notnone_batches[0]:
        adapted_batch['depth'] = collate_tensors([b['depth'].float() for b in notnone_batches])
        adapted_batch['depth_attention'] = collate_tensors([b['depth_attention'] for b in notnone_batches])
    if 'seg' in notnone_batches[0]:
        adapted_batch['seg'] = collate_tensors([b['seg'].float() for b in notnone_batches])
        adapted_batch['seg_attention'] = collate_tensors([b['seg_attention'] for b in notnone_batches])
    return adapted_batch

def voxcelebinstacoeflmdb_collate_fn(batch):
    notnone_batches = [b for b in batch if b is not None]
    # notnone_batches.sort(key=lambda x: x['vertice_length'], reverse=True)
    adpated_batch = {
        ### audio related ##########################################
        'audio': collate_tensors([b['audio'].float() for b in notnone_batches]),
        'audio_attention': collate_tensors([b['audio_attention'] for b in notnone_batches]),
        ### none-predictive features ###############################
        'vertice': collate_tensors([b['vertice'].float() for b in notnone_batches]),
        'shape': collate_tensors([b['flame_shape'].float() for b in notnone_batches]),
        'template': collate_tensors([b['template'].float() for b in notnone_batches]),
        'id': collate_tensors([b['id'].float() for b in notnone_batches]),
        'coefficient_attention': collate_tensors([b['coef_attention'] for b in notnone_batches]),
        'file_name': [b['file_name'] for b in notnone_batches],
        ### predictive features ####################################
        'exp': collate_tensors([b['flame_exp'].float() for b in notnone_batches]),
        'jaw': collate_tensors([b['flame_jaw'].float() for b in notnone_batches]),
        'eyes': collate_tensors([b['flame_eyes'].float() for b in notnone_batches]),
        'eyelids': collate_tensors([b['flame_eyelids'].float() for b in notnone_batches]),
    }
    # pose-related features ####################################
    ### none-predictive features ###############################
    if 'flame_fl' in notnone_batches[0]:
        adpated_batch['fl'] = collate_tensors([b['flame_fl'].float() for b in notnone_batches])
    if 'flame_pp' in notnone_batches[0]:
        adpated_batch['pp'] = collate_tensors([b['flame_pp'].float() for b in notnone_batches])
    ### predictive features ####################################
    if 'flame_R' in notnone_batches[0]:
        adpated_batch['R'] = collate_tensors([b['flame_R'].float() for b in notnone_batches])
    if 'flame_t' in notnone_batches[0]:
        adpated_batch['T'] = collate_tensors([b['flame_t'].float() for b in notnone_batches])
    return adpated_batch


def get_datasets(cfg, logger, phase='train'):
    dataset_names = eval(f"cfg.{phase.upper()}.DATASETS")
    datasets = []
    for dataset_name in dataset_names:
        if dataset_name.lower() in ['vocaset']:
            from .vocaset import VOCASETDataModule
            data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
            collate_fn = vocaset_collate_fn
            dataset = VOCASETDataModule(
                cfg = cfg,
                data_root = data_root,
                batch_size=cfg.TRAIN.BATCH_SIZE,
                num_workers=cfg.TRAIN.NUM_WORKERS,
                debug=cfg.DEBUG,
                collate_fn=collate_fn,
            )
            datasets.append(dataset)
        if dataset_name.lower() in ['voxcelebinsta']:
            from .voxceleb_insta import VoxCelebInstaDataModule
            data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
            collate_fn = voxcelebinsta_collate_fn
            dataset = VoxCelebInstaDataModule(
                cfg = cfg,
                data_root = data_root,
                batch_size=cfg.TRAIN.BATCH_SIZE,
                num_workers=cfg.TRAIN.NUM_WORKERS,
                debug=cfg.DEBUG,
                collate_fn=collate_fn,
                predict_pose=cfg.model.predict_pose,
                predict_exp=cfg.model.predict_exp,
                use_image=cfg.model.use_image,
            )
            datasets.append(dataset)
        if dataset_name.lower() in ['voxcelebinstalmdb']:
            from .voxceleb_insta_lmdb import VoxCelebInstalmDBDataModule
            data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
            collate_fn = voxcelebinsta_collate_fn
            dataset = VoxCelebInstalmDBDataModule(
                cfg = cfg,
                data_root = data_root,
                batch_size=cfg.TRAIN.BATCH_SIZE,
                num_workers=cfg.TRAIN.NUM_WORKERS,
                debug=cfg.DEBUG,
                collate_fn=collate_fn,
                predict_pose=cfg.model.predict_pose,
                predict_exp=cfg.model.predict_exp,
                use_image=cfg.model.use_image,
            )
            datasets.append(dataset)
        if dataset_name.lower() in ['biwi']:
            from .biwi import BIWIDataModule
            data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
            collate_fn = vocaset_collate_fn # this is the same as vocaset
            dataset = BIWIDataModule(
                cfg = cfg,
                data_root = data_root,
                batch_size=cfg.TRAIN.BATCH_SIZE,
                num_workers=cfg.TRAIN.NUM_WORKERS,
                debug=cfg.DEBUG,
                collate_fn=collate_fn,
            )
            datasets.append(dataset)
        if dataset_name.lower() in ['voxcelebinstacoeflmdb']:
            from .voxceleb_insta_coef_lmdb import VoxCelebInstaCoefLMDBDataModule
            data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
            collate_fn = voxcelebinstacoeflmdb_collate_fn
            dataset = VoxCelebInstaCoefLMDBDataModule(
                cfg = cfg,
                data_root = data_root,
                batch_size=cfg.TRAIN.BATCH_SIZE,
                num_workers=cfg.TRAIN.NUM_WORKERS,
                debug=cfg.DEBUG,
                collate_fn=collate_fn,
            )
            datasets.append(dataset)

    cfg.DATASET.NFEATS = datasets[0].nfeats
    return datasets




================================================
FILE: alm/data/voca/__init__.py
================================================
from .dataset import VOCASETDataset

================================================
FILE: alm/data/voca/dataset.py
================================================
import numpy as np
import torch
from torch.utils import data
from transformers import Wav2Vec2Processor
from collections import defaultdict
import os
from tqdm import tqdm
import numpy as np
import pickle



class VOCASETDataset(data.Dataset):

    def __init__(self, 
                data, 
                subjects_dict, 
                data_type="train",
                ):

        self.data = data
        self.len = len(self.data)
        self.subjects_dict = subjects_dict
        self.data_type = data_type
        self.one_hot_labels = np.eye(len(subjects_dict["train"]))

    def __getitem__(self, index):
        # seq_len, fea_dim
        file_name = self.data[index]["name"]
        file_path = self.data[index]["path"]
        audio = self.data[index]["audio"]
        vertice = self.data[index]["vertice"]
        template = self.data[index]["template"]
        if self.data_type == "train":
            subject = "_".join(file_name.split("_")[:-1])
            one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)]
        elif self.data_type == "val":
            one_hot = self.one_hot_labels
        elif self.data_type == "test":
            subject = "_".join(file_name.split("_")[:-1])
            if subject in self.subjects_dict["train"]:
                one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)]
            else:
                one_hot = self.one_hot_labels

        return {
            'audio':torch.FloatTensor(audio),
            'audio_attention':torch.ones_like(torch.Tensor(audio)).long(),
            'vertice':torch.FloatTensor(vertice), 
            'vertice_attention':torch.ones_like(torch.Tensor(vertice)[..., 0]).long(),
            'template':torch.FloatTensor(template), 
            'id':torch.FloatTensor(one_hot), 
            'file_name':file_name,
            'file_path':file_path
        }

    def __len__(self):
        return self.len
    

================================================
FILE: alm/data/vocaset.py
================================================
from .base import BASEDataModule
from alm.data.voca import VOCASETDataset
from transformers import Wav2Vec2Processor
from collections import defaultdict

import os
from os.path import join as pjoin
import pickle
from tqdm import tqdm
import librosa
import numpy as np
from multiprocessing import Pool


def load_data(args):
    file, root_dir, processor, templates, audio_dir, vertice_dir = args
    if file.endswith('wav'):
        wav_path = os.path.join(root_dir, audio_dir, file)
        speech_array, sampling_rate = librosa.load(wav_path, sr=16000)
        input_values = np.squeeze(processor(speech_array,sampling_rate=16000).input_values)
        key = file.replace("wav", "npy")
        result = {}
        result["audio"] = input_values
        subject_id = "_".join(key.split("_")[:-1])
        temp = templates[subject_id]
        result["name"] = file.replace(".wav", "")
        result["path"] = os.path.abspath(wav_path)
        result["template"] = temp.reshape((-1)) 
        vertice_path = os.path.join(root_dir, vertice_dir, file.replace("wav", "npy"))
        if not os.path.exists(vertice_path):
            return None
        else:
            result["vertice"] = np.load(vertice_path,allow_pickle=True)[::2,:]
            return (key, result)

class VOCASETDataModule(BASEDataModule):
    def __init__(self,
                cfg,
                batch_size,
                num_workers,
                collate_fn = None,
                phase="train",
                **kwargs):
        super().__init__(batch_size=batch_size,
                            num_workers=num_workers,
                            collate_fn=collate_fn)
        self.save_hyperparameters(logger=False)
        self.name = 'VOCASET'
        self.Dataset = VOCASETDataset
        self.cfg = cfg
        
        # customized to VOCASET
        self.subjects = {
            'train': [
                'FaceTalk_170728_03272_TA',
                'FaceTalk_170904_00128_TA',
                'FaceTalk_170725_00137_TA',
                'FaceTalk_170915_00223_TA',
                'FaceTalk_170811_03274_TA',
                'FaceTalk_170913_03279_TA',
                'FaceTalk_170904_03276_TA',
                'FaceTalk_170912_03278_TA'
            ],
            'val': [
                'FaceTalk_170811_03275_TA',
                'FaceTalk_170908_03277_TA'
            ],
            'test': [
                'FaceTalk_170809_00138_TA',
                'FaceTalk_170731_00024_TA'
            ]
            # 'test': [
            #     'FaceTalk_170728_03272_TA',
            #     'FaceTalk_170904_00128_TA',
            #     'FaceTalk_170725_00137_TA',
            #     'FaceTalk_170915_00223_TA',
            #     'FaceTalk_170811_03274_TA',
            #     'FaceTalk_170913_03279_TA',
            #     'FaceTalk_170904_03276_TA',
            #     'FaceTalk_170912_03278_TA'
            # ]
        }

        self.root_dir = kwargs.get('data_root', 'datasets/vocaset')
        self.audio_dir = 'wav'
        self.vertice_dir = 'vertices_npy'
        processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        self.template_file = 'templates.pkl'

        self.nfeats = 15069

        # load
        data = defaultdict(dict)
        with open(os.path.join(self.root_dir, self.template_file), 'rb') as fin:
            templates = pickle.load(fin, encoding='latin1')

        count = 0
        args_list = []
        for r, ds, fs in os.walk(os.path.join(self.root_dir, self.audio_dir)):
            for f in fs:
                args_list.append((f, self.root_dir, processor, templates, self.audio_dir, self.vertice_dir, ))

                # # comment off for full dataset
                # count += 1
                # if count > 10:
                #     break

        # split dataset
        self.data_splits = {
            'train':[],
            'val':[],
            'test':[],
        }

        motion_list = []

        if True: # multi-process
            with Pool(processes=os.cpu_count()) as pool:
                results = pool.map(load_data, args_list)
                for result in results:
                    if result is not None:
                        key, value = result
                        data[key] = value
        else: # single process
            for args in tqdm(args_list, desc="Loading data"):
                result = load_data(args)
                if result is not None:
                    key, value = result
                    data[key] = value
                else:
                    print("Warning: data not found")


        # # calculate mean and std
        # motion_list = np.concatenate(motion_list, axis=0)
        # self.mean = np.mean(motion_list, axis=0)
        # self.std = np.std(motion_list, axis=0)

        splits = {
                    'train':range(1,41),
                    'val':range(21,41),
                    'test':range(21,41)
                }
        
        for k, v in data.items():
            subject_id = "_".join(k.split("_")[:-1])
            sentence_id = int(k.split(".")[0][-2:])
            for sub in ['train', 'val', 'test']:
                if subject_id in self.subjects[sub] and sentence_id in splits[sub]:
                    self.data_splits[sub].append(v)

        # self._sample_set = self.__getattr__("test_dataset")


    def __getattr__(self, item):
        # train_dataset/val_dataset etc cached like properties
        # question
        if item.endswith("_dataset") and not item.startswith("_"):
            subset = item[:-len("_dataset")]
            item_c = "_" + item
            if item_c not in self.__dict__:
                # todo: config name not consistent
                self.__dict__[item_c] = self.Dataset(
                    data = self.data_splits[subset] ,
                    subjects_dict = self.subjects,
                    data_type = subset
                )
            return getattr(self, item_c)
        classname = self.__class__.__name__
        raise AttributeError(f"'{classname}' object has no attribute '{item}'")

================================================
FILE: alm/models/__init__.py
================================================


================================================
FILE: alm/models/architectures/__init__.py
================================================


================================================
FILE: alm/models/architectures/adpt_bias_denoiser.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from alm.models.architectures.tools.embeddings import (TimestepEmbedding,
                                                       Timesteps)

from .tools.utils import PeriodicPositionalEncoding, init_bi_biased_mask_faceformer, init_mem_mask_faceformer

from typing import Optional, Tuple, Union, Callable
from .tools.transformer_adpt import TransformerDecoderLayer_w_Adapter, TransformerDecoder_w_Adapter

class Adpt_Bias_Denoiser(nn.Module):
    # this model is based on the trasnformer_adpt.py but with some modifications for the diffusion denoising task
    def __init__(self,
                 nfeats: int = 15069,
                 latent_dim: list = 174,
                 ff_size: int = 1024,
                 num_layers: int = 6,
                 num_heads: int = 4,
                 dropout: float = 0.1,
                 normalize_before: bool = False,
                 activation: str = "gelu",
                 arch: str = "trans_dec",
                 audio_encoded_dim: int = 768,
                 max_len: int = 600,
                 id_dim: int = 10,
                 return_intermediate_dec: bool = False,
                 flip_sin_to_cos: bool = True,
                 freq_shift: int = 0,
                 mem_attn_scale: float = 1.0,
                 tgt_attn_scale: float = 0.1,
                 period: int = 30,
                 no_cross: bool = False,
                 **kwargs) -> None:

        super().__init__()
        self.latent_dim = latent_dim
        self.arch = arch
        self.audio_encoded_dim = audio_encoded_dim

        # audio projecter
        self.audio_feature_map = nn.Linear(audio_encoded_dim, latent_dim)

        # motion projecter
        self.vertice_map = nn.Linear(nfeats, latent_dim)

        # periodic positional encoding
        self.PPE = PeriodicPositionalEncoding(latent_dim, period = period, max_seq_len=5000) # max_seq_len can be adjusted if thit reporst an error

        # attention bias
        assert mem_attn_scale in [-1.0, 0.0, 1.0]
        self.use_mem_attn_bias = mem_attn_scale != 0.0
        self.use_tgt_attn_bias = tgt_attn_scale != 0.0
        self.memory_bi_bias = init_mem_mask_faceformer(max_len)
        
        if tgt_attn_scale < 0.0: # means we only use the causal attention
            self.target_bi_bias = init_bi_biased_mask_faceformer(num_heads, max_len, period)
            mask = torch.triu(torch.ones(max_len, max_len), diagonal=1) == 1
            self.target_bi_bias = self.target_bi_bias.masked_fill(mask, float('-inf'))
        else:
            self.target_bi_bias = init_bi_biased_mask_faceformer(num_heads, max_len, period)



        # init decoder
        decoder_layer = TransformerDecoderLayer_w_Adapter(
            d_model=latent_dim, 
            nhead=num_heads, 
            dim_feedforward=ff_size,
            dropout=dropout, 
            activation=activation, 
            norm_first=normalize_before,
            batch_first=True
        )

        self.transformer_decoder = TransformerDecoder_w_Adapter(
            decoder_layer=decoder_layer,
            num_layers=num_layers,
            )

        # used for diffusion denoising
        self.time_proj = Timesteps(
            audio_encoded_dim, 
            flip_sin_to_cos=flip_sin_to_cos, # because baseline models is trained with this
            downscale_freq_shift=freq_shift, # same as above
        )
        self.time_embedding = TimestepEmbedding(
            audio_encoded_dim,
            latent_dim * num_layers
        )
        
        # motion decoder
        self.motion_decoder = nn.Linear(latent_dim, nfeats)
        nn.init.constant_(self.motion_decoder.weight, 0)
        nn.init.constant_(self.motion_decoder.bias, 0)

        # style embedding
        self.obj_vector = nn.Embedding(id_dim, latent_dim * num_layers, )

        # whether we do not use cross attention
        self.no_cross = no_cross

    def forward(self,
                vertice_input: torch.Tensor,
                hidden_state: torch.Tensor,
                timesteps: torch.Tensor,
                adapter: torch.Tensor = None, # conditions other than the time embedding
                tgt_mask: torch.Tensor = None,
                tgt_key_padding_mask: torch.Tensor = None,
                memory_mask: torch.Tensor = None,
                memory_key_padding_mask: torch.Tensor = None,
                **kwargs):
        """
        Auto-regressive forward pass for the decoder.
        To be used during training.
        Args:
            vertice_input: [N, T, E]
            hidden_state: [N, S, E]
            adapter: [N, A, E]
            tgt_mask: [N * H, T, T]
            tgt_key_padding_mask: [N, T]
            memory_mask: [T, S]
            memory_key_padding_mask: [N, S]
        """
        
        # vertice projection
        vertice_input = self.vertice_map(vertice_input)
        vertice_input = self.PPE(vertice_input)

        # time projection
        time_emb = self.time_proj(timesteps).to(vertice_input.device)
        time_emb = self.time_embedding(time_emb).unsqueeze(1) # time_emb.shape = [N, 1, E]

        # treat the time embedding as an adapter
        if adapter is not None:
            adapter = torch.concat([adapter, time_emb], dim=1)
        else:
            adapter = time_emb

        vertice_out = vertice_input
        # split the adpater in to num_layers pieces, in order to feed them into the transformer
        adapters = adapter.split(self.latent_dim, dim=-1)

        # concat the hidden state and the vertice input
        if self.no_cross:
            hidden_len = hidden_state.shape[1]
            vertice_out = torch.cat([hidden_state, vertice_out], dim=1)
            hidden_state = torch.cat([hidden_state, hidden_state], dim=1)

        for mod,adapter in zip(self.transformer_decoder.layers, adapters):
            vertice_out = mod(
                tgt=vertice_out,
                memory=hidden_state,
                adapter=adapter,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask = tgt_key_padding_mask,
                memory_mask=memory_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                **kwargs
            )

        if self.no_cross: # remove the hidden state
            vertice_out = vertice_out[:, hidden_len:]

        if self.transformer_decoder.norm is not None:
            vertice_out = self.transformer_decoder.norm(vertice_out)

        self.transformer_decoder.layers[0].self_attn
        vertice_out = self.motion_decoder(vertice_out)

        return vertice_out





        


================================================
FILE: alm/models/architectures/tools/embeddings.py
================================================
# This file is taken from signjoey repository
import math

import torch
from torch import Tensor, nn


def get_activation(activation_type):
    if activation_type == "relu":
        return nn.ReLU()
    elif activation_type == "relu6":
        return nn.ReLU6()
    elif activation_type == "prelu":
        return nn.PReLU()
    elif activation_type == "selu":
        return nn.SELU()
    elif activation_type == "celu":
        return nn.CELU()
    elif activation_type == "gelu":
        return nn.GELU()
    elif activation_type == "sigmoid":
        return nn.Sigmoid()
    elif activation_type == "softplus":
        return nn.Softplus()
    elif activation_type == "softshrink":
        return nn.Softshrink()
    elif activation_type == "softsign":
        return nn.Softsign()
    elif activation_type == "tanh":
        return nn.Tanh()
    elif activation_type == "tanhshrink":
        return nn.Tanhshrink()
    else:
        raise ValueError("Unknown activation type {}".format(activation_type))


class MaskedNorm(nn.Module):
    """
        Original Code from:
        https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8
    """

    def __init__(self, norm_type, num_groups, num_features):
        super().__init__()
        self.norm_type = norm_type
        if self.norm_type == "batch":
            self.norm = nn.BatchNorm1d(num_features=num_features)
        elif self.norm_type == "group":
            self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)
        elif self.norm_type == "layer":
            self.norm = nn.LayerNorm(normalized_shape=num_features)
        else:
            raise ValueError("Unsupported Normalization Layer")

        self.num_features = num_features

    def forward(self, x: Tensor, mask: Tensor):
        if self.training:
            reshaped = x.reshape([-1, self.num_features])
            reshaped_mask = mask.reshape([-1, 1]) > 0
            selected = torch.masked_select(reshaped, reshaped_mask).reshape(
                [-1, self.num_features]
            )
            batch_normed = self.norm(selected)
            scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
            return scattered.reshape([x.shape[0], -1, self.num_features])
        else:
            reshaped = x.reshape([-1, self.num_features])
            batched_normed = self.norm(reshaped)
            return batched_normed.reshape([x.shape[0], -1, self.num_features])


# TODO (Cihan): Spatial and Word Embeddings are pretty much the same
#       We might as well convert them into a single module class.
#       Only difference is the lut vs linear layers.
class Embeddings(nn.Module):

    """
    Simple embeddings class
    """

    # pylint: disable=unused-argument
    def __init__(
        self,
        embedding_dim: int = 64,
        num_heads: int = 8,
        scale: bool = False,
        scale_factor: float = None,
        norm_type: str = None,
        activation_type: str = None,
        vocab_size: int = 0,
        padding_idx: int = 1,
        freeze: bool = False,
        **kwargs
    ):
        """
        Create new embeddings for the vocabulary.
        Use scaling for the Transformer.

        :param embedding_dim:
        :param scale:
        :param vocab_size:
        :param padding_idx:
        :param freeze: freeze the embeddings during training
        """
        super().__init__()

        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx)

        self.norm_type = norm_type
        if self.norm_type:
            self.norm = MaskedNorm(
                norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
            )

        self.activation_type = activation_type
        if self.activation_type:
            self.activation = get_activation(activation_type)

        self.scale = scale
        if self.scale:
            if scale_factor:
                self.scale_factor = scale_factor
            else:
                self.scale_factor = math.sqrt(self.embedding_dim)

        if freeze:
            freeze_params(self)

    # pylint: disable=arguments-differ
    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        """
        Perform lookup for input `x` in the embedding table.

        :param mask: token masks
        :param x: index in the vocabulary
        :return: embedded representation for `x`
        """

        x = self.lut(x)

        if self.norm_type:
            x = self.norm(x, mask)

        if self.activation_type:
            x = self.activation(x)

        if self.scale:
            return x * self.scale_factor
        else:
            return x

    def __repr__(self):
        return "%s(embedding_dim=%d, vocab_size=%d)" % (
            self.__class__.__name__,
            self.embedding_dim,
            self.vocab_size,
        )


class SpatialEmbeddings(nn.Module):

    """
    Simple Linear Projection Layer
    (For encoder outputs to predict glosses)
    """

    # pylint: disable=unused-argument
    def __init__(
        self,
        embedding_dim: int,
        input_size: int,
        num_heads: int,
        freeze: bool = False,
        norm_type: str = "batch",
        activation_type: str = "softsign",
        scale: bool = False,
        scale_factor: float = None,
        **kwargs
    ):
        """
        Create new embeddings for the vocabulary.
        Use scaling for the Transformer.

        :param embedding_dim:
        :param input_size:
        :param freeze: freeze the embeddings during training
        """
        super().__init__()

        self.embedding_dim = embedding_dim
        self.input_size = input_size
        self.ln = nn.Linear(self.input_size, self.embedding_dim)

        self.norm_type = norm_type
        if self.norm_type:
            self.norm = MaskedNorm(
                norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim
            )

        self.activation_type = activation_type
        if self.activation_type:
            self.activation = get_activation(activation_type)

        self.scale = scale
        if self.scale:
            if scale_factor:
                self.scale_factor = scale_factor
            else:
                self.scale_factor = math.sqrt(self.embedding_dim)

        if freeze:
            freeze_params(self)

    # pylint: disable=arguments-differ
    def forward(self, x: Tensor, mask: Tensor) -> Tensor:
        """
        :param mask: frame masks
        :param x: input frame features
        :return: embedded representation for `x`
        """

        x = self.ln(x)

        if self.norm_type:
            x = self.norm(x, mask)

        if self.activation_type:
            x = self.activation(x)

        if self.scale:
            return x * self.scale_factor
        else:
            return x

    def __repr__(self):
        return "%s(embedding_dim=%d, input_size=%d)" % (
            self.__class__.__name__,
            self.embedding_dim,
            self.input_size,
        )

def get_timestep_embedding(
    timesteps: torch.Tensor,
    embedding_dim: int,
    flip_sin_to_cos: bool = False,
    downscale_freq_shift: float = 1,
    scale: float = 1,
    max_period: int = 10000,
):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
    embeddings. :return: an [N x dim] Tensor of positional embeddings.
    """
    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
    )
    exponent = exponent / (half_dim - downscale_freq_shift)

    emb = torch.exp(exponent)
    emb = timesteps[:, None].float() * emb[None, :]

    # scale embeddings
    emb = scale * emb

    # concat sine and cosine embeddings
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

    # flip sine and cosine embeddings
    if flip_sin_to_cos:
        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb


class TimestepEmbedding(nn.Module):
    def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
        super().__init__()

        self.linear_1 = nn.Linear(channel, time_embed_dim)
        self.act = None
        if act_fn == "silu":
            self.act = nn.SiLU()
        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)

    def forward(self, sample):
        sample = self.linear_1(sample)

        if self.act is not None:
            sample = self.act(sample)

        sample = self.linear_2(sample)
        return sample


class Timesteps(nn.Module):
    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
        super().__init__()
        self.num_channels = num_channels
        self.flip_sin_to_cos = flip_sin_to_cos
        self.downscale_freq_shift = downscale_freq_shift

    def forward(self, timesteps):
        t_emb = get_timestep_embedding(
            timesteps,
            self.num_channels,
            flip_sin_to_cos=self.flip_sin_to_cos,
            downscale_freq_shift=self.downscale_freq_shift,
        )
        return t_emb


================================================
FILE: alm/models/architectures/tools/transformer_adpt.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from alm.models.architectures.tools.embeddings import (TimestepEmbedding,
                                                       Timesteps)

from .utils import PeriodicPositionalEncoding, init_bi_biased_mask

from typing import Optional, Tuple, Union, Callable
import math

class Transformer_Adpt(nn.Module):

    def __init__(self,
                 nfeats: int = 15069,
                 latent_dim: list = 174,
                 ff_size: int = 1024,
                 num_layers: int = 6,
                 num_heads: int = 4,
                 dropout: float = 0.1,
                 normalize_before: bool = False,
                 activation: str = "gelu",
                 arch: str = "trans_dec",
                 audio_encoded_dim: int = 768,
                 max_len: int = 3000,
                 id_dim: int = 10,
                 return_intermediate_dec: bool = False,
                 require_start_token: bool = False,   
                 require_time_encoding: bool = True,              
                 **kwargs) -> None:

        super().__init__()
        self.latent_dim = latent_dim
        self.arch = arch
        self.audio_encoded_dim = audio_encoded_dim

        # audio projecter
        self.audio_feature_map = nn.Linear(audio_encoded_dim, latent_dim)

        # motion projecter
        self.vertice_map = nn.Linear(nfeats, latent_dim)

        # periodic positional encoding
        self.PPE = PeriodicPositionalEncoding(latent_dim, period = max_len)

        # temporal bias
        self.memory_bi_bias = init_bi_biased_mask(max_len) # this is for the memory bias, not all the model

        # init decoder
        decoder_layer = TransformerDecoderLayer_w_Adapter(
            d_model=latent_dim, 
            nhead=num_heads, 
            dim_feedforward=ff_size,
            dropout=dropout, 
            activation=activation, 
            norm_first=normalize_before,
            batch_first=True
        )
        self.transformer_decoder = TransformerDecoder_w_Adapter(
            decoder_layer=decoder_layer,
            num_layers=num_layers,
            )

        # motion decoder
        self.motion_decoder = nn.Linear(latent_dim, nfeats)

        # used for auto-regressive decoding
        if require_start_token:
            self.start_token = nn.Parameter(torch.randn(1, 1, latent_dim), requires_grad=True)

        # style embedding
        # self.obj_vector = nn.Linear(id_dim, latent_dim, bias=False)

        id_len = 1000
        self.id_enc = torch.zeros([id_len, latent_dim])
        position = torch.arange(0, id_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, latent_dim, 2).float() * (-math.log(10000.0) / latent_dim))
        self.id_enc[:, 0::2] = torch.sin(position * div_term)
        self.id_enc[:, 1::2] = torch.cos(position * div_term)        

        nn.init.constant_(self.motion_decoder.weight, 0)
        nn.init.constant_(self.motion_decoder.bias, 0)

    def obj_vector(self, id):
        # if id is a one-hot vector
        if id.dim() == 2: # [N, id_dim]
            id = id.argmax(dim=1)
        return self.id_enc[id] # [N, id_dim]    

    def _forward(self,
                vertice_input: torch.Tensor,
                hidden_state: torch.Tensor,
                adapter: torch.Tensor = None,
                tgt_mask: torch.Tensor = None,
                tgt_key_padding_mask: torch.Tensor = None,
                memory_mask: torch.Tensor = None,
                memory_key_padding_mask: torch.Tensor = None,
                **kwargs):
        """
        Auto-regressive forward pass for the decoder.
        To be used during training.
        Args:
            vertice_input: [N, T, E]
            hidden_state: [N, S, E]
            adapter: [N, A, E]
            tgt_mask: [N * H, T, T]
            tgt_key_padding_mask: [N, T]
            memory_mask: [T, S]
            memory_key_padding_mask: [N, S]
        """
        
        vertice_input = self.PPE(vertice_input)

        vertice_out = self.transformer_decoder(
            tgt=vertice_input,
            memory=hidden_state,
            adapter = adapter,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask = tgt_key_padding_mask,
            memory_mask=memory_mask,
            memory_key_padding_mask = memory_key_padding_mask,
            **kwargs
        )

        vertice_out = self.motion_decoder(vertice_out)

        return vertice_out

from torch.nn.modules.transformer import _get_clones
class TransformerDecoder_w_Adapter(nn.TransformerDecoder):
    """
    A transformer decoder with adapter layer.
    Args:
        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
        num_layers: the number of sub-DecoderLayer in the decoder (required).
        norm: the layer normalization component (optional).
    """
    __constants__ = ['norm']

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder_w_Adapter, self).__init__(decoder_layer, num_layers, norm)
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, 
                tgt: Tensor, 
                memory: Tensor, 
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, 
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                adapter: Optional[Tensor] = None,
        ) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer in turn.

        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
            adapter: the adapter for the decoder layer (optional).
        Shape:
            see the docs in Transformer class.
        """
        output = tgt

        for mod in self.layers:
            output = mod(output, memory, tgt_mask=tgt_mask,
                         memory_mask=memory_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask,
                         adapter=adapter)
            
        if self.norm is not None:
            output = self.norm(output)

        return output
        

class TransformerDecoderLayer_w_Adapter(nn.TransformerDecoderLayer):
    """
    A single layer of the transformer decoder with adapter.
    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: if ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False``.
        norm_first: if ``True``, then the input and output tensors are provided as (seq, batch, feature). Default: ``False``.
        device: the desired device of the encoder layer. Default: if ``None`` will use ``torch.device("cuda")`` if ``torch.cuda.is_available()`` else ``torch.device("cpu")``
        dtype: the desired dtype of the encoder layer. Default: if ``None`` will use ``torch.float32``
    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, 
                 d_model: int, 
                 nhead: int, 
                 dim_feedforward: int = 2048, 
                 dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, 
                 batch_first: bool = False, 
                 norm_first: bool = False,
                 device=None, dtype=None) -> None:

        # folow the original transformer decoder layer
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerDecoderLayer_w_Adapter, self).__init__(
            d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, **factory_kwargs)

    def forward(self, 
                tgt: Tensor, 
                memory: Tensor, 
                tgt_mask: Optional[Tensor] = None, 
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, 
                memory_key_padding_mask: Optional[Tensor] = None,
                adapter: Optional[Tensor] = None,
        ) -> Tensor:
        r"""Pass the inputs (and mask) through the decoder layer.

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
            adapter: the adapter for the decoder layer (optional).
        Shape:
            see the docs in Transformer class.
        """
        x = tgt
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, adapter=adapter)
            x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, adapter=adapter)
            x = x + self._ff_block(self.norm3(x))
        else:
            x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, adapter=adapter))
            x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, adapter=adapter))
            x = self.norm3(x + self._ff_block(x))

        return x

    # self-attention block with adapter
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], 
                  key_padding_mask: Optional[Tensor],
                  adapter: Optional[Tensor] = None,
        ) -> Tensor:
        """
        Args:
            x: [B, T, E] if batch_first else [T, B, E]
            attn_mask: [T, T]
            key_padding_mask: [B, T]
            adapter: [B, A, E] if batch_first else [A, B, E]
        Returns:
            [B, T, E] if batch_first else [T, B, E]
        """
        batch_first = self.self_attn.batch_first
        # concate adapter to key and value if it is not None
        if adapter is not None:
            x_adpt = self._concate_adapter(adapter, x, batch_first=batch_first)
        else:
            x_adpt = x

        # # original self-attention block
        # tmp = self.self_attn(x, x_adpt, x_adpt, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=True, )[1]
        # # visualize attention, use sns
        # import matplotlib.pyplot as plt
        # import seaborn as sns
        # length = 100
        # fig, ax = plt.subplots(figsize=(15, 10))
        # sns.heatmap(tmp[0, :length, :length+2].detach().cpu().numpy())
        # # save to disk
        # plt.savefig('self_attention.png')

        
        x = self.self_attn(x, x_adpt, x_adpt, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0]
        return self.dropout1(x)

    # cross-attention block with adapter
    def _mha_block(self, x: Tensor, mem: Tensor,
                   attn_mask: Optional[Tensor], 
                   key_padding_mask: Optional[Tensor],
                   adapter: Optional[Tensor] = None,
        ) -> Tensor:
        """
        Args:
            x: [B, T, E] if batch_first else [T, B, E]
            mem: [B, S, E] if batch_first else [S, B, E]
            attn_mask: [T, S]
            key_padding_mask: [B, T]
            adapter: [B, A, E] if batch_first else [A, B, E]
        Returns:
            [B, T, E] if batch_first else [T, B, E]
        """

        batch_first = self.multihead_attn.batch_first
        # concate adapter to key and value if it is not None
        if adapter is not None:
            mem_adpt = self._concate_adapter(adapter, mem, batch_first=batch_first)
        else:
            mem_adpt = x

        # tmp = self.multihead_attn(x, mem_adpt, mem_adpt, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=True, )[1]
        # # visualize attention, use sns
        # import matplotlib.pyplot as plt
        # import seaborn as sns
        # length = 100
        # fig, ax = plt.subplots(figsize=(15, 10))
        # sns.heatmap(tmp[0, :length, :length+2].detach().cpu().numpy())
        # # save to disk
        # plt.savefig('cross_attention.png')

        # original cross-attention block
        x = self.multihead_attn(x, mem_adpt, mem_adpt, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0]
        return self.dropout2(x)
        
    def _concate_adapter(self, adapter: Tensor, x: Tensor, batch_first: bool = True):
        """
        concate adapter ahead of x
        Args:
            adapter: [B, A, E] if batch_first else [A, B, E]
            x: [B, T, E] if batch_first else [T, B, E]
        Returns:
            x_adapted: [B, A+T, E] if batch_first else [A+T, B, E]
        """
        if batch_first:
            x_adapted = torch.concat([adapter, x], dim=1) # [B, A, E] + [B, T, E] -> [B, A+T, E]  
        else: # batch_first
            x_adapted = torch.concat([adapter, x], dim=0) # [A, B, E] + [T, B, E] -> [A+T, B, E]
        return x_adapted



        


================================================
FILE: alm/models/architectures/tools/utils.py
================================================
import math
import torch
import torch.nn as nn

class PeriodicPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=600):
        super(PeriodicPositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(period, d_model)
        position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # (1, period, d_model)
        repeat_num = (max_seq_len//period) + 1
        pe = pe.repeat(1, repeat_num, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
    
# def init_biased_mask(n_head, max_seq_len, period):
#     # this code is from https://github.com/EvelynFan/FaceFormer/blob/dfaea81983665b22b99af336a80574208cfcc099/faceformer.py#L10
#     # however, the original code is not working for the case where the batch size is not 1
#     # so I modified it a little bit
#     def get_slopes(n):
#         def get_slopes_power_of_2(n):
#             start = (2**(-2**-(math.log2(n)-3)))
#             ratio = start
#             return [start*ratio**i for i in range(n)]
#         if math.log2(n).is_integer():
#             return get_slopes_power_of_2(n)                   
#         else:                                                 
#             closest_power_of_2 = 2**math.floor(math.log2(n)) 
#             return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
    
#     slopes = torch.Tensor(get_slopes(n_head))
#     bias = torch.div(
#         torch.arange(
#             start=0, 
#             end=max_seq_len, 
#             step=period,
#             dtype=torch.float
#         ).unsqueeze(1).repeat(1,period).view(-1),
#         period,
#         rounding_mode='floor'
#     )
#     bias = - torch.flip(bias,dims=[0])
#     alibi = torch.zeros(max_seq_len, max_seq_len)
#     for i in range(max_seq_len):
#         alibi[i, :i+1] = bias[-(i+1):]

#     alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
#     mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
#     mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
#     mask = mask.unsqueeze(0) + alibi
#     return mask


# def init_bi_biased_mask_dev(n_head, max_seq_len, period):
#     # this code is from # https://github.com/tensorflow/tensor2tensor
#     # and I modified it a little bit to match the original code in  https://github.com/EvelynFan/FaceFormer/blob/dfaea81983665b22b99af336a80574208cfcc099/faceformer.py#L10
#     # such the code is working for the case where the batch size is not 1
#     def get_slopes(n):
#         def get_slopes_power_of_2(n):
#             start = (2**(-2**-(math.log2(n)-3)))
#             ratio = start
#             return [start*ratio**i for i in range(n)]
#         if math.log2(n).is_integer():
#             return get_slopes_power_of_2(n)                   
#         else:                                                 
#             closest_power_of_2 = 2**math.floor(math.log2(n)) 
#             return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]

#     slopes = torch.Tensor(get_slopes(n_head))

#     range_vec = torch.div(
#         torch.arange(
#             start=0,
#             end=max_seq_len,
#             step=period,
#             dtype=torch.float
#         ).unsqueeze(1).repeat(1,period).view(-1),
#         period,
#         rounding_mode='floor'
#     )

#     relative_matrix = range_vec[None, :] - range_vec[:, None]
#     relative_matrix[torch.where(relative_matrix > 0)] *= -1 

#     alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_matrix.unsqueeze(0)
#     return alibi

def init_biased_mask(n_head, max_seq_len, period):
    # this code is from https://github.com/EvelynFan/FaceFormer/blob/dfaea81983665b22b99af336a80574208cfcc099/faceformer.py#L10
    # however, the original code is not working for the case where the batch size is not 1
    # so I modified it a little bit
    def get_slopes(n):
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]
        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)                   
        else:                                                 
            closest_power_of_2 = 2**math.floor(math.log2(n)) 
            return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
    
    slopes = torch.Tensor(get_slopes(n_head))
    bias = torch.div(
        torch.arange(
            start=0, 
            end=max_seq_len, 
            step=period,
            dtype=torch.float
        ).unsqueeze(1).repeat(1,period).view(-1),
        period,
        rounding_mode='floor'
    )
    bias = - torch.flip(bias,dims=[0])
    alibi = torch.zeros(max_seq_len, max_seq_len)
    for i in range(max_seq_len):
        alibi[i, :i+1] = bias[-(i+1):]

    alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
    mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    mask = mask.unsqueeze(0) + alibi
    return mask


def init_bi_biased_mask(max_seq_len, ):
    # any attention mask that is as more than 3 elements is not working

    range_vec = torch.arange(
            start=0,
            end=max_seq_len,
            dtype=torch.float
        )

    relative_matrix = range_vec[None, :] - range_vec[:, None]
    relative_matrix[torch.where(relative_matrix > 0)] *= -1 

    return relative_matrix

def init_mem_mask_faceformer(max_seq_len):
    mask = torch.ones(max_seq_len, max_seq_len)
    # set the diagonal to 0
    mask = mask.masked_fill(torch.eye(max_seq_len) == 1, 0)
    return mask    

def init_bi_biased_mask_faceformer(n_head, max_seq_len, period):
    # any attention mask that is as more than 3 elements is not working
    def get_slopes(n):
            def get_slopes_power_of_2(n):
                start = (2**(-2**-(math.log2(n)-3)))
                ratio = start
                return [start*ratio**i for i in range(n)]
            if math.log2(n).is_integer():
                return get_slopes_power_of_2(n)                   
            else:                                                 
                closest_power_of_2 = 2**math.floor(math.log2(n)) 
                return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]

    slopes = torch.Tensor(get_slopes(n_head))
    bias = torch.div(torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1), period, rounding_mode='floor')
    bias = - torch.flip(bias,dims=[0])
    alibi = torch.zeros(max_seq_len, max_seq_len)
    for i in range(max_seq_len):
        alibi[i, :i+1] = bias[-(i+1):]
        if i+1 < max_seq_len:
            alibi[i, i+1:] = bias[-(max_seq_len-(i+1)):].flip(dims=[0])

    alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)

    return alibi



================================================
FILE: alm/models/get_model.py
================================================
import importlib

def get_model(cfg, datamodule):
    modeltype = cfg.model.model_type
    return get_module(cfg, datamodule)
    # if modeltype in ["proto","faceformer"]:
    #     return get_module(cfg, datamodule)
    # else:
    #     raise ValueError(f"Invalid model {modeltype}.")
    
def get_module(cfg, datamodule):
    modeltype = cfg.model.model_type
    model_module = importlib.import_module(
        f".modeltype.{cfg.model.model_type}", package="alm.models")
    Model = model_module.__getattribute__(f"{modeltype.upper()}")
    return Model(cfg=cfg, datamodule=datamodule)
 
    

================================================
FILE: alm/models/losses/__init__.py
================================================


================================================
FILE: alm/models/losses/utils.py
================================================
import torch
import torch.nn as nn

class MaskedConsistency:
    def __init__(self) -> None:
        self.loss = nn.MSELoss(reduction="mean")

    def __call__(self, pred, gt, mask):
        return self.loss(mask * pred, mask * gt)
    
    def __repr__(self):
        return self.loss.__repr__()
    
class MaskedVelocityConsistency:
    def __init__(self) -> None:
        self.loss = nn.MSELoss(reduction="mean")

    def __call__(self, pred, gt, mask):
        term1 = self.velocity(mask * pred)
        temr2 = self.velocity(mask * gt)
        return self.loss(term1, temr2)

    def velocity(self, term):
        velocity = term[:, 1:] - term[:, :-1]
        return velocity

    def __repr__(self):
        return self.loss.__repr__()


# flame_lmk_faces = [[2210, 2212, 2213],
#          [3060, 3059, 1962],
#          [3485, 3060, 1961],
#          [3382, 3384, 3381],
#          [3385, 3388, 3386],
#          [3387, 3389, 3390],
#          [3392, 3418, 3419],
#          [3415, 3395, 3393],
#          [3414, 3399, 3397],
#          [3634, 3595, 3598],
#          [3643, 3637, 3594],
#          [3588, 3587, 3583],
#          [3584, 3581, 3582],
#          [3742, 3580, 3577],
#          [2012, 3756,  566],
#          [2009, 2012, 2011],
#          [ 728,  731,  730],
#          [1983, 1984, 1985],
#          [3157, 3708, 3158],
#          [ 338, 3153,  335],
#          [3154, 3712, 3705],
#          [2179, 2178, 3684],
#          [ 674, 3851,  673],
#          [3863, 3868, 2135],
#          [2134,   27,   16],
#          [2138, 2139, 3865],
#          [ 572,  571,  570],
#          [2194, 3553, 3542],
#          [3561,  739, 3518],
#          [1757, 3521, 3501],
#          [3526, 3564, 1819],
#          [2748, 2746, 2750],
#          [2792, 2794, 2795],
#          [1692, 3556, 3507],
#          [1678, 1677, 1675],
#          [1612, 1618, 1610],
#          [2440, 2428, 2437],
#          [2383, 2453, 2495],
#          [2493, 3689, 2494],
#          [2509, 3631, 3632],
#          [2293, 2298, 2299],
#          [2333, 2296, 2295],
#          [3833, 3832, 1358],
#          [1342, 1343, 3855],
#          [1344, 1218, 1034],
#          [1182, 1175, 1154],
#          [ 955,  883,  884],
#          [ 881,  897,  896],
#          [2845, 2715, 2714],
#          [2849, 2813, 2850],
#          [2811, 2866, 2774],
#          [1657, 3543, 3546],
#          [1694, 1657, 1751],
#          [1734, 1735, 1696],
#          [1730, 1578, 1579],
#          [1774, 1795, 1796],
#          [1802, 1865, 1866],
#          [1850, 3506, 3503],
#          [2905, 2949, 2948],
#          [2899, 2898, 2881],
#          [2719, 2718, 2845],
#          [3533, 2786, 2785],
#          [3533, 3531, 2786],
#          [1669, 3533, 1668],
#          [1730, 1578, 1579],
#          [1826, 1848, 1849],
#          [3504, 3509, 2937],
#          [2938, 2937, 2928]] # extracted from the mica-trakcer code

# def flame_vertice_2_lmk(vertice, lmk_faces = flame_lmk_faces):
        
#     if len(vertice.shape) == 3: # if the vertices is resizes to (B, N, V * 3), reshape it back to (B, N, V, 3)
#         resized = True
#         B, N, _ = vertice.shape
#         vertice = vertice.reshape(B, N, -1, 3)
#         V = vertice.shape[-2]
#     else:
#         resized = False
#         B, N, V, _ = vertice.shape

#     _vertice = vertice.view(-1, V, 3) # (B * N, V, 3) -> (_B, V, 3)
#     _B = _vertice.shape[0]

#     lmk_faces = torch.tensor(lmk_faces,).clone() # (68, 3)

#     lmk_faces = lmk_faces.unsqueeze(0).expand(_B, -1, -1).to(vertice.device) # (68, 3) -> (_B, 68, 3)
#     lmk_faces += torch.arange(_B, device=vertice.device).view(-1, 1, 1) * V # (_B, 68, 3)
#     lmk_vertices = _vertice.reshape(-1, 3)[lmk_faces].view(_B, -1, 3, 3) # (_B, 68, 3, 3), here we have to use .reshape(-1, 3) to make sure the index is correct, view reports error
#     landmarks = lmk_vertices.mean(dim=-2) # (_B, 68, 3), every vertice in the face contributes to the landmark

#     if resized:
#         landmarks = landmarks.view(B, N, -1) # (B, N, 68 * 3)
#     else:
#         landmarks = landmarks.view(B, N, -1, 3) # (B, N, 68, 3)

#     return landmarks

# def cusum(data, threshold, drift):
#     """Cumulative sum algorithm (CUSUM) to detect abrupt changes in data."""
#     # Initialize variables
#     thres = torch.zeros_like(data)
#     # Compute the cumulative sum using torch.cumsum
#     for i in range(1, data.shape[1]):
#         # Update the cumulative sum
#         prev_thres = thres[:, i-1].unsqueeze(1)
#         delta = (data[:, i, :] - data[:, i-1, :]).abs().unsqueeze(1) #torch.norm(data[:, i, :] - data[:, i-1, :], dim=1).unsqueeze(1)
#         thres[:, i] = torch.max(torch.zeros_like(prev_thres), prev_thres + delta - threshold).squeeze(1)
#         # Update the threshold
#         threshold += drift

#     return thres

# the following is required by BIWI
import os
import numpy as np
import pickle

with open(os.path.join("datasets/biwi/regions", "lve.txt")) as f:
    maps = f.read().split(", ")
    mouth_map = [int(i) for i in maps]

with open(os.path.join("datasets/biwi/regions", "fdd.txt")) as f:
    maps = f.read().split(", ")
    upper_map = [int(i) for i in maps]

# open /home/zhiyuan_ma/code/FaceDiffusion/datasets/vocaset/FLAME_masks.pkl
with open(os.path.join("datasets/vocaset", "FLAME_masks.pkl"), "rb") as f:
    masks = pickle.load(f, encoding='latin1')
    vocaset_mouth_map = masks["lips"].tolist()
    vocaset_upper_map = masks["forehead"].tolist() + masks["eye_region"].tolist()
    vocaset_upper_map = list(set(vocaset_upper_map))


def vocaset_upper_face_variance(motion, ):
    L2_dis_upper = np.array([np.square(motion[:,v, :]) for v in vocaset_upper_map])
    L2_dis_upper = np.transpose(L2_dis_upper, (1,0,2))
    L2_dis_upper = np.sum(L2_dis_upper,axis=2)
    L2_dis_upper = np.std(L2_dis_upper, axis=0)
    motion_std = np.mean(L2_dis_upper)
    return torch.tensor(motion_std).float() #torch.from_numpy(motion_std).float()

def vocaset_mouth_distance(vertices_gt, vertices_pred):
    L2_dis = np.array([np.square(vertices_gt[:,v, :] - vertices_pred[:,v, :]) for v in vocaset_mouth_map])
    L2_dis = np.transpose(L2_dis, (1,0,2)) # (V, N, 3) -> (N, V, 3)
    L2_dis = np.sum(L2_dis, axis=2) # (N, V, 3) -> (N, V)
    L2_dis = np.max(L2_dis, axis=1) # (N, V) -> (N)
    return torch.tensor(L2_dis).float()

def biwi_upper_face_variance(motion, ):
    L2_dis_upper = np.array([np.square(motion[:,v, :]) for v in upper_map])
    L2_dis_upper = np.transpose(L2_dis_upper, (1,0,2))
    L2_dis_upper = np.sum(L2_dis_upper,axis=2)
    L2_dis_upper = np.std(L2_dis_upper, axis=0)
    motion_std = np.mean(L2_dis_upper)
    return torch.tensor(motion_std).float() #torch.from_numpy(motion_std).float()

def biwi_mouth_distance(vertices_gt, vertices_pred):
    L2_dis = np.array([np.square(vertices_gt[:,v, :] - vertices_pred[:,v, :]) for v in mouth_map])
    L2_dis = np.transpose(L2_dis, (1,0,2)) # (V, N, 3) -> (N, V, 3)
    L2_dis = np.sum(L2_dis, axis=2) # (N, V, 3) -> (N, V)
    L2_dis = np.max(L2_dis, axis=1) # (N, V) -> (N)
    return torch.tensor(L2_dis).float()



================================================
FILE: alm/models/losses/voca.py
================================================
# import numpy as np
# import torch
# import torch.nn as nn
# from torchmetrics import Metric

# class almLosses(Metric):
#     """
#     Audio latent motion losses
#     """
#     def __init__(self, cfg):
#         super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP)

#         # Save parameters
#         # self.vae = vae
#         self.cfg = cfg
#         self.stage = cfg.TRAIN.STAGE

#     def updata(self, rs_set):
#         return None

#     def compute(self, split):
#         count = getattr(self, "count")
#         return {loss: getattr(self, loss) / count for loss in self.losses}

import numpy as np
import torch
import torch.nn as nn
from torchmetrics import Metric
import os
import pickle

class VOCALosses(Metric):
    """
    MLD Loss
    """

    def __init__(self, cfg, split):
        super().__init__(dist_sync_on_step=cfg.LOSS.DIST_SYNC_ON_STEP)

        self.cfg = cfg
        

        # set up loss 
        losses = []
        self._losses_func = {}
        self._params = {}

        reconstruct = MaskedConsistency()
        reconstruct_v = MaskedVelocityConsistency()

        self.split = split
        is_train = split in ['losses_train']

        if split in ['losses_train', 'losses_val']:
            # vertice 
            name = "vertice_enc" # enc here means encoding, just for matching the name in the the diffusion-denoising experiment
            losses.append(name)
            self._losses_func[name] = reconstruct
            self._params[name] = cfg.LOSS.VERTICE_ENC if is_train else 1.0
            self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")

            name = "vertice_encv" # encv here means encoding velocity, just for matching the name in the the diffusion-denoising experiment
            losses.append(name)
            self._losses_func[name] = reconstruct_v
            self._params[name] = cfg.LOSS.VERTICE_ENC_V if is_train else 1.0
            self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")

            name = "lip_enc"
            losses.append(name)
            self._losses_func[name] = reconstruct
            self._params[name] = cfg.LOSS.LIP_ENC if is_train else 1.0
            self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")

            name = "lip_encv"
            losses.append(name)
            self._losses_func[name] = reconstruct_v
            self._params[name] = cfg.LOSS.LIP_ENC_V if is_train else 1.0
            self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")

        elif split in ['losses_test']:
            pass # no loss for test
        else:
            raise ValueError(f"split {split} not supported")

        name = "total"
        losses.append(name)
        self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")

        name = 'count'
        self.add_state(name, default=torch.tensor(0), dist_reduce_fx="sum")
        
        self.losses = losses

    #     # obtain the lmk index
    #     # the following is required by BIWI
    #     with open(os.path.join("datasets/biwi/regions", "lve.txt")) as f:
    #         maps = f.read().split(", ")
    #         self.biwi_mouth_map = [int(i) for i in maps]
    #         self.biwi_mouth_map = torch.tensor(self.biwi_mouth_map).long()

    #     # the following is required by vocaset
    #     with open(os.path.join("datasets/vocaset", "FLAME_masks.pkl"), "rb") as f:
    #         masks = pickle.load(f, encoding='latin1')
    #         self.vocaset_mouth_map = masks["lips"].tolist()     
    #         self.vocaset_mouth_map = torch.tensor(self.vocaset_mouth_map).long()   

    # def vert2lip(self, vertice):

    #     num_verts = vertice.shape[-1] // 3
    #     if num_verts == 5023:
    #         mouth_map = self.vocaset_mouth_map.to(vertice.device)
    #     elif num_verts == 23370:
    #         mouth_map = self.biwi_mouth_map.to(vertice.device)
    #     else:
    #         raise ValueError(f"num_verts {num_verts} not supported")
        
    #     shape = vertice.shape
    #     lip_vertice = vertice.view(shape[0], shape[1], -1, 3)[:, :, mouth_map, :].view(shape[0], shape[1], -1)
    #     return lip_vertice

    def update(self, rs_set):
        # rs_set.keys() = dict_keys(['latent', 'latent_pred', 'vertice', 'vertice_recon', 'vertice_pred', 'vertice_attention'])

        total: float = 0.0
        # Compute the losses
        # Compute instance loss

        # padding mask
        mask = rs_set['vertice_attention'].unsqueeze(-1)

        if self.split in ['losses_train', 'losses_val']: 
            # vertice loss
            total += self._update_loss("vertice_enc", rs_set['vertice'], rs_set['vertice_pred'], mask = mask)
            total += self._update_loss("vertice_encv", rs_set['vertice'], rs_set['vertice_pred'], mask = mask)

            # lip loss
            # lip_vertice = self.vert2lip(rs_set['vertice'])
            # lip_vertice_pred = self.vert2lip(rs_set['vertice_pred'])
            # total += self._update_loss("lip_enc", lip_vertice, lip_vertice_pred, mask = mask)
            # total += self._update_loss("lip_encv", lip_vertice, lip_vertice_pred, mask = mask)

            self.total += total.detach()
            self.count += 1

            return total
        
        if self.split in ['losses_test']:
            raise ValueError(f"split {self.split} not supported")


    def compute(self, split):
        count = getattr(self, "count")
        return {loss: getattr(self, loss) / count for loss in self.losses}


    def _update_loss(self, loss: str, outputs, inputs, mask = None):
        # Update the loss
        if mask is not None:
            val = self._losses_func[loss](outputs, inputs, mask)
        else:
            val = self._losses_func[loss](outputs, inputs)
        getattr(self, loss).__iadd__(val.detach())
        # Return a weighted sum
        weighted_loss = self._params[loss] * val
        return weighted_loss

    def loss2logname(self, loss: str, split: str):
        if loss == "total":
            log_name = f"{loss}/{split}"
        else:
            loss_type, name = loss.split("_")
            log_name = f"{loss_type}/{name}/{split}"
        return log_name
    
class MaskedConsistency:
    def __init__(self) -> None:
        self.loss = nn.MSELoss(reduction="mean")

    def __call__(self, pred, gt, mask):
        # # masking nan
        # is_nan = torch.logical_or(torch.isnan(pred), torch.isnan(gt))
        # nan_mask = torch.logical_not(is_nan).long()
        # torch.where(nan_mask[0, ..., 0] != mask[0].squeeze())
        return self.loss(mask * pred, mask * gt)
    
    def __repr__(self):
        return self.loss.__repr__()
    
class MaskedVelocityConsistency:
    def __init__(self) -> None:
        self.loss = nn.MSELoss(reduction="mean")

    def __call__(self, pred, gt, mask):
        term1 = self.velocity(mask * pred)
        temr2 = self.velocity(mask * gt)
        return self.loss(term1, temr2)

    def velocity(self, term):
        velocity = term[:, 1:] - term[:, :-1]
        return velocity

    def __repr__(self):
        return self.loss.__repr__()

================================================
FILE: alm/models/modeltype/__init__.py
================================================


================================================
FILE: alm/models/modeltype/base.py
================================================
import os
from pathlib import Path
import numpy as np
from pytorch_lightning import LightningModule
import torch
from collections import OrderedDict

class BaseModel(LightningModule):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.times = []

    def __post_init__(self):
        trainable, nontrainable = 0, 0
        for p in self.parameters():
            if p.requires_grad:
                trainable += np.prod(p.size())
            else:
                nontrainable += np.prod(p.size())

        self.hparams.n_params_trainable = trainable
        self.hparams.n_params_nontrainable = nontrainable

    def training_step(self, batch, batch_idx):
        return self.allsplit_step("train", batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        return self.allsplit_step("val", batch, batch_idx)

    def test_step(self, batch, batch_idx):
        if len(self.times) *self.cfg.TEST.BATCH_SIZE % (100) > 0 and len(self.times) > 0:
            print(f"Average time per sample ({self.cfg.TEST.BATCH_SIZE*len(self.times)}): ", np.mean(self.times)/self.cfg.TEST.BATCH_SIZE)
        return self.allsplit_step("test", batch, batch_idx)

    def predict_step(self, batch, batch_idx):
        return self.forward(batch)

    def allsplit_epoch_end(self, split: str, outputs):
        dico = {}

        if split in ["train", "val"]:
            losses = self.losses[split]
            loss_dict = losses.compute(split)
            losses.reset()
            dico.update({
                losses.loss2logname(loss, split): value.item()
                for loss, value in loss_dict.items() #if not torch.isnan(value)
            })

            dico.update({
                "epoch": float(self.trainer.current_epoch),
                "step": float(self.trainer.current_epoch),
            })

        if split == "test":
            metircs = {key: [] for key in outputs[0].keys()}
            for output in outputs: # collect the results from all batches
                for key, value in output.items():
                    metircs[key].append(value)

            lengths = torch.stack(metircs.pop("Length"))
            for key, value in metircs.items():
                if key == 'Lip Vertex Error':
                    metircs[key] = torch.mean(torch.stack(value) * lengths) / torch.mean(lengths)
                metircs[key] = torch.mean(torch.stack(value))
            dico.update(metircs)

        if not self.trainer.sanity_checking:
            self.log_dict(dico, sync_dist=True, rank_zero_only=True)


    def training_epoch_end(self, outputs):
        return self.allsplit_epoch_end("train", outputs)

    def validation_epoch_end(self, outputs):
        return self.allsplit_epoch_end("val", outputs)

    def test_epoch_end(self, outputs):
        return self.allsplit_epoch_end("test", outputs)
    
    # def on_save_checkpoint(self, checkpoint):
    #     # don't save audio_encoder to checkpoint
    #     state_dict = checkpoint['state_dict']
    #     clip_k = []
    #     for k, v in state_dict.items():
    #         if 'audio_encoder' in k:
    #             clip_k.append(k)
    #     for k in clip_k:
    #         del checkpoint['state_dict'][k]

    # def on_load_checkpoint(self, checkpoint):
    #     # restore clip state_dict to checkpoint
    #     clip_state_dict = self.audio_encoder.state_dict()
    #     new_state_dict = OrderedDict()
    #     for k, v in clip_state_dict.items():
    #         new_state_dict['audio_encoder.' + k] = v
    #     for k, v in checkpoint['state_dict'].items():
    #         if 'audio_encoder' not in k:
    #             new_state_dict[k] = v
    #     checkpoint['audio_dict'] = new_state_dict

    # def load_state_dict(self, state_dict, strict=True):
    #     # load clip state_dict to checkpoint
    #     clip_state_dict = self.audio_encoder.state_dict()
    #     new_state_dict = OrderedDict()
    #     for k, v in clip_state_dict.items():
    #         new_state_dict['audio_encoder.' + k] = v
    #     for k, v in state_dict.items():
    #         if 'audio_encoder' not in k:
    #             new_state_dict[k] = v
    #     super().load_state_dict(new_state_dict, strict)


    def configure_optimizers(self):
        return {"optimizer": self.optimizer}




================================================
FILE: alm/models/modeltype/diffusion_bias.py
================================================
import torch
from torch.optim import AdamW, Adam
import torch.nn.functional as F
from torchmetrics import MetricCollection
from transformers import Wav2Vec2Model

from alm.config import instantiate_from_config
from alm.models.modeltype.base import BaseModel
from alm.models.losses.voca import VOCALosses
from alm.utils.demo_utils import animate
from .base import BaseModel

import inspect
from typing import Optional, Tuple, Union, Callable
import os
import time
from multiprocessing import Process
from tqdm import tqdm

import numpy as np

from time import time as infer_time
import pickle


class DIFFUSION_BIAS(BaseModel):

    def __init__(self, cfg, datamodule, **kwargs):
        """
        Initialize the model
        """
        # we only use the functions in the GPt_ADPT_LOCAL_ATTEN class, so we don't need to call the __init__ function of the GPT_ADPT_LOCAL_ATTEN class
        super().__init__()
        self.cfg = cfg
        self.datamodule = datamodule

        # set up losses
        self._losses = MetricCollection({
                split: VOCALosses(cfg=cfg, split=split)
                for split in ["losses_train", "losses_test", "losses_val",] # "losses_train_val"
            })

        self.losses = {
            key: self._losses["losses_" + key]
            for key in ["train", "test", "val", ] # "train_val"
        }

        # set up model
        self.audio_encoder = Wav2Vec2Model.from_pretrained(cfg.audio_encoder.model_name_or_path)
        if cfg.audio_encoder.train_audio_encoder:
            self.audio_encoder.feature_extractor._freeze_parameters() # we don't want to train the feature extractor
        else:
            for param in self.audio_encoder.parameters():
                param.requires_grad = False
        self.denoiser = instantiate_from_config(cfg.model.denoiser)

        # set up optimizer
        if cfg.TRAIN.OPTIM.TYPE.lower() == "adamw":
            self.optimizer = AdamW(lr=cfg.TRAIN.OPTIM.LR,
                                   params=filter(lambda p: p.requires_grad,self.parameters())
                                   )
        elif cfg.TRAIN.OPTIM.TYPE.lower() == "adam":
            self.optimizer = Adam(lr=cfg.TRAIN.OPTIM.LR,
                                  params=filter(lambda p: p.requires_grad,self.parameters())
                                  )
        else:
            raise NotImplementedError(
                "Do not support other optimizer for now.")

        # set up diffusion specific initialization
        if not cfg.model.predict_epsilon:
            cfg.model.scheduler.params['prediction_type'] = 'sample'
            cfg.model.noise_scheduler.params['prediction_type'] = 'sample'
        self.scheduler = instantiate_from_config(cfg.model.scheduler)
        self.noise_scheduler = instantiate_from_config(cfg.model.noise_scheduler)

        # set up the hidden state resizing parameters
        self.audio_fps = cfg.denoiser.params.audio_fps
        self.hidden_fps = cfg.denoiser.params.hidden_fps
        # set up the vertice dimension
        self.nfeats = cfg.denoiser.params.nfeats

        # guided diffusion
        self.guidance_uncondp = cfg.model.guidance_uncondp if hasattr(cfg.model, "guidance_uncondp") else 0.0
        self.guidance_scale = cfg.model.guidance_scale if hasattr(cfg.model, "guidance_scale") else 1.0
        # assert self.guidance_scale >= 0.0 and self.guidance_scale <= 1.0
        assert self.guidance_scale >= 0.0 
        self.do_classifier_free_guidence = self.guidance_scale > 0.0

        self.smooth_output = False
        if hasattr(cfg.model, "smooth_output") and cfg.model.smooth_output:
            self.smooth_output = True


    def allsplit_step(self, split: str, batch, batch_idx):
        """
        One step
        Args:
            split (str): train, test, val
            batch (dict): batch
            batch contains:
                template (torch.Tensor): [batch_size, vert_dim]
                vertice (torch.Tensor): [batch_size, vert_len, vert_dim ]
                vertice_attention (torch.Tensor): [batch_size, vert_len]
                audio (torch.Tensor): [batch_size, aud_len]
                audio_attention (torch.Tensor): [batch_size, aud_len]
                id (torch.Tensor): [batch_size, id_dim]
            batch_idx (int): batch index
        """
        # training
        if split == "train":
            if self.guidance_uncondp > 0: # we randomly mask the audio feature
                audio_mask = torch.rand(batch['audio'].shape[0]) < self.guidance_uncondp
                batch['audio'][audio_mask] = 0

            rs_set = self._diffusion_forward(batch, batch_idx, phase="train")
            loss = self.losses[split].update(rs_set)
            return loss


        if split in ["val", ]:
            # the id is not used in the validation
            # because the id in the validation is not the same as anyone in the training
            # so we set the id to be any one of the id in the training
            bs = batch["vertice"].shape[0]
            id_dim = self.cfg.denoiser.params.id_dim

            # collect the results for each id
            loss_list = []
            for idx in range(id_dim):
                batch["id"] = torch.zeros(bs, id_dim).to(batch["vertice"].device)
                batch["id"][:, idx] = 1
                with torch.no_grad():
                    # same as the training, we use the autoregressive inference
                    rs_set = self._diffusion_forward(batch, batch_idx, phase="val")
                    loss = self.losses[split].update(rs_set)

                    if loss is None:
                        return ValueError("loss is None")
                    
                    loss_list.append(loss)

                # visualize the result for the first id and the first batch
                if batch_idx == 0 and idx == 0:
                    self._visualize(batch, rs_set)

            loss = torch.stack(loss_list, dim=0).mean(dim=0)
            return loss

        if split in ["test"]:
            
            from alm.models.losses.utils import biwi_upper_face_variance, biwi_mouth_distance
            from alm.models.losses.utils import vocaset_upper_face_variance, vocaset_mouth_distance

            # we also need to collect the results for each id in the test
            bs = batch["vertice"].shape[0]
            id_dim = self.cfg.denoiser.params.id_dim

            # we don't need the vertice_attention in the test time
            # because we do not have the ground truth vertice
            if 'vertice_attention' in batch:
                batch.pop('vertice_attention') 

            # collect the results for each id
            ndim = batch['id'].ndim
            if ndim == 2:
                batch_id_list = [batch['id']] # the id is given
            elif ndim == 3:
                batch_id_list = []
                for i in range(id_dim): # the id is not given, we use the id of each person in the training set
                    batch["id"] = torch.zeros(bs, id_dim).to(batch["vertice"].device)
                    batch["id"][:, i] = 1
                    batch_id_list.append(batch['id'])
            else:
                raise ValueError(f"the dimension of the id should be 2 or 3, but got {ndim}")

            # collect the results for each id
            metrics_list = {}
            # for idx in tqdm(range(id_dim), desc="alternating among identities"):
            for idx, batch_id in enumerate(batch_id_list):
                # batch["id"] = torch.zeros(bs, id_dim).to(batch["vertice"].device)
                # batch["id"][:, idx] = 1

                id_idx = torch.where(batch_id == 1)[1][0].item()
                batch["id"] = batch_id.to(batch["vertice"].device)

                with torch.no_grad():

                    # start time
                    start_time = infer_time()

                    # same as the validation, we use the autoregressive inference
                    rs_set = self._diffusion_forward(batch, batch_idx, phase="val")

                    # end time
                    end_time = infer_time()

                    # calculate the metrics
                    if rs_set['vertice_pred'].shape[-1] == 70110: # BIWI
                        exp = 'BIWI'
                        pred = rs_set['vertice_pred'].view(-1, 23370, 3).detach().cpu().numpy()
                        gt = rs_set['vertice'].view(-1, 23370, 3).detach().cpu().numpy()
                        template =  batch['template'].view(-1, 23370, 3).detach().cpu().numpy()
                        min_len = min(pred.shape[0], gt.shape[0])
                        pred = pred[:min_len, :]
                        gt = gt[:min_len, :]
                        
                        metrics = {
                            "FDD": biwi_upper_face_variance( motion = (gt - template), ) \
                                - biwi_upper_face_variance(motion = (pred - template), ),
                            "Lip Vertex Error": biwi_mouth_distance(
                                vertices_gt=gt,
                                vertices_pred=pred,
                            ).mean(),
                            "Length": torch.tensor(min_len).float(),
                        }
                    else:                                         # VOCASET
                        exp = 'vocaset'
                        pred = rs_set['vertice_pred'].view(-1, 5023, 3).detach().cpu().numpy()
                        gt = rs_set['vertice'].view(-1, 5023, 3).detach().cpu().numpy()
                        template =  batch['template'].view(-1, 5023, 3).detach().cpu().numpy()
                        min_len = min(pred.shape[0], gt.shape[0])
                        pred = pred[:min_len, :]
                        gt = gt[:min_len, :]

                        metrics = {
                            "FDD": vocaset_upper_face_variance( motion = (gt - template), ) \
                                - vocaset_upper_face_variance(motion = (pred - template), ),
                            "Lip Vertex Error": vocaset_mouth_distance(
                                vertices_gt=gt,
                                vertices_pred=pred,
                            ).mean(),
                            "Length": torch.tensor(min_len).float(),
                        }

                    # save the results
                    # the code is a little bit messy, sorry for that
                    result_dir = self.cfg.FOLDER_EXP + '/results_{}'.format(exp)
                    os.makedirs(result_dir, exist_ok=True)
                    result_file = batch['file_name'][0].split('/')[-1].split('.')[0] + '_condition_' + str(id_idx) + '.pkl'

                    with open(os.path.join(result_dir, result_file), 'wb') as f:
                        report = {
                            'prediction': rs_set['vertice_pred'].detach().cpu().numpy(),
                            'ground_truth': rs_set['vertice'].detach().cpu().numpy(),
                            'template': batch['template'].detach().cpu().numpy(),
                            'audio_length': rs_set['vertice_pred'].shape[1] / self.cfg.hidden_fps,
                            'time': end_time - start_time,
                            'fdd': metrics['FDD'],
                            've': metrics['Lip Vertex Error'],
                        }
                        pickle.dump(report, f)

                    # save the metrics
                    if metrics is None:
                        return ValueError("metrics is None")
                    
                    for key in metrics: # collect the metrics for each id
                        if key not in metrics_list:
                            metrics_list[key] = []
                        metrics_list[key].append(metrics[key])

                # visualize the result for the first id and the first batch
                if batch_idx == 0 and idx == 0:
                    None #TODO: visualize the result

            # average the metrics for each id
            for key in metrics_list:
                metrics_list[key] = torch.stack(metrics_list[key], dim=0).mean(dim=0)
                
            return metrics_list
    
    def _memory_mask(self, hidden_attention, ):
        """
        Create memory_mask for transformer decoder, which is used to mask the padding information
        Args:
            hidden_attention: [batch_len, source_len]
            frame_num: int
        """

        if self.denoiser.use_mem_attn_bias:
            # since the source_len is the same as the target_len, we can use the same size to create the mask
            memory_mask = self.denoiser.memory_bi_bias[:hidden_attention.shape[1], :hidden_attention.shape[1]]

            # since the adapter is used, we need to unmask another position to make the adapter work, this position is the first and the second positions of the memory_mask
            adpater_mask = torch.zeros_like(memory_mask[:, :2]) # [1, source_len], since the apdater length = id + time = 2
            memory_mask = torch.cat([adpater_mask, memory_mask], dim = 1) # [source_len, latent_len + 2]

            # # visualize the attention bias using sns.heatmap
            # import seaborn as sns
            # import matplotlib.pyplot as plt
            # fig, ax = plt.subplots(figsize=(15, 10))
            # length = 100
            # # visualize the memory_mask
            # minimum = 1
            # mask = (-minimum*  memory_mask[:length, :length+2].long()).detach().cpu().numpy().astype(int)
            # sns.heatmap(mask, ax=ax,)

            # # set the cbar to be discrete
            # colorbar = ax.collections[0].colorbar
            # colorbar.set_ticks([-minimum, 0])
            # colorbar.set_ticklabels(['-inf','0'])
            # # save the figure
            # plt.savefig('memory_mask.png')



            return  memory_mask.bool().to(hidden_attention.device) # [source_len, latent_len + 2]
        else:
            return None
        
    def _tgt_mask(self, vertice_attention, ):
        """
        Create tgt_key_padding_mask for transformer decoder
        Args:
            vertice_attention: [batch_len, source_len]
            frame_num: int
        """


        if self.denoiser.use_tgt_attn_bias:
            batch_size = vertice_attention.shape[0]
            tgt_mask = self.denoiser.target_bi_bias[:, :vertice_attention.shape[1], :vertice_attention.shape[1]] # [num_heads, target_len, target_len]
            adapter_mask = torch.zeros_like(tgt_mask[..., :2]) # [num_heads, target_len, 2], since the apdater length = id + time = 2
            tgt_mask = torch.cat([adapter_mask, tgt_mask], dim = -1) # [num_heads, target_len, target_len + 2]

            # # visualize the attention bias using sns.heatmap
            # import seaborn as sns
            # import matplotlib.pyplot as plt
            # fig, ax = plt.subplots(figsize=(15, 10))
            # length = 100
            # # visualize the tgt_mask
            # mask = (5 * tgt_mask[0, :length, :length+2]).long().detach().cpu().numpy().astype(int)
            # # set the cbar to be discrete
            # cbar_kws = {
            #     "ticks": np.arange(mask.min(), mask.max()+1),
            #     "boundaries": np.arange(mask.min() - 0.5, mask.max() + 1.5)
            # }
            # sns.heatmap(mask, ax=ax, cbar_kws=cbar_kws)
            # # save the figure
            # plt.savefig('tgt_mask.png')

            # repeat the mask for each batch
            tgt_mask = tgt_mask.repeat(batch_size, 1, 1) # [batch_size * num_heads, target_len, target_len + 2]
            return tgt_mask.to(vertice_attention.device, non_blocking=True) # [batch_size * num_heads, target_len, target_len + 2]
        else:
            return None

    def _mem_key_padding_mask(self, vertice_attention):
        """
        Create mem_key_padding_mask for transformer decoder, which is used to mask the padding information
        Args:
            hidden_attention: [batch_len, source_len]
        """

        # since the adapter is used, we need to unmask another position to make the adapter work
        # this position is the first and the second positions of the mem_key_padding_mask
        adpater_mask = torch.ones_like(vertice_attention[:, :2]) # [batch_size, 2], since the apdater length = id + time = 2
        vertice_attention = torch.cat([adpater_mask, vertice_attention], dim = 1) # [batch_size, source_len + 2]

        # mask with 1 means that the position is masked
        return ~vertice_attention.bool()
    
    def _tgt_key_padding_mask(self, vertice_attention):
        """
        Create tgt_key_padding_mask for transformer decoder, which is used to mask the padding information
        Args:
            hidden_attention: [batch_len, target_len]
        """
        # since the adapter is used, we need to unmask another position to make the adapter work
        # this position is the first and the second positions of the tgt_key_padding_mask

        adpater_mask = torch.ones_like(vertice_attention[:, :2])
        vertice_attention = torch.cat([adpater_mask, vertice_attention], dim = 1)

        # mask with 1 means that the position is masked
        return ~vertice_attention.bool()

    def _audio_resize(self, hidden_state: torch.Tensor, input_fps: Optional[float] = None , output_fps: Optional[float] = None, output_len = None):
        """
        Resize the audio feature to the same length as the vertice
        Args:
            hidden_state (torch.Tensor): [batch_size, hidden_size, seq_len]
            input_fps (float): input fps
            output_fps (float): output fps
            output_len (int): output length
        """
        # if the input_fps and output_fps is not given, we use the default value
        input_fps = input_fps if input_fps is not None else self.cfg.denoiser.params.audio_fps
        output_fps = output_fps if output_fps is not None else self.cfg.denoiser.params.hidden_fps

        hidden_state = hidden_state.transpose(1,2)
        if output_len is None:
            seq_len = hidden_state.shape[2] / input_fps
            output_len = int(seq_len * output_fps)
        output_features = F.interpolate(hidden_state, size = output_len, align_corners=True, mode="linear")
        return output_features.transpose(2,1)

    def _audio_2_hidden(self, audio, audio_attention, length = None):
        """
        This function takes in an audio tensor and its corresponding attention mask, 
        and returns a hidden state tensor that represents the audio feature map. 
        The function first passes the audio tensor through an audio encoder to obtain the last hidden state. 
        It then resizes the hidden state to match the length of the input sequence, using the _audio_resize function. 
        Finally, the function passes the resized hidden state through the audio_feature_map layer of the denoiser to obtain the final hidden state tensor. 
        The output tensor has shape [batch_size, seq_len, latent_dim], where seq_len is the length of the input audio sequence and latent_dim is the dimensionality of the latent space.
        """
        hidden_state = self.audio_encoder(audio, attention_mask = audio_attention).last_hidden_state
        hidden_state = self._audio_resize(
            hidden_state, 
            output_len = length # if vertice is not given, we use the full length of the audio
        )    

        hidden_state = self.denoiser.audio_feature_map(hidden_state) # hidden_state.shape = [batch_size, seq_len, latent_dim]
        return hidden_state
    
    def _diffusion_forward(self, batch, batch_idx, phase):
        """
        Forward pass for training
        Args:
            batch (dict): batch
            batch contains:
                template (torch.Tensor): [batch_size, vert_dim]
                vertice (torch.Tensor): [batch_size, vert_len, vert_dim ]
                vertice_attention (torch.Tensor): [batch_size, vert_len]
                audio (torch.Tensor): [batch_size, aud_len]
                audio_attention (torch.Tensor): [batch_size, aud_len]
                id (torch.Tensor): [batch_size, id_dim]
                phase (str): eihter 'train' or 'val'
            batch_idx (int): batch index
        """

        # process audio condition
        hidden_state = self._audio_2_hidden(batch['audio'], batch['audio_attention'], length = batch['vertice'].shape[1] if 'vertice' in batch else None) # hidden_state.shape = [batch_size, seq_len, latent_dim]
        if 'vertice_attention' not in batch:
            # if the vertice_attention is not given, we assume that all the vertices are valid, so we set the attention to be all ones
            batch['vertice_attention'] = torch.ones(
                hidden_state.shape[0], 
                hidden_state.shape[1], # in our setting, the length of the vertice_attention should be the same as the length of the hidden_state
            ).long().to(hidden_state.device) # this attention should be long type
        
        # template is subtracted from the vertice_input to make the template as the origin of the vertice
        template = batch['template'].unsqueeze(1) # template.shape = [batch_size, 1, vert_dim]

        if phase == 'train':
            vertice_input = batch['vertice'] - template # vertice_input.shape = [batch_size, vert_len, vert_dim]
            # perform the diffusion forward process
            vertice_output = self._diffusion_process(
                vertice_input, 
                hidden_state, 
                batch['id'],
                vertice_attention = batch['vertice_attention'],
            ) + template # vertice_output.shape = [batch_size, vert_len, vert_dim]
        
        elif phase == 'val':

            if self.do_classifier_free_guidence:
                silent_hidden_state = self._audio_2_hidden(
                    torch.zeros_like(batch['audio']), # we use the silent audio as the input
                    batch['audio_attention'],
                    length=hidden_state.shape[1], # just use the length of the hidden_state, in case their length is different
                )
            else:
                silent_hidden_state = None

            # perform the diffusion revise process
            vertice_output = self._diffusion_reverse(
                hidden_state,
                batch['id'],
                vertice_attention = batch['vertice_attention'],
                silent_hidden_state = silent_hidden_state,
            ) + template # vertice_output.shape = [batch_size, vert_len, vert_dim]
            
            if self.smooth_output:
                # # smooth the prediction does not significantly affect the metric but makes the animation smoother
                vertice_output = self.smooth(vertice_output)
        else:
            raise ValueError(f"phase should be either 'train' or 'val', but got {phase}")

        rs_set = {
            "vertice_pred": vertice_output,
            "vertice": batch['vertice'] if 'vertice' in batch else None,
            "vertice_attention": batch['vertice_attention'],
        }
        return rs_set

    def smooth(self, vertices):
        vertices_smooth = F.avg_pool1d(
            vertices.permute(0, 2, 1),
            kernel_size=3, 
            stride=1, 
            padding=1
        ).permute(0, 2, 1)  # smooth the prediction with a moving average filter
        vertices[:, 1:-1] = vertices_smooth[:, 1:-1]
        return vertices


    def predict(self, batch, **kwargs):
        """
        Predict the result in the test time
        Here the length of the vertice_attention is decided by the length of the audio
        """

        if 'audio_attention' not in batch:
            # if the audio_attention is not given, we assume that all the audio is valid, so we set the attention to be all ones
            batch['audio_attention'] = torch.ones(
                batch['audio'].shape[0], 
                batch['audio'].shape[1], 
            ).long().to(batch['audio'].device) # this attention should be long type

        if 'id' not in batch:
            # if the id is not given, we use the id of the first person in the training set
            batch['id'] = kwargs.get(
                'id',
                torch.zeros(
                    1, #batch['vertice'].shape[0], 
                    self.cfg.denoiser.params.id_dim
                ).to(batch['audio'].device)
            )
            batch['id'][:, 0] = 1
        else:
            assert batch['id'].shape[1] == self.cfg.denoiser.params.id_dim, \
                f"the id dimension should be {self.cfg.denoiser.params.id_dim}, but got {batch['id'].shape[1]}"

        if 'vertice' in batch:
            # if the vertice is given, we use the given vertice as the vertice attention mask
            batch['vertice_attention'] = torch.ones(
                batch['vertice'].shape[0],
                batch['vertice'].shape[1]
            ).long().to(batch['vertice'].device)
            # this attention should be long type

        # add the batch dimension to the template
        batch['template'] = batch['template'][None, ...]

        # perform the diffusion forward process
        vertice_output = self._diffusion_forward(batch, 0, 'val')['vertice_pred'] # vertice_output.shape = [batch_size, vert_len, vert_dim]                
        
        rs_set = {
            "vertice_pred": vertice_output,
            "vertice": batch['vertice'] if 'vertice' in batch else None,
            "vertice_attention": batch['vertice_attention'],
        }
        return rs_set

    def _diffusion_process(
        self,
        vertice_input: torch.Tensor,
        hidden_state: torch.Tensor,
        id: torch.Tensor,
        vertice_attention: Optional[torch.Tensor] = None,
    ):  
        """
        Perform the diffusion forward process during training
        Args:
            vertice_input (torch.Tensor): [batch_size, vert_len, vert_dim], the grount truth vertices, padding may included
            hidden_state (torch.Tensor): [batch_size, seq_len, latent_dim], the audio feature, padding may included
            id (torch.Tensor): [batch_size, id_dim], the id of the subject
            vertice_attention (torch.Tensor): [batch_size, vert_len], the attention of the vertices to indicate which vertices are valid, since the audio feature has the same length as the vertices, the vertice_attention should be the same length as the hidden_state
        """

        # extract the id style
        object_emb = self.denoiser.obj_vector(torch.argmax(id, dim = 1)).unsqueeze(1) # object_emb.shape = [batch_size, 1, latent_dim]

        # sample noise
        noise = torch.randn_like(vertice_input) # noise.shape = [batch_size, vert_len, vert_dim]

        # sample a random timestep for the minibatch
        bsz = vertice_input.shape[0]
        timesteps = torch.randint(
            0,
            self.noise_scheduler.config.num_train_timesteps,
            (bsz,),
            device = vertice_input.device
        ) # timesteps.shape = [batch_size]

        # add noise to the latents
        noise_input = self.noise_scheduler.add_noise(
            vertice_input,
            noise,
            timesteps,
        ) # noise_input.shape = [batch_size, vert_len, vert_dim]

        # predict the noise or the input
        vertice_pred = self.denoiser(
            vertice_input = noise_input, # noise_input.shape = [batch_size, vert_len, vert_dim]
            hidden_state = hidden_state, # hidden_state.shape = [batch_size, seq_len, latent_dim]
            timesteps = timesteps, # timesteps.shape = [batch_size]
            adapter = object_emb, # object_emb.shape = [batch_size, 1, latent_dim]
            tgt_mask = self._tgt_mask(vertice_attention), # tgt_mask.shape = [vert_len, vert_len]
            memory_mask = self._memory_mask(vertice_attention), # memory_mask.shape = [vert_len, seq_len]
            tgt_key_padding_mask = self._tgt_key_padding_mask(vertice_attention), # tgt_key_padding_mask.shape = [batch_size, vert_len]
            memory_key_padding_mask = self._mem_key_padding_mask(vertice_attention), # memory_key_padding_mask.shape = [batch_size, seq_len]
        )

        return vertice_pred
    
    def _diffusion_reverse(
        self,
        hidden_state: torch.Tensor,
        id: torch.Tensor,
        vertice_attention: torch.Tensor,
        silent_hidden_state: Optional[torch.Tensor] = None,
    ):  
        """
        Perform the diffusion reverse process during inference
        Args:
            hidden_state (torch.Tensor): [batch_size, seq_len, latent_dim], the audio feature, padding may included
            id (torch.Tensor): [batch_size, id_dim], the id of the subject
            vertice_attention (torch.Tensor): [batch_size, vert_len], the attention of the vertices to indicate which vertices are valid, since the audio feature has the same length as the vertices, the vertice_attention should be the same length as the hidden_state
        """

        # extract the id style
        object_emb = self.denoiser.obj_vector(torch.argmax(id, dim = 1)).unsqueeze(1) # object_emb.shape = [batch_size, 1, latent_dim]

        # sample noise
        vertices = torch.randn(
            (
                hidden_state.shape[0], # batch_size
                hidden_state.shape[1], # vert_len
                self.nfeats, # latent_dim
            ),
            device = hidden_state.device,
            dtype = torch.float,
        )

        # scale the initial noise by the standard deviation required by the scheduler
        vertices = vertices * self.scheduler.init_noise_sigma

        # set timesteps
        self.scheduler.set_timesteps(self.cfg.model.scheduler.num_inference_timesteps)
        timesteps = self.scheduler.timesteps.to(hidden_state.device, non_blocking=True)
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, and between [0, 1]
        extra_step_kwargs = {}
        if "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()):
            extra_step_kwargs["eta"] = self.cfg.model.scheduler.eta

        if silent_hidden_state is not None: # self.do_classifier_free_guidence is True
            hidden_state = torch.cat([hidden_state, silent_hidden_state], dim = 0) # hidden_state.shape = [batch_size * 2, seq_len, latent_dim
            vertice_attention = torch.cat([vertice_attention, ] * 2, dim = 0) # vertice_attention.shape = [batch_size * 2, vert_len]
            object_emb = torch.cat([object_emb, ] * 2, dim = 0) # object_emb.shape = [batch_size * 2, 1, latent_dim]

        # perform denoising
        for i, t in enumerate(timesteps):
            if silent_hidden_state is not None: # self.do_classifier_free_guidence is True
                vertices = torch.cat(
                    [vertices] * 2,
                    dim = 0,
                ) # vertices.shape = [batch_size * 2, vert_len, latent_dim]

            # perform denoising step
            vertices_pred = self.denoiser(
                vertice_input = vertices, # vertices.shape = [batch_size, vert_len, latent_dim]
                hidden_state = hidden_state, # hidden_state.shape = [batch_size, seq_len, latent_dim]
                timesteps = t.expand(hidden_state.shape[0]), # timesteps.shape = [batch_size]
                adapter = object_emb, # object_emb.shape = [batch_size, 1, latent_dim]
                tgt_mask = self._tgt_mask(vertice_attention), # tgt_mask.shape = [vert_len, vert_len]
                memory_mask = self._memory_mask(vertice_attention), # memory_mask.shape = [vert_len, seq_len]
                tgt_key_padding_mask = self._tgt_key_padding_mask(vertice_attention), # tgt_key_padding_mask.shape = [batch_size, vert_len]
                memory_key_padding_mask = self._mem_key_padding_mask(vertice_attention), # memory_key_padding_mask.shape = [batch_size, seq_len]
            )
            
            # perform guided denoising step
            if silent_hidden_state is not None: # self.do_classifier_free_guidence is True
                vertices_pred_audio, vertices_pred_uncond = vertices_pred.chunk(2, dim = 0)
                vertices_pred = vertices_pred_audio + (vertices_pred_audio - vertices_pred_uncond)* self.guidance_scale

                vertices, _ = vertices.chunk(2, dim = 0)

            vertices = self.scheduler.step(vertices_pred, t, vertices, **extra_step_kwargs).prev_sample
                
        return vertices

    def _visualize(self, batch, rs_set, parrallel = True):
        """
        Visualize the result
        Args:
            batch (dict): batch
                batch contains:
                    file_path (list): audio file path
                    vertice (torch.Tensor): [batch_size, vert_len, vert_dim ]
            rs_set (dict): result set
                rs_set contains:
                    vertice_pred (torch.Tensor): [batch_size, vert_len, vert_dim ]

            parrallel (bool): if True, the visualization will be performed in the current thread,
                otherwise, the visualization will be performed in a new thread
        """
        # visualize the result only for the first data in the batch
        data_idx = 0

        audio_path = batch["file_path"][data_idx]
        vis_path = os.path.join(
            self.cfg.FOLDER_EXP,
            "visualization",
            "{}_{}.{}".format(
                time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()), # current time, it should be better to use the epoch number, but I am lazy
                audio_path.split("/")[-1].split(".")[0], # audio name
                'mp4' # video format
            )
        )

        # visualize the result
        if not parrallel:
            # use the current thread to visualize the result
            animate(
                vertices = rs_set["vertice_pred"][data_idx, ...].squeeze().cpu().numpy(),
                wav_path = audio_path,
                file_name = vis_path,
                ply = self.cfg.DEMO.PLY,
                fps = self.cfg.DEMO.FPS,
            )
        else:
            # use another thread to visualize the result
            p = Process(
                target=animate,
                args=(
                    rs_set["vertice_pred"][data_idx, ...].squeeze().cpu().numpy(),
                    audio_path,
                    vis_path,
                    self.cfg.DEMO.PLY,
                    self.cfg.DEMO.FPS,
                    ),
                )
            p.start()

================================================
FILE: alm/utils/__init__.py
================================================


================================================
FILE: alm/utils/demo_utils.py
================================================
from transformers import Wav2Vec2Processor
import numpy as np
import librosa
import os
import torch
import cv2
import pyrender
import trimesh

import tempfile
import imageio

from tqdm import tqdm
try:
    from psbody.mesh import Mesh
except:
    Mesh = None

import platform
if platform.system() == "Linux":
    # os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
    os.environ['PYOPENGL_PLATFORM'] = 'egl'


def load_example_input(audio_path, processor = None):
    if processor is None:
        processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
    
    speech_array, sampling_rate = librosa.load(
            os.path.join(audio_path), 
            sr=16000
        )

    audio_feature = np.squeeze(
        processor(
            speech_array,
            sampling_rate = sampling_rate
        ).input_values
    )

    audio_feature = np.reshape(
        audio_feature,
        (-1,audio_feature.shape[0])
    )

    return torch.FloatTensor(audio_feature)


# # The implementation of rendering is borrowed from VOCA: https://github.com/TimoBolkart/voca/blob/master/utils/rendering.py
# def render_mesh_helper(mesh, t_center, rot=np.zeros(3), tex_img=None, z_offset=0, template_type: str = "flame"):

#     assert template_type in ["flame", "biwi"], "template_type should be one of ['flame', 'biwi'],but got {}".format(template_type)


#     if template_type == "flame":
#         camera_params = {'c': np.array([400, 400]),
#                             'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
#                             'f': np.array([4754.97941935 / 2, 4754.97941935 / 2])}
#     elif template_type == "biwi":
#         camera_params = {'c': np.array([400, 400]),
#                          'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
#                          'f': np.array([4754.97941935 / 8, 4754.97941935 / 8])}
        
#     frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800}

#     mesh_copy = Mesh(mesh.v, mesh.f)
#     mesh_copy.v[:] = cv2.Rodrigues(rot)[0].dot((mesh_copy.v-t_center).T).T+t_center
#     intensity = 2.0
#     rgb_per_v = None

#     primitive_material = pyrender.material.MetallicRoughnessMaterial(
#                 alphaMode='BLEND',
#                 baseColorFactor=[0.3, 0.3, 0.3, 1.0],
#                 metallicFactor=0.8, 
#                 roughnessFactor=0.8 
#             )

#     tri_mesh = trimesh.Trimesh(vertices=mesh_copy.v, faces=mesh_copy.f, vertex_colors=rgb_per_v)
#     render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=primitive_material,smooth=True)

#     # scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0])
#     scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])

#     camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0],
#                                       fy=camera_params['f'][1],
#                                       cx=camera_params['c'][0],
#                                       cy=camera_params['c'][1],
#                                       znear=frustum['near'],
#                                       zfar=frustum['far'])

#     scene.add(render_mesh, pose=np.eye(4))

#     camera_pose = np.eye(4)
#     camera_pose[:3,3] = np.array([0, 0, 1.0-z_offset])
#     scene.add(camera, pose=[[1, 0, 0, 0],
#                             [0, 1, 0, 0],
#                             [0, 0, 1, 1],
#                             [0, 0, 0, 1]])

#     angle = np.pi / 6.0
#     pos = camera_pose[:3,3]
#     light_color = np.array([1., 1., 1.])
#     light = pyrender.DirectionalLight(color=light_color, intensity=intensity)

#     light_pose = np.eye(4)
#     light_pose[:3,3] = pos
#     scene.add(light, pose=light_pose.copy())
    
#     light_pose[:3,3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos)
#     scene.add(light, pose=light_pose.copy())

#     light_pose[:3,3] =  cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos)
#     scene.add(light, pose=light_pose.copy())

#     light_pose[:3,3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos)
#     scene.add(light, pose=light_pose.copy())

#     light_pose[:3,3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos)
#     scene.add(light, pose=light_pose.copy())

#     flags = pyrender.RenderFlags.SKIP_CULL_FACES
#     # try:
#     r = pyrender.OffscreenRenderer(viewport_width=frustum['width'], viewport_height=frustum['height'])
#     color, _ = r.render(scene, flags=flags)
#     # except:
#     #     print('pyrender: Failed rendering frame')
#     #     color = np.zeros((frustum['height'], frustum['width'], 3), dtype='uint8')

#     return color[..., ::-1]

# The implementation of rendering is borrowed from VOCA: https://github.com/TimoBolkart/voca/blob/master/utils/rendering.py
def render_mesh_helper(mesh, t_center, rot=np.zeros(3), tex_img=None, z_offset=0, template_type: str = "flame", rgb_per_v = None):
    

    assert template_type in ["flame", "biwi"], "template_type should be one of ['flame', 'biwi'],but got {}".format(template_type)


    if template_type == "flame":
        camera_params = {'c': np.array([400, 400]),
                            'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
                            'f': np.array([4754.97941935 / 2, 4754.97941935 / 2])}
    elif template_type == "biwi":
        camera_params = {'c': np.array([400, 400]),
                         'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
                         'f': np.array([4754.97941935 / 8, 4754.97941935 / 8])}
        
    frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800}

    mesh_copy = Mesh(mesh.v, mesh.f)
    mesh_copy.v[:] = cv2.Rodrigues(rot)[0].dot((mesh_copy.v-t_center).T).T+t_center

    if rgb_per_v is None:
        intensity = 2.0
        primitive_material = pyrender.material.MetallicRoughnessMaterial(
                    alphaMode='BLEND',
                    baseColorFactor=[0.3, 0.3, 0.3, 1.0],
                    metallicFactor=0.8, 
                    roughnessFactor=0.8 
                )

        tri_mesh = trimesh.Trimesh(vertices=mesh_copy.v, faces=mesh_copy.f, vertex_colors=rgb_per_v)
        render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=primitive_material,smooth=True)
    else:
        intensity = 0.5
        tri_mesh = trimesh.Trimesh(vertices=mesh_copy.v, faces=mesh_copy.f, vertex_colors=rgb_per_v)
        render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, smooth=True)

    # scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0])
    scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])

    camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0],
                                      fy=camera_params['f'][1],
                                      cx=camera_params['c'][0],
                                      cy=camera_params['c'][1],
                                      znear=frustum['near'],
                                      zfar=frustum['far'])

    scene.add(render_mesh, pose=np.eye(4))

    camera_pose = np.eye(4)
    camera_pose[:3,3] = np.array([0, 0, 1.0-z_offset])
    scene.add(camera, pose=[[1, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, 1, 1],
                            [0, 0, 0, 1]])

    angle = np.pi / 6.0
    pos = camera_pose[:3,3]
    light_color = np.array([1., 1., 1.])
    light = pyrender.DirectionalLight(color=light_color, intensity=intensity)

    light_pose = np.eye(4)
    light_pose[:3,3] = pos
    scene.add(light, pose=light_pose.copy())
    
    light_pose[:3,3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    light_pose[:3,3] =  cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    light_pose[:3,3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    light_pose[:3,3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    flags = pyrender.RenderFlags.SKIP_CULL_FACES
    # try:
    r = pyrender.OffscreenRenderer(viewport_width=frustum['width'], viewport_height=frustum['height'])
    color, _ = r.render(scene, flags=flags)
    # except:
    #     print('pyrender: Failed rendering frame')
    #     color = np.zeros((frustum['height'], frustum['width'], 3), dtype='uint8')

    return color[..., ::-1]

def render_frame(args):
    predicted_vertice, f, center,  template_type = args
    render_mesh = Mesh(predicted_vertice, f)
    pred_img = render_mesh_helper(render_mesh, center, template_type=template_type)
    pred_img = pred_img.astype(np.uint8)
    return pred_img

def animate(vertices: np.array, wav_path: str, file_name: str, ply: str, fps: int = 25, vertice_gt: np.array = None, use_tqdm: bool = False, multi_process = False):
    """
    Animate the predicted vertices with the synchronized audio and save the video to the output directory.
    Args:
        vertices: (num_frames, num_vertices*3)
        wav_path: path to wav file
        file_name: name of the output file
        ply: path to the ply file
        fps: frames per second
        use_tqdm: whether to use tqdm to show the progress
        vertice_gt: (num_frames, num_vertices*3)
        template: template to use, can be "flame" or "biwi"
    """
    # make output dir
    output_dir = os.path.dirname(file_name)
    os.makedirs(output_dir, exist_ok=True)

    template = Mesh(filename=ply)
    # determine biwi or flame
    if "FLAME" in ply:
        template_type = "flame"
    elif "BIWI" in ply:
        template_type = "biwi"
    else:
        raise ValueError("Template type not recognized, please use either BIWI or FLAME")

    # reshape vertices
    predicted_vertices = vertices.reshape(-1, vertices.shape[1]//3, 3) if vertices.ndim < 3 else vertices

    num_frames = predicted_vertices.shape[0]
    if vertice_gt is not None:
        vertice_gt = vertice_gt.reshape(-1, vertice_gt.shape[1]//3, 3) if vertice_gt.ndim < 3 else vertice_gt
        num_frames = np.where(np.sum(vertice_gt, axis=(1, 2)) != 0)[0][-1] + 1 # find the number of frames where the vertices are not all zeros

    tmp_video_file = tempfile.NamedTemporaryFile('w', suffix='.mp4', dir=output_dir)
    center = np.mean(predicted_vertices[0], axis=0)


    # make animation
    if multi_process:

        from multiprocessing import Pool, cpu_count
        from itertools import cycle
        # get maximum num of process
        frames = []
        max_processes = cpu_count()
        with Pool(processes=max_processes) as pool:
            args = [(
                predicted_vertice,
                template.f,
                center,
                template_type
            ) for predicted_vertice in predicted_vertices]

            for pred_img in pool.imap(render_frame, tqdm(args)):
                frames.append(pred_img)

        if vertice_gt is not None:
            frames_gt = []
            with Pool(processes=max_processes) as pool:
                args = [(
                    gt_vertice,
                    template.f,
                    center,
                    template_type
                ) for gt_vertice in vertice_gt]
                
                for gt_img in pool.imap(render_frame, tqdm(args)):
                    frames_gt.append(gt_img)

            # concat two videos
            frames_final = []
            for i in range(num_frames):
                frames_final.append(np.concatenate([frames_gt[i], frames[i]], axis=1))
            frames = frames_final

    else:
        frames = []
        for i_frame in tqdm(range(num_frames)) if use_tqdm else range(num_frames):
            render_mesh = Mesh(predicted_vertices[i_frame], template.f)
            pred_img = render_mesh_helper(render_mesh, center, template_type=template_type)
            pred_img = pred_img.astype(np.uint8)
            frames.append(pred_img)

        if vertice_gt is not None:
            frames_gt = []
            for i_frame in tqdm(range(num_frames)) if use_tqdm else range(num_frames):
                render_mesh = Mesh(vertice_gt[i_frame], template.f)
                pred_img = render_mesh_helper(render_mesh, center)
                pred_img = pred_img.astype(np.uint8)
                frames_gt.append(pred_img)
        
            # concat two videos
            frames_final = []
            for i in range(num_frames):
                frames_final.append(np.concatenate([frames_gt[i], frames[i]], axis=1))
            frames = frames_final

    imageio.mimsave(tmp_video_file.name, frames, fps = fps)

    cmd = " ".join(['ffmpeg', '-hide_banner -loglevel error', '-y', '-i', tmp_video_file.name, '-i', wav_path, '-c:v copy -c:a aac', '-pix_fmt yuv420p -qscale 0',file_name, ])
    cmd = " ".join(['ffmpeg', '-i', tmp_video_file.name, '-i', wav_path, '-c:v copy -c:a aac', '-pix_fmt yuv420p -qscale 0',file_name, ])
    
    os.system(cmd)
    tmp_dir = tempfile.gettempdir() # check if the wav file is in the tmp dir
    if os.path.exists(wav_path) and tmp_dir in wav_path: 
        os.remove(wav_path)

    print(f"Video saved to {file_name}")


================================================
FILE: alm/utils/logger.py
================================================
from pathlib import Path
import os
import time
import logging
from omegaconf import OmegaConf
from pytorch_lightning.utilities.rank_zero import rank_zero_only


def create_logger(cfg, phase='train'):
    # root dir set by cfg
    root_output_dir = Path(cfg.FOLDER)
    # set up logger
    if not root_output_dir.exists():
        print('=> creating {}'.format(root_output_dir))
        root_output_dir.mkdir()

    cfg_name = cfg.NAME
    model = cfg.model.model_type
    cfg_name = os.path.basename(cfg_name).split('.')[0]

    final_output_dir = root_output_dir / model / cfg_name
    cfg.FOLDER_EXP = str(final_output_dir)

    time_str = time.strftime('%Y-%m-%d-%H-%M-%S')

    new_dir(cfg, phase, time_str, final_output_dir)

    head = '%(asctime)-15s %(message)s'
    logger = config_logger(final_output_dir, time_str, phase, head)
    if logger is None:
        logger = logging.getLogger()
        logger.setLevel(logging.CRITICAL)
        logging.basicConfig(format=head)
    return logger


@rank_zero_only
def config_logger(final_output_dir, time_str, phase, head):
    log_file = '{}_{}_{}.log'.format('log', time_str, phase)
    final_log_file = final_output_dir / log_file
    logging.basicConfig(filename=str(final_log_file))
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    formatter = logging.Formatter(head)
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    file_handler = logging.FileHandler(final_log_file, 'w')
    file_handler.setFormatter(logging.Formatter(head))
    file_handler.setLevel(logging.INFO)
    logging.getLogger('').addHandler(file_handler)
    return logger


@rank_zero_only
def new_dir(cfg, phase, time_str, final_output_dir):
    # new experiment folder
    cfg.TIME = str(time_str)
    if os.path.exists(
            final_output_dir) and cfg.TRAIN.RESUME is None and not cfg.DEBUG:
        file_list = sorted(os.listdir(final_output_dir), reverse=True)
        for item in file_list:
            if item.endswith('.log'):
                os.rename(str(final_output_dir),
                          str(final_output_dir) + '_' + cfg.TIME)
                break
    final_output_dir.mkdir(parents=True, exist_ok=True)
    # write config yaml
    config_file = '{}_{}_{}.yaml'.format('config', time_str, phase)
    final_config_file = final_output_dir / config_file
    OmegaConf.save(config=cfg, f=final_config_file)


================================================
FILE: alm/utils/temos_utils.py
================================================
from typing import Dict, List

import numpy as np
import torch
from torch import Tensor

def lengths_to_mask(lengths: Tensor, # [batch_size]
                    device: torch.device,
                    max_len: int = None) -> Tensor:
    max_len = max_len if max_len else max(lengths.cpu().tolist())
    mask = torch.arange(max_len, device=device).expand(
        len(lengths), max_len) < lengths.unsqueeze(1)
    return mask

def remove_padding(tensors, lengths):
    return [
        tensor[:tensor_length]
        for tensor, tensor_length in zip(tensors, lengths)
    ]

================================================
FILE: configs/assets/biwi.yaml
================================================
FOLDER: './experiments/biwi' # Experiment files saving path

TEST:
  FOLDER: './results' # Testing files saving path

DATASET:
  BIWI:
    ROOT: ./datasets/biwi



================================================
FILE: configs/assets/vocaset.yaml
================================================
FOLDER: './experiments/vocaset' # Experiment files saving path

TEST:
  FOLDER: './results' # Testing files saving path

DATASET:
  VOCASET:
    ROOT: ./datasets/vocaset



================================================
FILE: configs/base.yaml
================================================
# FOLDER: ./experiments
SEED_VALUE: 1234
DEBUG: True
TRAIN:
  SPLIT: 'train'
  NUM_WORKERS: 1 #2 # Number of workers
  BATCH_SIZE: 4 # Size of batches
  START_EPOCH: 0 # Start epoch
  END_EPOCH: 2000 # End epoch
  RESUME: '' # Experiment path to be resumed training
  PRETRAINED: '' # Pretrained model path

  OPTIM:
    OPTIM.TYPE: 'AdamW' # Optimizer type
    OPTIM.LR: 1e-4 # Learning rate

  ABLATION:
    SKIP_CONNECT: False # skip connection for denoiser va
    # use linear to expand mean and std rather expand token nums
    MLP_DIST: False
    IS_DIST: False # Mcross distribution kl

EVAL:
  SPLIT: 'gtest'
  BATCH_SIZE: 1 # Evaluating Batch size
  NUM_WORKERS: 1 #12 # Evaluating Batch size

TEST:
  TEST_DIR: ''
  CHECKPOINTS: '' # Pretrained model path
  NUM_WORKERS: 1 #12 # Evaluating Batch size
  BATCH_SIZE: 1 # Evaluating Batch size
  REPLICATION_TIMES: 1 # Number of times to replicate the test
  SAVE_PREDICTIONS: False # Weather to save predictions
  COUNT_TIME: False # Weather to count time during test

model:
  target: 'modules'

LOSS:
  
METRIC:
  None
DATASET:
  VOCASET:
    NONE: none

DEMO:
  EAMPLE: null
  ID: null
  CHECKPOINTS: "templates"
  TEMPLATE: "datasets/vocaset/templates.pkl"
  PLY: "datasets/vocaset/templates/FLAME_sample.ply"
  FPS: 30


LOGGER:
  SACE_CHECKPOINT_EPOCH: 1
  LOG_EVERY_STEPS: 1
  VAL_EVERY_STEPS: 10
  TENSORBOARD: true
  WANDB:
    OFFLINE: false
    PROJECT: null
    RESUME_ID: null


================================================
FILE: configs/diffusion/biwi/diffspeaker_hubert_biwi.yaml
================================================
NAME: diffspeaker_hubert_biwi # Experiment name
DEBUG: False # Debug mode
ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
DEVICE: [0, 1, 2, 3, 4, 5, 6, 7] # Index of gpus eg. [0] or [0,1,2,3]

# Training configuration
TRAIN:
  #---------------------------------
  DATASETS: ['biwi'] # Training datasets
  NUM_WORKERS: 1 # Number of workers
  BATCH_SIZE: 8 # Size of batches
  START_EPOCH: 0 # Start epochMMOTIONENCODER
  END_EPOCH: 500 # End epoch
  RESUME: '' # Resume training from this path
  OPTIM:
    TYPE: AdamW # Optimizer type
    LR: 1e-4 # Learning rate

# Evaluating Configuration
EVAL:
  DATASETS: ['biwi'] # Evaluating datasets
  BATCH_SIZE: 8 # Evaluating Batch size

# Datasets Configuration
DATASET:
  JOINT_TYPE: 'biwi' # join type

TEST:
  CHECKPOINTS: checkpoints/biwi/diffspeaker_hubert_biwi.ckpt # Pretrained model path
  DATASETS: ['biwi'] # training datasets
  BATCH_SIZE: 1 # training Batch size
  SPLIT: test # split type
  REPLICATION_TIMES: 10 # replication times

# Losses Configuration
LOSS:
  TYPE: voca # Losses type
  VERTICE_ENC: 1 # Lambda for vertices reconstruction Losses
  VERTICE_ENC_V: 1 # lambda for vertices velocity reconstruction loss
  LIP_ENC: 0 # lambda for lip reconstruction loss
  LIP_ENC_V: 0 # lambda for lip velocity reconstruction loss
  DIST_SYNC_ON_STEP: True # Sync Losses on step when distributed trained

audio_encoder:
  train_audio_encoder: True
  model_name_or_path: "facebook/hubert-base-ls960"

# Model Configuration
model:
  target: 'diffusion/diffusion_bias_modules'
  audio_encoded_dim: 768 # audio hidden dimension
  model_type: diffusion_bias # model type
  latent_dim: 1024 # latent dimension
  id_dim: 6 # the dimension of the id vector
  ff_size: 2048 # latent_dim * 2
  num_layers: 1 # number of layers
  num_heads: 4 # number of head layers
  dropout: 0.1 # dropout rate
  max_len: 600 # the attention mask maximum length
  activation: gelu # activation type
  normalize_before: True 
  require_start_token: True # start_token is need for autogressive generation only
  arch: 'default'
  predict_epsilon: False # noise or motion, motion here
  freq_shift: 0
  flip_sin_to_cos: True
  mem_attn_scale: 1.
  tgt_attn_scale: 1.
  audio_fps: 50
  hidden_fps: 25 # 30
  guidance_scale: 0 # not used
  guidance_uncondp: 0. # not used
  period: 25
  no_cross: False
  smooth_output: False

# rewrite the template
DEMO:
  EAMPLE: null
  ID: null
  TEMPLATE: "datasets/biwi/templates.pkl"
  PLY: "datasets/biwi/templates/BIWI.ply"
  FPS: 25

# Logger configuration
LOGGER:
  SACE_CHECKPOINT_EPOCH: 100
  LOG_EVERY_STEPS: 10
  VAL_EVERY_STEPS: 100 # 200
  TENSORBOARD: True
  WANDB:
    PROJECT: null
    OFFLINE: False
    RESUME_ID: null

================================================
FILE: configs/diffusion/biwi/diffspeaker_wav2vec2_biwi.yaml
================================================
NAME: diffspeaker_wav2vec2_biwi # Experiment name
DEBUG: False # Debug mode
ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
DEVICE: [0, 1, 2, 3, 4, 5, 6, 7] # Index of gpus eg. [0] or [0,1,2,3]

# Training configuration
TRAIN:
  #---------------------------------
  DATASETS: ['biwi'] # Training datasets
  NUM_WORKERS: 1 # Number of workers
  BATCH_SIZE: 16 # Size of batches
  START_EPOCH: 0 # Start epochMMOTIONENCODER
  END_EPOCH: 700 # End epoch
  RESUME: '' # Resume training from this path
  OPTIM:
    TYPE: AdamW # Optimizer type
    LR: 1e-4 # Learning rate

# Evaluating Configuration
EVAL:
  DATASETS: ['biwi'] # Evaluating datasets
  BATCH_SIZE: 16 # Evaluating Batch size

# Datasets Configuration
DATASET:
  JOINT_TYPE: 'biwi' # join type

TEST:
  CHECKPOINTS: checkpoints/biwi/diffspeaker_wav2vec2_biwi.ckpt #experiments/biwi/diffusion_bias/diffspeaker_wav2vec2_biwi/checkpoints/epoch=699.ckpt # Pretrained model path
  DATASETS: ['biwi'] # training datasets
  BATCH_SIZE: 1 # training Batch size
  SPLIT: test # split type
  REPLICATION_TIMES: 10 # replication times for each test sample

# Losses Configuration
LOSS:
  TYPE: voca # Losses type
  VERTICE_ENC: 1 # Lambda for vertices reconstruction Losses
  VERTICE_ENC_V: 1 # lambda for vertices velocity reconstruction loss
  LIP_ENC: 0 # lambda for lip reconstruction loss
  LIP_ENC_V: 0 # lambda for lip velocity reconstruction loss
  DIST_SYNC_ON_STEP: True # Sync Losses on step when distributed trained

audio_encoder:
  train_audio_encoder: True
  model_name_or_path: 'facebook/wav2vec2-base-960h'
  
# Model Configuration
model:
  target: 'diffusion/diffusion_bias_modules'
  audio_encoded_dim: 768 # audio hidden dimension
  model_type: diffusion_bias # model type
  latent_dim: 1024 # latent dimension
  id_dim: 6 # the dimension of the id vector
  ff_size: 2048 # latent_dim * 2
  num_layers: 1 # number of layers
  num_heads: 4 # number of head layers
  dropout: 0.1 # dropout rate
  max_len: 600 # the attention mask maximum length
  activation: gelu # activation type
  normalize_before: True 
  require_start_token: True # start_token is need for autogressive generation only
  arch: 'default'
  predict_epsilon: False # noise or motion, motion here
  freq_shift: 0
  flip_sin_to_cos: True
  mem_attn_scale: 1.
  tgt_attn_scale: 1.
  audio_fps: 50
  hidden_fps: 25 # 30
  guidance_scale: 0 # not used
  guidance_uncondp: 0. # not used
  period: 25
  no_cross: False
  smooth_output: False
  
# rewrite the template
DEMO:
  EAMPLE: null
  ID: null
  TEMPLATE: "datasets/biwi/templates.pkl"
  PLY: "datasets/biwi/templates/BIWI.ply"
  FPS: 25

# Logger configuration
LOGGER:
  SACE_CHECKPOINT_EPOCH: 100
  LOG_EVERY_STEPS: 10
  VAL_EVERY_STEPS: 100 # 200
  TENSORBOARD: True
  WANDB:
    PROJECT: null
    OFFLINE: False
    RESUME_ID: null

================================================
FILE: configs/diffusion/diffusion_bias_modules/denoiser.yaml
================================================
denoiser: # this is copied from configs/baselines/transformer_adpt_modules/transformer.yaml
  target: alm.models.architectures.adpt_bias_denoiser.Adpt_Bias_Denoiser
  params:
    audio_encoded_dim: ${model.audio_encoded_dim}
    ff_size: ${model.ff_size}
    num_layers: ${model.num_layers}
    num_heads: ${model.num_heads}
    dropout: ${model.dropout}
    normalize_before: ${model.normalize_before}
    activation: ${model.activation}
    return_intermediate_dec: False
    arch: ${model.arch}
    latent_dim: ${model.latent_dim}
    nfeats: ${DATASET.NFEATS}
    freq_shift: ${model.freq_shift}
    flip_sin_to_cos: ${model.flip_sin_to_cos}
    max_len: 3000 # the attention mask maximum length
    id_dim: ${model.id_dim} # the number of identities
    require_start_token: ${model.require_start_token} # start_token is need for autogressive generation only
    mem_attn_scale: ${model.mem_attn_scale}
    tgt_attn_scale: ${model.tgt_attn_scale}
    audio_fps: ${model.audio_fps} # 
    hidden_fps: ${model.hidden_fps}
    # unconditional generation
    guidance_scale: ${model.guidance_scale}
    guidance_uncondp: ${model.guidance_uncondp}
    period: ${model.period}
    no_cross: ${model.no_cross}

================================================
FILE: configs/diffusion/diffusion_bias_modules/scheduler.yaml
================================================
scheduler:
  target: diffusers.DDIMScheduler
  num_inference_timesteps: 50
  eta: 0.0
  params:
    num_train_timesteps: 1000
    beta_start: 0.00085
    beta_end: 0.012
    beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
    # variance_type: 'fixed_small'
    clip_sample: false # clip sample to -1~1
    prediction_type: 'sample'
    # below are for ddim
    set_alpha_to_one: false
    steps_offset: 1

noise_scheduler:
  target: diffusers.DDPMScheduler
  params:
    num_train_timesteps: 1000
    beta_start: 0.00085
    beta_end: 0.012
    beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
    variance_type: 'fixed_small'
    prediction_type: 'sample'
    clip_sample: false # 


================================================
FILE: configs/diffusion/vocaset/diffspeaker_hubert_vocaset.yaml
================================================
NAME: diffspeaker_hubert_vocaset # Experiment name
DEBUG: False # Debug mode
ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3]

# Training configuration
TRAIN:
  #---------------------------------
  DATASETS: ['vocaset'] # Training datasets
  NUM_WORKERS: 1 # Number of workers
  BATCH_SIZE: 32 # Size of batches
  START_EPOCH: 0 # Start epochMMOTIONENCODER
  END_EPOCH: 10000 # End epoch
  RESUME: '' # Resume training from this path
  OPTIM:
    TYPE: AdamW # Optimizer type
    LR: 1e-4 # Learning rate

# Evaluating Configuration
EVAL:
  DATASETS: ['vocaset'] # Evaluating datasets
  BATCH_SIZE: 32 # Evaluating Batch size

# Datasets Configuration
DATASET:
  JOINT_TYPE: 'vocaset' # join type

TEST:
  CHECKPOINTS: checkpoints/vocaset/diffspeaker_hubert_vocaset.ckpt # Pretrained model path
  DATASETS: ['vocaset'] # training datasets
  BATCH_SIZE: 1 # training Batch size
  SPLIT: test # split type
  REPLICATION_TIMES: 10 # replication times for each test sample

# Losses Configuration
LOSS:
  TYPE: voca # Losses type
  VERTICE_ENC: 1 # Lambda for vertices reconstruction Losses
  VERTICE_ENC_V: 1 # lambda for vertices velocity reconstruction loss
  LIP_ENC: 0 # lambda for lip reconstruction loss
  LIP_ENC_V: 0 # lambda for lip velocity reconstruction loss
  DIST_SYNC_ON_STEP: True # Sync Losses on step when distributed trained

audio_encoder:
  train_audio_encoder: True
  model_name_or_path: 'facebook/wav2vec2-base-960h'
  
# Model Configuration
model:
  target: 'diffusion/diffusion_bias_modules'
  audio_encoded_dim: 768 # audio hidden dimension
  model_type: diffusion_bias # model type
  latent_dim: 512 # latent dimension
  id_dim: 8 # the dimension of the id vector
  ff_size: 1024 # latent_dim * 2
  num_layers: 1 # number of layers
  num_heads: 4 # number of head layers
  dropout: 0.1 # dropout rate
  max_len: 600 # the attention mask maximum length
  activation: gelu # activation type
  normalize_before: True 
  require_start_token: True # start_token is need for autogressive generation only
  arch: 'default'
  predict_epsilon: False # noise or motion, motion here
  freq_shift: 0
  flip_sin_to_cos: True
  mem_attn_scale: 1.
  tgt_attn_scale: 1.
  audio_fps: 50
  hidden_fps: 30
  guidance_scale: 0 # not used
  guidance_uncondp: 0. # not used
  period: 30
  no_cross: False
  smooth_output: True


# Logger configuration
LOGGER:
  SACE_CHECKPOINT_EPOCH: 1000
  LOG_EVERY_STEPS: 100
  VAL_EVERY_STEPS: 1000 # 200
  TENSORBOARD: True
  WANDB:
    PROJECT: null
    OFFLINE: False
    RESUME_ID: null

================================================
FILE: configs/diffusion/vocaset/diffspeaker_wav2vec2_vocaset.yaml
================================================
NAME: diffspeaker_wav2vec2_vocaset # Experiment name
DEBUG: False # Debug mode
ACCELERATOR: 'gpu' # Devices optioncal: “cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps, “auto”
DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3]

# Training configuration
TRAIN:
  #---------------------------------
  DATASETS: ['vocaset'] # Training datasets
  NUM_WORKERS: 1 # Number of workers
  BATCH_SIZE: 32 # Size of batches
  START_EPOCH: 0 # Start epochMMOTIONENCODER
  END_EPOCH: 9000 # End epoch
  RESUME: '' # Resume training from this path
  OPTIM:
    TYPE: AdamW # Optimizer type
    LR: 1e-4 # Learning rate

# Evaluating Configuration
EVAL:
  DATASETS: ['vocaset'] # Evaluating datasets
  BATCH_SIZE: 32 # Evaluating Batch size

# Datasets Configuration
DATASET:
  JOINT_TYPE: 'vocaset' # join type

TEST:
  CHECKPOINTS: checkpoints/vocaset/diffspeaker_wav2vec2_vocaset.ckpt # Pretrained model path
  DATASETS: ['vocaset'] # training datasets
  BATCH_SIZE: 1 # training Batch size
  SPLIT: test # split type
  REPLICATION_TIMES: 10 # replication times for each test sample

# Losses Configuration
LOSS:
  TYPE: voca # Losses type
  VERTICE_ENC: 1 # Lambda for vertices reconstruction Losses
  VERTICE_ENC_V: 1 # lambda for vertices velocity reconstruction loss
  LIP_ENC: 0 # lambda for lip reconstruction loss
  LIP_ENC_V: 0 # lambda for lip velocity reconstruction loss
  DIST_SYNC_ON_STEP: True # Sync Losses on step when distributed trained

audio_encoder:
  train_audio_encoder: True
  model_name_or_path: 'facebook/wav2vec2-base-960h'
  
# Model Configuration
model:
  target: 'diffusion/diffusion_bias_modules'
  audio_encoded_dim: 768 # audio hidden dimension
  model_type: diffusion_bias # model type
  latent_dim: 512 # latent dimension
  id_dim: 8 # the dimension of the id vector
  ff_size: 1024 # latent_dim * 2
  num_layers: 1 # number of layers
  num_heads: 4 # number of head layers
  dropout: 0.1 # dropout rate
  max_len: 600 # the attention mask maximum length
  activation: gelu # activation type
  normalize_before: True 
  require_start_token: True # start_token is need for autogressive generation only
  arch: 'default'
  predict_epsilon: False # noise or motion, motion here
  freq_shift: 0
  flip_sin_to_cos: True
  mem_attn_scale: 1.
  tgt_attn_scale: 1.
  audio_fps: 50
  hidden_fps: 30
  guidance_scale: 0 # not used
  guidance_uncondp: 0. # not used
  period: 30
  no_cross: False
  smooth_output: True


# Logger configuration
LOGGER:
  SACE_CHECKPOINT_EPOCH: 1000
  LOG_EVERY_STEPS: 100
  VAL_EVERY_STEPS: 1000 # 200
  TENSORBOARD: True
  WANDB:
    PROJECT: null
    OFFLINE: False
    RESUME_ID: null

================================================
FILE: datasets/biwi/README.md
================================================
should contanin
```
    regions
    templates
    templates.pkl
    vertices_npy
    wav
```

================================================
FILE: datasets/biwi/regions/fdd.txt
================================================
7, 9, 59, 89, 94, 103, 122, 126, 152, 159, 160, 161, 164, 166, 169, 174, 177, 178, 179, 180, 182, 183, 184, 185, 195, 216, 217, 221, 223, 224, 226, 227, 230, 232, 233, 234, 235, 236, 238, 240, 241, 243, 244, 259, 270, 286, 300, 393, 394, 397, 400, 415, 416, 417, 418, 419, 422, 424, 440, 446, 447, 448, 449, 450, 452, 454, 456, 457, 459, 461, 462, 467, 468, 469, 470, 472, 473, 476, 477, 478, 497, 498, 499, 506, 508, 511, 513, 519, 525, 553, 588, 594, 601, 610, 613, 622, 627, 633, 635, 637, 642, 643, 645, 647, 663, 665, 681, 682, 690, 702, 705, 706, 712, 713, 714, 715, 716, 717, 718, 721, 727, 728, 730, 734, 735, 737, 738, 739, 741, 743, 744, 745, 746, 747, 748, 751, 753, 754, 756, 757, 759, 760, 763, 765, 767, 769, 776, 777, 778, 779, 780, 781, 783, 784, 788, 789, 790, 791, 795, 796, 797, 801, 802, 805, 806, 810, 813, 815, 830, 831, 832, 836, 839, 840, 841, 851, 860, 877, 883, 885, 888, 898, 900, 921, 922, 924, 925, 939, 940, 942, 947, 952, 956, 957, 968, 969, 971, 973, 975, 976, 977, 982, 983, 985, 986, 988, 990, 995, 999, 1000, 1001, 1002, 1003, 1005, 1008, 1009, 1011, 1012, 1013, 1014, 1015, 1019, 1020, 1021, 1022, 1025, 1027, 1033, 1042, 1107, 1110, 1114, 1117, 1132, 1134, 1141, 1142, 1147, 1149, 1150, 1152, 1155, 1156, 1158, 1160, 1161, 1163, 1166, 1167, 1168, 1170, 1171, 1172, 1175, 1186, 1208, 1211, 1242, 1250, 1258, 1261, 1264, 1304, 1310, 1317, 1319, 1372, 1388, 1398, 1400, 1402, 1403, 1405, 1408, 1409, 1410, 1411, 1412, 1416, 1418, 1420, 1424, 1426, 1433, 1435, 1436, 1439, 1441, 1445, 1446, 1456, 1459, 1476, 1483, 1490, 1491, 1506, 1507, 1510, 1513, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1540, 1549, 1550, 1552, 1554, 1558, 1559, 1561, 1562, 1563, 1564, 1565, 1571, 1572, 1573, 1574, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1585, 1590, 1591, 1593, 1594, 1595, 1596, 1597, 1598, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1611, 1613, 1615, 1620, 1622, 1625, 1626, 1627, 1628, 1629, 1631, 1638, 1639, 1641, 1643, 1645, 1652, 1653, 1655, 1660, 1661, 1670, 1672, 1676, 1678, 1680, 1682, 1683, 1685, 1686, 1690, 1693, 1696, 1698, 1701, 1704, 1705, 1706, 1707, 1708, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1718, 1722, 1723, 1724, 1726, 1727, 1729, 1730, 1734, 1737, 1738, 1743, 1744, 1746, 1749, 1751, 1752, 1756, 1757, 1760, 1763, 1767, 1768, 1769, 1772, 1778, 1779, 1782, 1785, 1786, 1787, 1789, 1793, 1795, 1796, 1802, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1820, 1825, 1826, 1831, 1832, 1841, 1847, 1850, 1851, 1853, 1854, 1859, 1860, 1861, 1862, 1866, 1867, 1868, 1869, 1873, 1875, 1881, 1883, 1886, 1887, 1889, 1892, 1896, 1902, 1903, 1904, 1909, 1923, 1941, 1949, 1964, 1966, 1967, 1968, 1971, 1972, 1974, 1976, 1977, 1978, 1979, 1980, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1991, 1992, 1993, 1994, 1999, 2004, 2009, 2011, 2012, 2013, 2016, 2017, 2021, 2022, 2024, 2025, 2026, 2034, 2043, 2044, 2047, 2051, 2053, 2055, 2064, 2069, 2073, 2074, 2078, 2081, 2082, 2085, 2090, 2092, 2099, 2104, 2106, 2107, 2109, 2123, 2130, 2134, 2138, 2142, 2159, 2160, 2173, 2215, 2243, 2244, 2245, 2246, 2247, 2248, 2251, 2253, 2256, 2258, 2263, 2266, 2267, 2268, 2273, 2274, 2275, 2276, 2277, 2279, 2280, 2281, 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2291, 2293, 2294, 2296, 2297, 2298, 2299, 2300, 2303, 2304, 2305, 2306, 2307, 2308, 2309, 2310, 2311, 2312, 2313, 2314, 2315, 2316, 2317, 2318, 2322, 2325, 2327, 2328, 2330, 2331, 2333, 2336, 2337, 2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347, 2348, 2350, 2351, 2352, 2353, 2356, 2357, 2358, 2360, 2361, 2362, 2363, 2364, 2365, 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2375, 2378, 2380, 2381, 2383, 2384, 2385, 2386, 2390, 2391, 2394, 2395, 2399, 2404, 2406, 2407, 2408, 2409, 2410, 2411, 2413, 2414, 2416, 2417, 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2433, 2437, 2440, 2441, 2445, 2447, 2452, 2455, 2456, 2458, 2459, 2460, 2464, 2466, 2469, 2470, 2471, 2474, 2478, 2479, 2485, 2487, 2488, 2490, 2494, 2496, 2497, 2503, 2505, 2508, 2509, 2512, 2516, 2517, 2519, 2520, 2522, 2523, 2524, 2528, 2530, 2531, 2532, 2533, 2535, 2536, 2537, 2539, 2542, 2543, 2546, 2547, 2548, 2552, 2554, 2564, 2567, 2583, 2593, 2602, 2605, 2608, 2609, 2610, 2612, 2614, 2615, 2617, 2620, 2622, 2624, 2625, 2627, 2628, 2629, 2632, 2633, 2634, 2637, 2699, 2723, 2724, 2725, 2738, 2741, 2743, 2755, 2756, 2776, 2779, 2780, 2796, 2797, 2798, 2799, 2803, 2808, 2812, 2813, 2819, 2829, 2842, 2844, 2846, 2851, 2852, 2853, 2855, 2857, 2861, 2862, 2868, 2874, 2879, 2881, 2882, 2883, 2885, 2893, 2899, 2944, 2961, 2964, 2965, 2970, 2971, 2972, 2973, 2978, 2982, 2983, 2984, 2985, 2986, 2988, 2991, 2995, 2996, 2998, 3005, 3020, 3021, 3027, 3028, 3030, 3031, 3036, 3039, 3041, 3044, 3045, 3046, 3048, 3051, 3056, 3057, 3060, 3065, 3068, 3069, 3073, 3079, 3083, 3085, 3088, 3092, 3096, 3102, 3118, 3125, 3127, 3128, 3131, 3133, 3140, 3141, 3142, 3144, 3147, 3149, 3151, 3152, 3153, 3164, 3165, 3170, 3171, 3172, 3174, 3178, 3184, 3186, 3188, 3190, 3192, 3193, 3195, 3196, 3197, 3198, 3200, 3201, 3203, 3204, 3205, 3206, 3207, 3209, 3211, 3212, 3213, 3214, 3216, 3217, 3220, 3222, 3223, 3224, 3225, 3227, 3228, 3230, 3231, 3232, 3233, 3234, 3236, 3240, 3244, 3245, 3247, 3248, 3249, 3250, 3257, 3258, 3259, 3260, 3261, 3262, 3268, 3269, 3270, 3274, 3275, 3279, 3282, 3283, 3287, 3291, 3292, 3293, 3294, 3295, 3298, 3300, 3302, 3303, 3308, 3312, 3313, 3314, 3316, 3317, 3319, 3320, 3324, 3325, 3327, 3330, 3331, 3333, 3339, 3340, 3343, 3349, 3352, 3353, 3355, 3356, 3357, 3358, 3359, 3362, 3365, 3366, 3368, 3369, 3370, 3371, 3374, 3378, 3379, 3380, 3381, 3383, 3384, 3386, 3387, 3388, 3389, 3390, 3391, 3394, 3398, 3399, 3401, 3402, 3404, 3413, 3414, 3415, 3421, 3422, 3423, 3429, 3430, 3432, 3434, 3436, 3437, 3438, 3439, 3441, 3442, 3444, 3445, 3446, 3451, 3453, 3454, 3455, 3456, 3458, 3461, 3462, 3463, 3464, 3465, 3467, 3468, 3470, 3471, 3472, 3473, 3476, 3478, 3482, 3484, 3485, 3491, 3493, 3494, 3495, 3500, 3502, 3503, 3506, 3507, 3508, 3511, 3512, 3514, 3516, 3520, 3524, 3525, 3526, 3527, 3528, 3531, 3538, 3540, 3541, 3542, 3544, 3546, 3550, 3551, 3552, 3565, 3573, 3581, 3593, 3594, 3606, 3608, 3651, 3654, 3656, 3658, 3663, 3665, 3668, 3679, 3698, 3709, 3710, 3714, 3715, 3718, 3719, 3721, 3733, 3735, 3760, 3769, 3772, 3773, 3778, 3779, 3780, 3783, 3784, 3786, 3788, 3791, 3792, 3793, 3794, 3795, 3797, 3801, 3828, 3830, 3832, 3838, 3839, 3843, 3845, 3846, 3851, 3856, 3862, 3865, 3870, 3872, 3874, 3878, 3879, 3886, 3887, 3889, 3892, 3893, 3904, 3905, 3906, 3912, 3914, 3918, 3919, 3923, 3927, 3938, 3944, 3948, 3959, 3962, 3963, 3965, 3966, 3969, 3973, 3985, 3986, 3991, 3996, 3999, 4003, 4004, 4006, 4007, 4010, 4012, 4014, 4015, 4017, 4020, 4021, 4023, 4026, 4028, 4029, 4030, 4032, 4035, 4038, 4040, 4043, 4058, 4059, 4061, 4063, 4064, 4065, 4068, 4071, 4072, 4073, 4077, 4078, 4079, 4080, 4082, 4084, 4085, 4087, 4088, 4090, 4094, 4098, 4099, 4100, 4101, 4107, 4115, 4122, 4123, 4124, 4126, 4127, 4128, 4129, 4147, 4150, 4153, 4154, 4155, 4156, 4157, 4158, 4161, 4165, 4182, 4183, 4184, 4187, 4190, 4191, 4192, 4194, 4197, 4199, 4204, 4206, 4208, 4209, 4210, 4215, 4216, 4218, 4220, 4221, 4222, 4225, 4226, 4227, 4228, 4232, 4235, 4236, 4237, 4238, 4240, 4241, 4244, 4247, 4252, 4253, 4255, 4258, 4259, 4260, 4261, 4264, 4266, 4268, 4272, 4274, 4276, 4278, 4279, 4280, 4281, 4284, 4287, 4288, 4289, 4290, 4293, 4294, 4295, 4297, 4298, 4300, 4301, 4302, 4303, 4304, 4305, 4306, 4307, 4308, 4309, 4310, 4311, 4313, 4315, 4316, 4317, 4318, 4319, 4322, 4324, 4326, 4327, 4328, 4331, 4335, 4337, 4338, 4346, 4352, 4355, 4358, 4359, 4360, 4361, 4365, 4366, 4367, 4368, 4369, 4370, 4371, 4372, 4373, 4374, 4375, 4376, 4380, 4383, 4388, 4391, 4394, 4395, 4396, 4398, 4399, 4409, 4412, 4416, 4418, 4419, 4425, 4426, 4431, 4435, 4441, 4443, 4452, 4454, 4455, 4462, 4463, 4464, 4465, 4467, 4468, 4475, 4478, 4479, 4480, 4481, 4483, 4484, 4485, 4486, 4487, 4488, 4492, 4495, 4497, 4498, 4502, 4503, 4507, 4508, 4509, 4511, 4512, 4514, 4515, 4516, 4517, 4518, 4519, 4520, 4523, 4524, 4525, 4526, 4527, 4528, 4529, 4530, 4531, 4532, 4533, 4534, 4535, 4536, 4537, 4538, 4539, 4541, 4542, 4543, 4544, 4545, 4546, 4548, 4549, 4552, 4553, 4554, 4555, 4556, 4557, 4558, 4559, 4560, 4561, 4562, 4563, 4564, 4565, 4566, 4567, 4569, 4571, 4572, 4573, 4574, 4575, 4576, 4577, 4578, 4579, 4580, 4581, 4582, 4583, 4584, 4590, 4591, 4592, 4594, 4595, 4598, 4600, 4602, 4604, 4605, 4606, 4608, 4611, 4616, 4617, 4618, 4619, 4631, 4632, 4633, 4634, 4635, 4636, 4637, 4638, 4641, 4642, 4645, 4648, 4649, 4651, 4656, 4657, 4664, 4666, 4667, 4668, 4673, 4675, 4678, 4679, 4684, 4687, 4690, 4693, 4695, 4696, 4697, 4700, 4702, 4704, 4707, 4709, 4710, 4712, 4715, 4719, 4723, 4730, 4733, 4735, 4736, 4737, 4739, 4740, 4743, 4744, 4745, 4747, 4750, 4754, 4755, 4756, 4758, 4759, 4761, 4763, 4765, 4766, 4768, 4769, 4770, 4771, 4772, 4773, 4774, 4775, 4776, 4777, 4778, 4779, 4780, 4781, 4782, 4783, 4784, 4785, 4786, 4787, 4788, 4789, 4790, 4791, 4794, 4795, 4800, 4802, 4804, 4805, 4806, 4807, 4810, 4813, 4814, 4816, 4817, 4818, 4819, 4821, 4823, 4825, 4826, 4827, 4828, 4829, 4830, 4831, 4832, 4834, 4835, 4836, 4837, 4838, 4839, 4840, 4841, 4842, 4845, 4848, 4849, 4852, 4854, 4855, 4857, 4858, 4860, 4861, 4863, 4864, 4866, 4867, 4868, 4869, 4870, 4871, 4872, 4873, 4874, 4875, 4877, 4878, 4880, 4881, 4887, 4889, 4890, 4891, 4892, 4893, 4894, 4895, 4896, 4897, 4898, 4899, 4900, 4901, 4902, 4903, 4905, 4906, 4909, 4910, 4914, 4918, 4920, 4922, 4924, 4926, 4928, 4929, 4931, 4946, 4953, 4998, 4999, 5001, 5003, 5005, 5009, 5010, 5012, 5013, 5014, 5016, 5017, 5019, 5023, 5024, 5028, 5030, 5032, 5034, 5036, 5039, 5042, 5049, 5051, 5053, 5057, 5063, 5074, 5083, 5087, 5089, 5093, 5094, 5095, 5099, 5106, 5109, 5115, 5116, 5118, 5119, 5121, 5124, 5127, 5131, 5132, 5175, 5176, 5177, 5178, 5179, 5182, 5185, 5188, 5191, 5192, 5194, 5203, 5205, 5214, 5216, 5217, 5223, 5233, 5235, 5237, 5240, 5242, 5243, 5245, 5250, 5251, 5257, 5263, 5266, 5269, 5270, 5271, 5274, 5280, 5281, 5285, 5286, 5287, 5291, 5293, 5295, 5298, 5299, 5301, 5302, 5303, 5304, 5305, 5306, 5307, 5308, 5309, 5310, 5316, 5319, 5320, 5323, 5324, 5328, 5329, 5330, 5331, 5333, 5334, 5335, 5337, 5338, 5339, 5342, 5345, 5346, 5347, 5348, 5350, 5351, 5353, 5355, 5356, 5357, 5358, 5359, 5361, 5362, 5364, 5366, 5367, 5376, 5390, 5402, 5404, 5408, 5415, 5423, 5424, 5425, 5428, 5433, 5435, 5439, 5441, 5442, 5444, 5445, 5455, 5459, 5464, 5465, 5471, 5475, 5477, 5479, 5483, 5488, 5489, 5499, 5506, 5510, 5514, 5515, 5516, 5518, 5520, 5521, 5522, 5525, 5526, 5527, 5528, 5529, 5533, 5536, 5538, 5539, 5540, 5543, 5549, 5551, 5554, 5555, 5558, 5559, 5563, 5564, 5566, 5567, 5568, 5570, 5571, 5573, 5574, 5575, 5578, 5579, 5580, 5581, 5582, 5583, 5584, 5585, 5587, 5588, 5589, 5590, 5591, 5593, 5594, 5596, 5597, 5599, 5600, 5602, 5603, 5605, 5606, 5607, 5608, 5611, 5615, 5617, 5620, 5622, 5623, 5625, 5631, 5635, 5643, 5678, 5680, 5682, 5684, 5694, 5711, 5712, 5714, 5718, 5725, 5728, 5730, 5732, 5736, 5740, 5748, 5751, 5759, 5761, 5763, 5766, 5768, 5770, 5772, 5787, 5791, 5795, 5798, 5802, 5809, 5813, 5818, 5822, 5848, 5854, 5855, 5858, 5859, 5861, 5863, 5864, 5865, 5867, 5875, 5877, 5881, 5882, 5886, 5888, 5890, 5892, 5893, 5895, 5896, 5901, 5902, 5905, 5908, 5913, 5914, 5919, 5923, 5925, 5926, 5927, 5928, 5929, 5932, 5933, 5934, 5940, 5941, 5946, 5948, 5949, 5950, 5951, 5952, 5953, 5955, 5959, 5960, 5961, 5962, 5964, 5965, 5966, 5967, 5968, 5969, 5971, 5972, 5973, 5974, 5975, 5976, 5977, 5978, 5979, 5980, 5981, 5983, 5985, 5986, 5987, 5988, 5989, 5991, 5992, 5993, 5994, 5995, 5996, 5997, 6004, 6007, 6008, 6012, 6015, 6016, 6017, 6022, 6031, 6033, 6035, 6037, 6038, 6039, 6040, 6041, 6042, 6046, 6047, 6048, 6049, 6050, 6053, 6055, 6057, 6058, 6061, 6063, 6065, 6066, 6067, 6070, 6072, 6074, 6075, 6079, 6080, 6081, 6082, 6083, 6085, 6087, 6088, 6089, 6090, 6091, 6092, 6093, 6094, 6096, 6097, 6098, 6106, 6111, 6112, 6113, 6114, 6116, 6122, 6123, 6124, 6126, 6127, 6131, 6132, 6133, 6134, 6135, 6136, 6137, 6138, 6139, 6140, 6141, 6143, 6145, 6146, 6147, 6148, 6149, 6150, 6151, 6152, 6153, 6157, 6158, 6159, 6160, 6161, 6164, 6169, 6170, 6171, 6172, 6174, 6175, 6176, 6177, 6180, 6181, 6182, 6183, 6184, 6185, 6187, 6191, 6192, 6193, 6195, 6196, 6198, 6200, 6204, 6206, 6209, 6211, 6212, 6215, 6219, 6221, 6222, 6225, 6235, 6236, 6241, 6244, 6245, 6247, 6248, 6249, 6251, 6253, 6255, 6257, 6259, 6260, 6265, 6273, 6274, 6276, 6280, 6282, 6283, 6284, 6285, 6286, 6287, 6288, 6289, 6291, 6293, 6294, 6295, 6298, 6299, 6300, 6311, 6314, 6334, 6335, 6336, 6338, 6340, 6343, 6351, 6353, 6354, 6357, 6359, 6360, 6361, 6364, 6368, 6372, 6380, 6384, 6391, 6395, 6401, 6402, 6403, 6406, 6412, 6414, 6415, 6416, 6418, 6419, 6421, 6425, 6426, 6427, 6428, 6429, 6430, 6434, 6435, 6436, 6443, 6446, 6450, 6452, 6453, 6454, 6459, 6460, 6461, 6463, 6466, 6472, 6479, 6490, 6491, 6492, 6496, 6497, 6498, 6500, 6501, 6503, 6504, 6508, 6509, 6510, 6511, 6512, 6515, 6519, 6520, 6522, 6523, 6527, 6534, 6536, 6538, 6540, 6543, 6545, 6546, 6549, 6552, 6553, 6554, 6562, 6565, 6567, 6568, 6569, 6571, 6573, 6575, 6576, 6578, 6579, 6581, 6584, 6586, 6587, 6591, 6593, 6597, 6598, 6604, 6611, 6615, 6616, 6617, 6618, 6621, 6622, 6625, 6626, 6627, 6629, 6630, 6639, 6640, 6642, 6643, 6644, 6645, 6646, 6652, 6655, 6659, 6665, 6673, 6675, 6677, 6679, 6680, 6698, 6710, 6737, 6746, 6747, 6748, 6749, 6754, 6760, 6764, 6771, 6781, 6782, 6792, 6794, 6804, 6815, 6822, 6831, 6842, 6853, 6855, 6856, 6859, 6864, 6867, 6871, 6876, 6879, 6882, 6883, 6887, 6890, 6913, 6914, 6917, 6928, 6945, 6956, 6957, 6959, 6960, 6962, 6964, 6965, 6968, 6970, 6971, 6972, 6974, 6975, 6976, 6977, 6981, 6982, 6983, 6984, 6986, 6988, 6989, 6990, 6991, 6992, 6993, 6994, 6995, 6996, 6997, 7005, 7006, 7012, 7015, 7020, 7024, 7025, 7026, 7027, 7028, 7029, 7030, 7031, 7032, 7033, 7034, 7035, 7036, 7037, 7038, 7039, 7040, 7042, 7043, 7044, 7045, 7046, 7047, 7050, 7051, 7052, 7053, 7054, 7055, 7056, 7057, 7058, 7059, 7060, 7061, 7062, 7063, 7064, 7065, 7068, 7069, 7070, 7071, 7072, 7073, 7075, 7078, 7081, 7083, 7084, 7086, 7088, 7093, 7094, 7095, 7096, 7097, 7098, 7099, 7100, 7101, 7103, 7105, 7106, 7108, 7109, 7110, 7111, 7112, 7114, 7115, 7117, 7118, 7119, 7121, 7125, 7126, 7127, 7131, 7132, 7133, 7137, 7138, 7139, 7140, 7141, 7142, 7143, 7144, 7146, 7149, 7150, 7151, 7153, 7156, 7169, 7171, 7172, 7178, 7179, 7182, 7184, 7185, 7186, 7187, 7190, 7191, 7192, 7193, 7194, 7195, 7196, 7197, 7198, 7199, 7200, 7201, 7202, 7203, 7204, 7205, 7206, 7207, 7208, 7209, 7210, 7211, 7212, 7213, 7214, 7215, 7216, 7217, 7218, 7219, 7220, 7221, 7222, 7223, 7224, 7225, 7226, 7228, 7229, 7234, 7235, 7239, 7240, 7243, 7245, 7246, 7247, 7249, 7251, 7253, 7254, 7256, 7262, 7265, 7269, 7270, 7280, 7282, 7283, 7284, 7285, 7286, 7287, 7288, 7289, 7290, 7291, 7292, 7293, 7294, 7295, 7296, 7297, 7298, 7299, 7300, 7301, 7302, 7303, 7304, 7305, 7306, 7307, 7309, 7311, 7312, 7313, 7314, 7315, 7316, 7317, 7318, 7322, 7323, 7324, 7326, 7327, 7329, 7330, 7335, 7336, 7341, 7343, 7344, 7345, 7346, 7348, 7350, 7353, 7354, 7356, 7358, 7360, 7361, 7362, 7363, 7365, 7370, 7372, 7373, 7375, 7376, 7377, 7378, 7379, 7381, 7382, 7385, 7386, 7387, 7388, 7389, 7390, 7391, 7392, 7393, 7394, 7395, 7396, 7397, 7398, 7399, 7400, 7401, 7402, 7403, 7404, 7405, 7406, 7407, 7408, 7409, 7410, 7411, 7412, 7413, 7414, 7415, 7416, 7417, 7418, 7419, 7420, 7421, 7422, 7424, 7425, 7426, 7428, 7430, 7432, 7434, 7435, 7436, 7437, 7438, 7439, 7441, 7442, 7443, 7444, 7449, 7451, 7453, 7455, 7456, 7459, 7460, 7461, 7463, 7466, 7468, 7470, 7471, 7472, 7485, 7523, 7524, 7528, 7529, 7531, 7532, 7537, 7543, 7585, 7602, 7608, 7623, 7624, 7660, 7662, 7663, 7665, 7667, 7673, 7674, 7678, 7680, 7683, 7684, 7685, 7686, 7687, 7690, 7691, 7694, 7695, 7697, 7698, 7744, 7791, 7792, 7793, 7795, 7799, 7800, 7802, 7804, 7807, 7810, 7812, 7814, 7822, 7823, 7827, 7830, 7831, 7839, 7842, 7846, 7849, 7852, 7857, 7859, 7861, 7865, 7866, 7867, 7871, 7872, 7879, 7885, 7886, 7887, 7891, 7892, 7893, 7895, 7897, 7899, 7901, 7902, 7914, 7923, 7924, 7929, 7932, 7933, 7934, 7936, 7937, 7948, 7951, 7954, 7956, 7957, 7958, 7962, 7965, 7966, 7973, 7979, 7984, 7986, 7989, 7990, 7998, 8006, 8007, 8010, 8014, 8015, 8017, 8020, 8022, 8023, 8025, 8026, 8029, 8030, 8031, 8036, 8039, 8041, 8042, 8043, 8045, 8047, 8048, 8049, 8051, 8052, 8053, 8054, 8055, 8056, 8057, 8058, 8059, 8063, 8064, 8065, 8067, 8068, 8069, 8070, 8071, 8072, 8073, 8074, 8075, 8077, 8078, 8081, 8084, 8085, 8086, 8087, 8088, 8089, 8090, 8093, 8094, 8095, 8097, 8102, 8108, 8126, 8128, 8131, 8132, 8134, 8135, 8138, 8139, 8159, 8162, 8163, 8173, 8174, 8177, 8178, 8183, 8186, 8187, 8188, 8190, 8195, 8201, 8216, 8229, 8231, 8234, 8236, 8238, 8245, 8247, 8250, 8252, 8253, 8256, 8257, 8258, 8260, 8261, 8262, 8269, 8271, 8279, 8280, 8281, 8286, 8289, 8292, 8294, 8298, 8300, 8301, 8302, 8305, 8306, 8321, 8323, 8325, 8328, 8333, 8335, 8352, 8354, 8357, 8359, 8362, 8363, 8364, 8365, 8366, 8369, 8370, 8371, 8374, 8379, 8382, 8383, 8385, 8386, 8387, 8388, 8389, 8393, 8394, 8396, 8397, 8400, 8402, 8406, 8408, 8409, 8411, 8412, 8414, 8415, 8416, 8417, 8418, 8420, 8422, 8428, 8430, 8433, 8436, 8438, 8442, 8443, 8445, 8451, 8461, 8463, 8481, 8485, 8486, 8489, 8490, 8493, 8495, 8497, 8498, 8500, 8502, 8507, 8513, 8517, 8523, 8526, 8530, 8533, 8537, 8538, 8539, 8540, 8543, 8545, 8546, 8547, 8548, 8549, 8563, 8573, 8578, 8579, 8581, 8584, 8587, 8591, 8592, 8593, 8600, 8604, 8607, 8609, 8618, 8619, 8621, 8625, 8627, 8628, 8633, 8634, 8636, 8639, 8641, 8645, 8648, 8649, 8650, 8652, 8654, 8658, 8660, 8661, 8665, 8671, 8672, 8674, 8676, 8678, 8680, 8686, 8688, 8689, 8690, 8696, 8702, 8703, 8704, 8705, 8715, 8716, 8717, 8727, 8731, 8737, 8739, 8740, 8747, 8758, 8785, 8787, 8788, 8793, 8797, 8817, 8838, 8850, 8856, 8860, 8865, 8870, 8871, 8873, 8885, 8888, 8892, 8895, 8896, 8898, 8901, 8904, 8907, 8909, 8911, 8913, 8933, 8936, 8937, 8939, 8942, 8946, 8949, 8950, 8951, 8952, 8953, 8957, 8961, 8962, 8963, 8965, 8966, 8967, 8968, 8969, 8971, 8975, 8979, 8981, 8986, 8987, 8990, 9005, 9007, 9010, 9011, 9012, 9014, 9017, 9018, 9020, 9023, 9027, 9029, 9035, 9037, 9038, 9040, 9045, 9049, 9050, 9052, 9053, 9054, 9055, 9056, 9057, 9058, 9059, 9060, 9061, 9062, 9063, 9064, 9065, 9066, 9067, 9070, 9071, 9072, 9074, 9075, 9076, 9078, 9079, 9080, 9081, 9082, 9083, 9084, 9085, 9086, 9087, 9088, 9089, 9090, 9091, 9092, 9093, 9094, 9095, 9096, 9097, 9098, 9099, 9100, 9101, 9102, 9103, 9104, 9105, 9106, 9108, 9110, 9112, 9113, 9114, 9115, 9116, 9117, 9119, 9120, 9121, 9122, 9123, 9124, 9125, 9129, 9130, 9131, 9133, 9134, 9135, 9136, 9137, 9138, 9139, 9140, 9141, 9142, 9145, 9146, 9147, 9149, 9150, 9151, 9152, 9153, 9155, 9156, 9158, 9159, 9160, 9161, 9162, 9164, 9165, 9173, 9175, 9177, 9178, 9180, 9181, 9182, 9183, 9184, 9186, 9193, 9195, 9197, 9201, 9202, 9205, 9206, 9226, 9227, 9229, 9232, 9233, 9234, 9236, 9237, 9238, 9240, 9241, 9251, 9257, 9259, 9260, 9261, 9264, 9265, 9267, 9272, 9274, 9275, 9277, 9279, 9286, 9288, 9292, 9295, 9297, 9301, 9306, 9309, 9310, 9313, 9315, 9319, 9329, 9333, 9343, 9346, 9347, 9355, 9356, 9364, 9371, 9372, 9379, 9385, 9387, 9389, 9393, 9395, 9398, 9400, 9407, 9417, 9419, 9432, 9433, 9434, 9436, 9441, 9446, 9450, 9451, 9452, 9453, 9454, 9455, 9456, 9457, 9458, 9459, 9460, 9462, 9465, 9472, 9473, 9474, 9476, 9478, 9479, 9493, 9510, 9517, 9518, 9519, 9521, 9531, 9532, 9533, 9534, 9539, 9546, 9548, 9549, 9550, 9551, 9553, 9555, 9557, 9558, 9563, 9566, 9569, 9571, 9575, 9577, 9579, 9581, 9583, 9588, 9589, 9590, 9592, 9594, 9595, 9600, 9601, 9602, 9607, 9611, 9612, 9613, 9615, 9620, 9654, 9655, 9664, 9667, 9668, 9669, 9679, 9683, 9687, 9688, 9694, 9695, 9698, 9699, 9700, 9701, 9702, 9704, 9706, 9711, 9712, 9721, 9722, 9728, 9729, 9730, 9733, 9734, 9735, 9736, 9737, 9739, 9741, 9743, 9744, 9745, 9746, 9748, 9749, 9752, 9753, 9754, 9755, 9762, 9771, 9774, 9782, 9783, 9784, 9785, 9786, 9787, 9789, 9790, 9795, 9796, 9799, 9801, 9803, 9804, 9805, 9806, 9807, 9810, 9812, 9813, 9814, 9818, 9822, 9825, 9827, 9833, 9836, 9839, 9842, 9847, 9850, 9851, 9852, 9854, 9858, 9863, 9878, 9879, 9880, 9883, 9886, 9887, 9891, 9892, 9898, 9901, 9904, 9908, 9913, 9919, 9928, 9931, 9942, 9943, 9947, 9954, 9955, 9959, 9975, 9976, 9985, 9997, 10011, 10012, 10013, 10015, 10016, 10020, 10022, 10023, 10024, 10025, 10027, 10028, 10029, 10031, 10032, 10033, 10036, 10037, 10038, 10040, 10041, 10042, 10043, 10044, 10045, 10046, 10051, 10054, 10078, 10082, 10083, 10085, 10087, 10088, 10090, 10092, 10093, 10097, 10098, 10100, 10101, 10102, 10103, 10109, 10113, 10118, 10122, 10123, 10124, 10128, 10129, 10134, 10135, 10136, 10138, 10139, 10148, 10150, 10152, 10157, 10164, 10169, 10170, 10173, 10175, 10176, 10179, 10181, 10185, 10187, 10190, 10191, 10192, 10193, 10197, 10199, 10200, 10202, 10204, 10206, 10210, 10213, 10215, 10216, 10217, 10220, 10222, 10227, 10228, 10233, 10236, 10237, 10238, 10245, 10246, 10248, 10249, 10255, 10256, 10261, 10262, 10264, 10265, 10269, 10270, 10272, 10274, 10278, 10280, 10281, 10282, 10283, 10287, 10289, 10291, 10293, 10296, 10297, 10301, 10304, 10305, 10308, 10312, 10315, 10317, 10322, 10329, 10335, 10343, 10344, 10345, 10351, 10371, 10376, 10391, 10393, 10394, 10395, 10396, 10397, 10398, 10416, 10418, 10421, 10427, 10428, 10430, 10435, 10438, 10443, 10445, 10447, 10452, 10453, 10455, 10456, 10457, 10458, 10460, 10463, 10464, 10465, 10466, 10467, 10468, 10469, 10470, 10471, 10472, 10473, 10474, 10475, 10476, 10477, 10478, 10479, 10480, 10481, 10483, 10485, 10487, 10488, 10489, 10492, 10493, 10494, 10495, 10496, 10497, 10498, 10499, 10500, 10501, 10504, 10505, 10506, 10507, 10508, 10509, 10510, 10513, 10514, 10515, 10517, 10518, 10519, 10520, 10521, 10527, 10528, 10529, 10530, 10531, 10532, 10533, 10540, 10541, 10545, 10547, 10551, 10556, 10562, 10563, 10564, 10566, 10571, 10572, 10575, 10577, 10578, 10579, 10580, 10581, 10584, 10586, 10587, 10590, 10592, 10596, 10597, 10600, 10601, 10602, 10603, 10604, 10606, 10607, 10608, 10609, 10610, 10612, 10615, 10617, 10619, 10620, 10621, 10622, 10625, 10628, 10630, 10637, 10640, 10641, 10647, 10650, 10651, 10654, 10657, 10659, 10666, 10669, 10670, 10671, 10672, 10673, 10674, 10675, 10676, 10677, 10681, 10683, 10687, 10688, 10689, 10692, 10698, 10702, 10707, 10713, 10716, 10722, 10732, 10733, 10734, 10736, 10737, 10756, 10759, 10760, 10764, 10770, 10777, 10778, 10782, 10786, 10788, 10791, 10794, 10798, 10799, 10810, 10826, 10829, 10830, 10831, 10834, 10835, 10836, 10839, 10843, 10845, 10848, 10849, 10850, 10857, 10859, 10860, 10861, 10864, 10877, 10879, 10882, 10884, 10886, 10893, 10895, 10899, 10902, 10903, 10904, 10905, 10907, 10910, 10913, 10914, 10920, 10923, 10925, 10928, 10929, 10942, 10944, 10945, 10948, 10949, 10952, 10953, 10954, 10955, 10959, 10960, 10961, 10965, 10966, 10968, 10971, 10972, 10976, 10984, 10986, 10988, 10989, 10992, 11021, 11029, 11038, 11040, 11053, 11062, 11063, 11081, 11107, 11113, 11116, 11118, 11119, 11122, 11124, 11141, 11143, 11144, 11145, 11147, 11153, 11159, 11164, 11169, 11171, 11173, 11177, 11179, 11184, 11185, 11186, 11189, 11191, 11193, 11195, 11197, 11199, 11203, 11204, 11205, 11206, 11207, 11208, 11209, 11211, 11213, 11215, 11217, 11219, 11225, 11227, 11228, 11230, 11231, 11232, 11233, 11236, 11238, 11241, 11250, 11251, 11254, 11257, 11258, 11260, 11265, 11270, 11271, 11272, 11275, 11277, 11279, 11280, 11288, 11294, 11296, 11298, 11309, 11311, 11312, 11313, 11330, 11331, 11334, 11344, 11346, 11347, 11357, 11358, 11359, 11363, 11364, 11367, 11369, 11370, 11371, 11373, 11376, 11378, 11383, 11389, 11391, 11394, 11396, 11402, 11403, 11408, 11411, 11412, 11434, 11438, 11443, 11449, 11450, 11453, 11461, 11464, 11466, 11467, 11472, 11477, 11478, 11480, 11482, 11487, 11489, 11490, 11492, 11494, 11496, 11514, 11538, 11541, 11551, 11554, 11557, 11558, 11561, 11562, 11598, 11603, 11605, 11608, 11622, 11623, 11628, 11629, 11637, 11639, 11642, 11644, 11645, 11651, 11661, 11662, 11666, 11667, 11668, 11677, 11679, 11687, 11690, 11692, 11694, 11695, 11696, 11700, 11702, 11709, 11715, 11718, 11725, 11727, 11728, 11739, 11749, 11750, 11754, 11755, 11756, 11757, 11761, 11764, 11766, 11779, 11782, 11788, 11790, 11793, 11794, 11796, 11805, 11809, 11810, 11812, 11813, 11816, 11817, 11819, 11820, 11823, 11825, 11826, 11828, 11829, 11831, 11833, 11839, 11843, 11845, 11846, 11849, 11850, 11852, 11853, 11861, 11862, 11868, 11870, 11872, 11914, 11919, 11920, 11923, 11934, 11937, 11938, 11943, 11944, 11945, 11946, 11947, 11948, 11952, 11956, 11960, 11961, 11970, 11972, 11974, 11979, 11980, 11981, 11983, 11988, 11991, 11993, 11996, 11997, 11998, 12000, 12002, 12007, 12011, 12019, 12022, 12025, 12031, 12033, 12036, 12046, 12051, 12054, 12058, 12061, 12063, 12065, 12066, 12070, 12071, 12072, 12073, 12074, 12075, 12079, 12082, 12083, 12085, 12087, 12090, 12092, 12093, 12094, 12097, 12098, 12102, 12103, 12104, 12105, 12108, 12109, 12110, 12113, 12114, 12116, 12117, 12118, 12119, 12120, 12123, 12124, 12125, 12126, 12127, 12128, 12129, 12132, 12133, 12134, 12135, 12137, 12138, 12139, 12140, 12142, 12143, 12145, 12147, 12148, 12150, 12152, 12153, 12155, 12156, 12159, 12160, 12161, 12167, 12168, 12171, 12172, 12173, 12174, 12175, 12176, 12178, 12180, 12181, 12183, 12186, 12188, 12194, 12195, 12199, 12201, 12202, 12203, 12205, 12207, 12208, 12209, 12210, 12213, 12215, 12216, 12217, 12220, 12221, 12222, 12223, 12230, 12231, 12233, 12234, 12235, 12239, 12240, 12246, 12250, 12251, 12252, 12258, 12259, 12260, 12261, 12262, 12263, 12264, 12268, 12269, 12270, 12271, 12272, 12276, 12278, 12281, 12282, 12285, 12286, 12287, 12288, 12289, 12290, 12291, 12292, 12293, 12294, 12295, 12296, 12297, 12298, 12299, 12300, 12301, 12302, 12303, 12304, 12305, 12306, 12307, 12309, 12311, 12313, 12314, 12317, 12319, 12321, 12325, 12334, 12340, 12342, 12343, 12344, 12345, 12346, 12347, 12348, 12349, 12350, 12351, 12352, 12353, 12354, 12355, 12356, 12357, 12359, 12361, 12362, 12364, 12365, 12366, 12367, 12368, 12369, 12370, 12371, 12375, 12376, 12378, 12380, 12382, 12384, 12385, 12386, 12387, 12391, 12392, 12393, 12394, 12395, 12396, 12397, 12398, 12399, 12407, 12408, 12409, 12412, 12427, 12450, 12452, 12470, 12478, 12479, 12480, 12481, 12482, 12484, 12485, 12486, 12490, 12505, 12508, 12510, 12512, 12520, 12521, 12528, 12533, 12535, 12565, 12582, 12583, 12592, 12596, 12598, 12599, 12600, 12602, 12603, 12604, 12606, 12609, 12611, 12613, 12614, 12617, 12626, 12645, 12656, 12672, 12673, 12684, 12697, 12698, 12702, 12705, 12706, 12707, 12708, 12709, 12712, 12714, 12716, 12718, 12720, 12723, 12724, 12725, 12730, 12732, 12735, 12737, 12739, 12740, 12741, 12744, 12745, 12746, 12749, 12750, 12751, 12753, 12754, 12756, 12757, 12759, 12761, 12762, 12765, 12766, 12770, 12771, 12772, 12775, 12779, 12801, 12810, 12830, 12831, 12835, 12843, 12867, 12870, 12871, 12872, 12884, 12886, 12890, 12893, 12895, 12896, 12897, 12899, 12901, 12921, 12922, 12925, 12938, 12943, 12945, 12947, 12948, 12951, 12954, 12959, 12960, 12961, 12962, 12967, 12976, 12982, 12984, 12988, 12989, 12990, 12992, 12993, 12996, 12998, 13008, 13009, 13010, 13012, 13018, 13019, 13021, 13022, 13024, 13025, 13026, 13027, 13028, 13031, 13034, 13037, 13044, 13070, 13079, 13084, 13092, 13094, 13095, 13096, 13097, 13098, 13101, 13102, 13104, 13105, 13106, 13108, 13111, 13112, 13114, 13115, 13117, 13118, 13119, 13121, 13123, 13127, 13128, 13129, 13130, 13131, 13132, 13133, 13134, 13137, 13139, 13141, 13142, 13143, 13154, 13167, 13184, 13187, 13189, 13193, 13201, 13219, 13220, 13246, 13265, 13302, 13308, 13311, 13317, 13336, 13345, 13346, 13348, 13350, 13351, 13354, 13356, 13359, 13360, 13364, 13365, 13366, 13367, 13368, 13370, 13371, 13372, 13375, 13377, 13378, 13379, 13380, 13381, 13387, 13388, 13390, 13391, 13393, 13394, 13395, 13397, 13400, 13413, 13415, 13417, 13418, 13419, 13420, 13421, 13422, 13425, 13427, 13428, 13431, 13432, 13434, 13437, 13438, 13441, 13442, 13444, 13446, 13447, 13449, 13451, 13452, 13453, 13455, 13459, 13462, 13464, 13465, 13471, 13472, 13474, 13475, 13476, 13478, 13480, 13481, 13483, 13516, 13518, 13519, 13520, 13522, 13529, 13530, 13534, 13540, 13541, 13549, 13559, 13561, 13568, 13569, 13571, 13573, 13575, 13577, 13579, 13581, 13582, 13583, 13584, 13585, 13586, 13590, 13591, 13593, 13595, 13603, 13605, 13606, 13608, 13609, 13611, 13612, 13613, 13614, 13618, 13620, 13626, 13627, 13629, 13630, 13631, 13632, 13634, 13635, 13637, 13638, 13641, 13644, 13646, 13647, 13649, 13650, 13651, 13652, 13653, 13656, 13657, 13659, 13662, 13663, 13664, 13666, 13667, 13670, 13671, 13673, 13674, 13676, 13677, 13680, 13681, 13682, 13683, 13685, 13688, 13690, 13691, 13692, 13724, 13727, 13729, 13731, 13735, 13738, 13739, 13745, 13746, 13749, 13753, 13755, 13756, 13757, 13758, 13759, 13767, 13769, 13771, 13777, 13779, 13781, 13785, 13786, 13788, 13791, 13792, 13793, 13797, 13799, 13801, 13802, 13803, 13804, 13808, 13814, 13815, 13816, 13817, 13819, 13821, 13824, 13826, 13827, 13828, 13829, 13831, 13834, 13835, 13838, 13840, 13841, 13842, 13843, 13845, 13846, 13847, 13849, 13855, 13862, 13863, 13864, 13865, 13866, 13867, 13868, 13876, 13878, 13879, 13880, 13881, 13884, 13885, 13886, 13887, 13888, 13890, 13891, 13895, 13897, 13903, 13910, 13919, 13920, 13937, 13943, 13946, 13950, 13951, 13952, 13962, 13978, 13987, 14006, 14010, 14015, 14016, 14017, 14024, 14026, 14039, 14045, 14046, 14051, 14074, 14088, 14100, 14106, 14107, 14111, 14115, 14116, 14123, 14129, 14131, 14139, 14142, 14145, 14147, 14150, 14151, 14153, 14154, 14160, 14168, 14171, 14175, 14177, 14186, 14220, 14239, 14241, 14252, 14253, 14279, 14303, 14314, 14319, 14320, 14322, 14337, 14344, 14345, 14349, 14351, 14356, 14357, 14358, 14360, 14363, 14364, 14366, 14367, 14375, 14376, 14377, 14381, 14383, 14384, 14385, 14386, 14387, 14388, 14389, 14391, 14397, 14398, 14401, 14403, 14407, 14415, 14418, 14420, 14426, 14427, 14428, 14429, 14431, 14432, 14433, 14434, 14435, 14437, 14438, 14439, 14440, 14441, 14443, 14444, 14445, 14447, 14448, 14452, 14453, 14460, 14461, 14463, 14464, 14465, 14467, 14468, 14470, 14471, 14472, 14473, 14474, 14475, 14476, 14478, 14479, 14480, 14481, 14482, 14483, 14484, 14485, 14486, 14487, 14489, 14490, 14491, 14492, 14495, 14496, 14498, 14499, 14501, 14502, 14511, 14513, 14516, 14517, 14518, 14519, 14520, 14521, 14522, 14523, 14524, 14525, 14526, 14527, 14528, 14529, 14531, 14532, 14533, 14534, 14535, 14537, 14538, 14540, 14541, 14542, 14543, 14545, 14546, 14547, 14548, 14552, 14553, 14555, 14556, 14560, 14564, 14567, 14568, 14569, 14570, 14571, 14575, 14578, 14579, 14580, 14581, 14583, 14585, 14586, 14587, 14588, 14589, 14594, 14595, 14596, 14597, 14598, 14599, 14600, 14602, 14603, 14604, 14606, 14607, 14610, 14611, 14612, 14614, 14617, 14618, 14620, 14621, 14622, 14623, 14624, 14625, 14629, 14635, 14637, 14638, 14643, 14644, 14649, 14651, 14657, 14659, 14661, 14662, 14670, 14677, 14707, 14715, 14717, 14719, 14729, 14746, 14751, 14764, 14766, 14768, 14770, 14771, 14772, 14775, 14777, 14778, 14779, 14782, 14785, 14786, 14787, 14788, 14790, 14791, 14792, 14795, 14801, 14804, 14807, 14811, 14813, 14816, 14818, 14819, 14821, 14822, 14824, 14825, 14826, 14829, 14833, 14834, 14838, 14839
Download .txt
gitextract_v4nqs5sn/

├── README.md
├── alm/
│   ├── callback/
│   │   ├── __init__.py
│   │   └── progress.py
│   ├── config.py
│   ├── data/
│   │   ├── BIWI/
│   │   │   ├── __init__.py
│   │   │   └── dataset.py
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── biwi.py
│   │   ├── get_data.py
│   │   ├── voca/
│   │   │   ├── __init__.py
│   │   │   └── dataset.py
│   │   └── vocaset.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── architectures/
│   │   │   ├── __init__.py
│   │   │   ├── adpt_bias_denoiser.py
│   │   │   └── tools/
│   │   │       ├── embeddings.py
│   │   │       ├── transformer_adpt.py
│   │   │       └── utils.py
│   │   ├── get_model.py
│   │   ├── losses/
│   │   │   ├── __init__.py
│   │   │   ├── utils.py
│   │   │   └── voca.py
│   │   └── modeltype/
│   │       ├── __init__.py
│   │       ├── base.py
│   │       └── diffusion_bias.py
│   └── utils/
│       ├── __init__.py
│       ├── demo_utils.py
│       ├── logger.py
│       └── temos_utils.py
├── configs/
│   ├── assets/
│   │   ├── biwi.yaml
│   │   └── vocaset.yaml
│   ├── base.yaml
│   └── diffusion/
│       ├── biwi/
│       │   ├── diffspeaker_hubert_biwi.yaml
│       │   └── diffspeaker_wav2vec2_biwi.yaml
│       ├── diffusion_bias_modules/
│       │   ├── denoiser.yaml
│       │   └── scheduler.yaml
│       └── vocaset/
│           ├── diffspeaker_hubert_vocaset.yaml
│           └── diffspeaker_wav2vec2_vocaset.yaml
├── datasets/
│   ├── biwi/
│   │   ├── README.md
│   │   ├── regions/
│   │   │   ├── fdd.txt
│   │   │   └── lve.txt
│   │   ├── templates/
│   │   │   └── BIWI.ply
│   │   └── templates.pkl
│   └── vocaset/
│       ├── FLAME_masks.pkl
│       ├── README.md
│       ├── templates/
│       │   ├── FLAME_sample.ply
│       │   └── README.md
│       └── templates.pkl
├── demo_biwi.py
├── demo_vocaset.py
├── demo_vocaset_text.py
├── eval_biwi.py
├── eval_vocaset.py
├── requirements.txt
├── scripts/
│   ├── demo/
│   │   ├── demo_biwi.sh
│   │   └── demo_vocaset.sh
│   └── diffusion/
│       ├── biwi_evaluation/
│       │   ├── diffspeaker_hubert_biwi.sh
│       │   └── diffspeaker_wav2vec2_biwi.sh
│       ├── biwi_training/
│       │   ├── diffspeaker_hubert_biwi.sh
│       │   └── diffspeaker_wav2vec2_biwi.sh
│       ├── vocaset_evaluation/
│       │   ├── diffspeaker_hubert_vocaset.sh
│       │   └── diffspeaker_wav2vec2_vocaset.sh
│       └── vocaset_training/
│           ├── diffspeaker_hubert_vocaset.sh
│           └── diffspeaker_wav2vec2_vocaset.sh
└── train.py
Download .txt
SYMBOL INDEX (157 symbols across 26 files)

FILE: alm/callback/progress.py
  class ProgressLogger (line 10) | class ProgressLogger(Callback):
    method __init__ (line 12) | def __init__(self, metric_monitor: dict, precision: int = 3):
    method on_train_start (line 17) | def on_train_start(self, trainer: Trainer, pl_module: LightningModule,
    method on_train_end (line 21) | def on_train_end(self, trainer: Trainer, pl_module: LightningModule,
    method on_validation_epoch_end (line 25) | def on_validation_epoch_end(self, trainer: Trainer,
    method on_train_epoch_end (line 30) | def on_train_epoch_end(self,

FILE: alm/config.py
  function get_module_config (line 7) | def get_module_config(cfg_model, path="modules"):
  function get_obj_from_str (line 16) | def get_obj_from_str(string, reload=False):
  function instantiate_from_config (line 24) | def instantiate_from_config(config):
  function parse_args (line 34) | def parse_args(phase="train"):

FILE: alm/data/BIWI/dataset.py
  class BIWIDataset (line 13) | class BIWIDataset(data.Dataset):
    method __init__ (line 15) | def __init__(self,
    method __len__ (line 29) | def __len__(self):
    method __getitem__ (line 32) | def __getitem__(self, index):

FILE: alm/data/base.py
  class BASEDataModule (line 7) | class BASEDataModule(pl.LightningDataModule):
    method __init__ (line 9) | def __init__(self, collate_fn, batch_size: int, num_workers: int):
    method __getattr__ (line 32) | def __getattr__(self, item):
    method setup (line 53) | def setup(self, stage=None):
    method train_dataloader (line 62) | def train_dataloader(self):
    method predict_dataloader (line 70) | def predict_dataloader(self):
    method val_dataloader (line 82) | def val_dataloader(self):
    method test_dataloader (line 94) | def test_dataloader(self):

FILE: alm/data/biwi.py
  function load_data (line 15) | def load_data(args):
  class BIWIDataModule (line 36) | class BIWIDataModule(BASEDataModule):
    method __init__ (line 37) | def __init__(self,
    method __getattr__ (line 175) | def __getattr__(self, item):

FILE: alm/data/get_data.py
  function collate_tensors (line 4) | def collate_tensors(batch):
  function vocaset_collate_fn (line 16) | def vocaset_collate_fn(batch):
  function voxcelebinsta_collate_fn (line 31) | def voxcelebinsta_collate_fn(batch):
  function voxcelebinstacoeflmdb_collate_fn (line 60) | def voxcelebinstacoeflmdb_collate_fn(batch):
  function get_datasets (line 94) | def get_datasets(cfg, logger, phase='train'):

FILE: alm/data/voca/dataset.py
  class VOCASETDataset (line 13) | class VOCASETDataset(data.Dataset):
    method __init__ (line 15) | def __init__(self,
    method __getitem__ (line 27) | def __getitem__(self, index):
    method __len__ (line 57) | def __len__(self):

FILE: alm/data/vocaset.py
  function load_data (line 15) | def load_data(args):
  class VOCASETDataModule (line 36) | class VOCASETDataModule(BASEDataModule):
    method __init__ (line 37) | def __init__(self,
    method __getattr__ (line 155) | def __getattr__(self, item):

FILE: alm/models/architectures/adpt_bias_denoiser.py
  class Adpt_Bias_Denoiser (line 13) | class Adpt_Bias_Denoiser(nn.Module):
    method __init__ (line 15) | def __init__(self,
    method forward (line 104) | def forward(self,

FILE: alm/models/architectures/tools/embeddings.py
  function get_activation (line 8) | def get_activation(activation_type):
  class MaskedNorm (line 37) | class MaskedNorm(nn.Module):
    method __init__ (line 43) | def __init__(self, norm_type, num_groups, num_features):
    method forward (line 57) | def forward(self, x: Tensor, mask: Tensor):
  class Embeddings (line 76) | class Embeddings(nn.Module):
    method __init__ (line 83) | def __init__(
    method forward (line 133) | def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
    method __repr__ (line 155) | def __repr__(self):
  class SpatialEmbeddings (line 163) | class SpatialEmbeddings(nn.Module):
    method __init__ (line 171) | def __init__(
    method forward (line 218) | def forward(self, x: Tensor, mask: Tensor) -> Tensor:
    method __repr__ (line 238) | def __repr__(self):
  function get_timestep_embedding (line 245) | def get_timestep_embedding(
  class TimestepEmbedding (line 288) | class TimestepEmbedding(nn.Module):
    method __init__ (line 289) | def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "s...
    method forward (line 298) | def forward(self, sample):
  class Timesteps (line 308) | class Timesteps(nn.Module):
    method __init__ (line 309) | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale...
    method forward (line 315) | def forward(self, timesteps):

FILE: alm/models/architectures/tools/transformer_adpt.py
  class Transformer_Adpt (line 13) | class Transformer_Adpt(nn.Module):
    method __init__ (line 15) | def __init__(self,
    method obj_vector (line 85) | def obj_vector(self, id):
    method _forward (line 91) | def _forward(self,
  class TransformerDecoder_w_Adapter (line 131) | class TransformerDecoder_w_Adapter(nn.TransformerDecoder):
    method __init__ (line 141) | def __init__(self, decoder_layer, num_layers, norm=None):
    method forward (line 147) | def forward(self,
  class TransformerDecoderLayer_w_Adapter (line 184) | class TransformerDecoderLayer_w_Adapter(nn.TransformerDecoderLayer):
    method __init__ (line 201) | def __init__(self,
    method forward (line 217) | def forward(self,
    method _sa_block (line 252) | def _sa_block(self, x: Tensor,
    method _mha_block (line 289) | def _mha_block(self, x: Tensor, mem: Tensor,
    method _concate_adapter (line 326) | def _concate_adapter(self, adapter: Tensor, x: Tensor, batch_first: bo...

FILE: alm/models/architectures/tools/utils.py
  class PeriodicPositionalEncoding (line 5) | class PeriodicPositionalEncoding(nn.Module):
    method __init__ (line 6) | def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=600):
    method forward (line 19) | def forward(self, x):
  function init_biased_mask (line 95) | def init_biased_mask(n_head, max_seq_len, period):
  function init_bi_biased_mask (line 133) | def init_bi_biased_mask(max_seq_len, ):
  function init_mem_mask_faceformer (line 147) | def init_mem_mask_faceformer(max_seq_len):
  function init_bi_biased_mask_faceformer (line 153) | def init_bi_biased_mask_faceformer(n_head, max_seq_len, period):

FILE: alm/models/get_model.py
  function get_model (line 3) | def get_model(cfg, datamodule):
  function get_module (line 11) | def get_module(cfg, datamodule):

FILE: alm/models/losses/utils.py
  class MaskedConsistency (line 4) | class MaskedConsistency:
    method __init__ (line 5) | def __init__(self) -> None:
    method __call__ (line 8) | def __call__(self, pred, gt, mask):
    method __repr__ (line 11) | def __repr__(self):
  class MaskedVelocityConsistency (line 14) | class MaskedVelocityConsistency:
    method __init__ (line 15) | def __init__(self) -> None:
    method __call__ (line 18) | def __call__(self, pred, gt, mask):
    method velocity (line 23) | def velocity(self, term):
    method __repr__ (line 27) | def __repr__(self):
  function vocaset_upper_face_variance (line 164) | def vocaset_upper_face_variance(motion, ):
  function vocaset_mouth_distance (line 172) | def vocaset_mouth_distance(vertices_gt, vertices_pred):
  function biwi_upper_face_variance (line 179) | def biwi_upper_face_variance(motion, ):
  function biwi_mouth_distance (line 187) | def biwi_mouth_distance(vertices_gt, vertices_pred):

FILE: alm/models/losses/voca.py
  class VOCALosses (line 32) | class VOCALosses(Metric):
    method __init__ (line 37) | def __init__(self, cfg, split):
    method update (line 121) | def update(self, rs_set):
    method compute (line 151) | def compute(self, split):
    method _update_loss (line 156) | def _update_loss(self, loss: str, outputs, inputs, mask = None):
    method loss2logname (line 167) | def loss2logname(self, loss: str, split: str):
  class MaskedConsistency (line 175) | class MaskedConsistency:
    method __init__ (line 176) | def __init__(self) -> None:
    method __call__ (line 179) | def __call__(self, pred, gt, mask):
    method __repr__ (line 186) | def __repr__(self):
  class MaskedVelocityConsistency (line 189) | class MaskedVelocityConsistency:
    method __init__ (line 190) | def __init__(self) -> None:
    method __call__ (line 193) | def __call__(self, pred, gt, mask):
    method velocity (line 198) | def velocity(self, term):
    method __repr__ (line 202) | def __repr__(self):

FILE: alm/models/modeltype/base.py
  class BaseModel (line 8) | class BaseModel(LightningModule):
    method __init__ (line 10) | def __init__(self, *args, **kwargs):
    method __post_init__ (line 14) | def __post_init__(self):
    method training_step (line 25) | def training_step(self, batch, batch_idx):
    method validation_step (line 28) | def validation_step(self, batch, batch_idx):
    method test_step (line 31) | def test_step(self, batch, batch_idx):
    method predict_step (line 36) | def predict_step(self, batch, batch_idx):
    method allsplit_epoch_end (line 39) | def allsplit_epoch_end(self, split: str, outputs):
    method training_epoch_end (line 73) | def training_epoch_end(self, outputs):
    method validation_epoch_end (line 76) | def validation_epoch_end(self, outputs):
    method test_epoch_end (line 79) | def test_epoch_end(self, outputs):
    method configure_optimizers (line 115) | def configure_optimizers(self):

FILE: alm/models/modeltype/diffusion_bias.py
  class DIFFUSION_BIAS (line 26) | class DIFFUSION_BIAS(BaseModel):
    method __init__ (line 28) | def __init__(self, cfg, datamodule, **kwargs):
    method allsplit_step (line 95) | def allsplit_step(self, split: str, batch, batch_idx):
    method _memory_mask (line 273) | def _memory_mask(self, hidden_attention, ):
    method _tgt_mask (line 312) | def _tgt_mask(self, vertice_attention, ):
    method _mem_key_padding_mask (line 349) | def _mem_key_padding_mask(self, vertice_attention):
    method _tgt_key_padding_mask (line 364) | def _tgt_key_padding_mask(self, vertice_attention):
    method _audio_resize (line 379) | def _audio_resize(self, hidden_state: torch.Tensor, input_fps: Optiona...
    method _audio_2_hidden (line 399) | def _audio_2_hidden(self, audio, audio_attention, length = None):
    method _diffusion_forward (line 417) | def _diffusion_forward(self, batch, batch_idx, phase):
    method smooth (line 487) | def smooth(self, vertices):
    method predict (line 498) | def predict(self, batch, **kwargs):
    method _diffusion_process (line 546) | def _diffusion_process(
    method _diffusion_reverse (line 598) | def _diffusion_reverse(
    method _visualize (line 675) | def _visualize(self, batch, rs_set, parrallel = True):

FILE: alm/utils/demo_utils.py
  function load_example_input (line 25) | def load_example_input(audio_path, processor = None):
  function render_mesh_helper (line 132) | def render_mesh_helper(mesh, t_center, rot=np.zeros(3), tex_img=None, z_...
  function render_frame (line 218) | def render_frame(args):
  function animate (line 225) | def animate(vertices: np.array, wav_path: str, file_name: str, ply: str,...

FILE: alm/utils/logger.py
  function create_logger (line 9) | def create_logger(cfg, phase='train'):
  function config_logger (line 38) | def config_logger(final_output_dir, time_str, phase, head):
  function new_dir (line 56) | def new_dir(cfg, phase, time_str, final_output_dir):

FILE: alm/utils/temos_utils.py
  function lengths_to_mask (line 7) | def lengths_to_mask(lengths: Tensor, # [batch_size]
  function remove_padding (line 15) | def remove_padding(tensors, lengths):

FILE: demo_biwi.py
  function main (line 12) | def main():

FILE: demo_vocaset.py
  function main (line 36) | def main():

FILE: demo_vocaset_text.py
  function main (line 13) | def main():

FILE: eval_biwi.py
  function print_table (line 54) | def print_table(title, metrics):
  function get_metric_statistics (line 66) | def get_metric_statistics(values, replication_times):
  function main (line 72) | def main():

FILE: eval_vocaset.py
  function print_table (line 27) | def print_table(title, metrics):
  function get_metric_statistics (line 39) | def get_metric_statistics(values, replication_times):
  function main (line 45) | def main():

FILE: train.py
  function main (line 18) | def main():
Condensed preview — 66 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (6,533K chars).
[
  {
    "path": "README.md",
    "chars": 2256,
    "preview": "# DiffSpeaker: Speech-Driven 3D Facial Animation with Diffusion Transformer\n## [Paper](https://arxiv.org/pdf/2402.05712."
  },
  {
    "path": "alm/callback/__init__.py",
    "chars": 37,
    "preview": "from .progress import ProgressLogger\n"
  },
  {
    "path": "alm/callback/progress.py",
    "chars": 1906,
    "preview": "import logging\n\nfrom pytorch_lightning import LightningModule, Trainer\nfrom pytorch_lightning.callbacks import Callback\n"
  },
  {
    "path": "alm/config.py",
    "chars": 6915,
    "preview": "import importlib\nfrom argparse import ArgumentParser\nfrom omegaconf import OmegaConf\nimport os\n\n\ndef get_module_config(c"
  },
  {
    "path": "alm/data/BIWI/__init__.py",
    "chars": 32,
    "preview": "from .dataset import BIWIDataset"
  },
  {
    "path": "alm/data/BIWI/dataset.py",
    "chars": 2063,
    "preview": "import numpy as np\nimport torch\nfrom torch.utils import data\nfrom transformers import Wav2Vec2Processor\nfrom collections"
  },
  {
    "path": "alm/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "alm/data/base.py",
    "chars": 3992,
    "preview": "from os.path import join as pjoin\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.utils.data import DataLoa"
  },
  {
    "path": "alm/data/biwi.py",
    "chars": 6269,
    "preview": "from .base import BASEDataModule\nfrom transformers import Wav2Vec2Processor\nfrom collections import defaultdict\nfrom .BI"
  },
  {
    "path": "alm/data/get_data.py",
    "chars": 8922,
    "preview": "import numpy as np\nimport torch\n\ndef collate_tensors(batch):\n    dims = batch[0].dim()\n    max_size = [max([b.size(i) fo"
  },
  {
    "path": "alm/data/voca/__init__.py",
    "chars": 35,
    "preview": "from .dataset import VOCASETDataset"
  },
  {
    "path": "alm/data/voca/dataset.py",
    "chars": 1945,
    "preview": "import numpy as np\nimport torch\nfrom torch.utils import data\nfrom transformers import Wav2Vec2Processor\nfrom collections"
  },
  {
    "path": "alm/data/vocaset.py",
    "chars": 6090,
    "preview": "from .base import BASEDataModule\nfrom alm.data.voca import VOCASETDataset\nfrom transformers import Wav2Vec2Processor\nfro"
  },
  {
    "path": "alm/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "alm/models/architectures/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "alm/models/architectures/adpt_bias_denoiser.py",
    "chars": 6677,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom alm.models.architecture"
  },
  {
    "path": "alm/models/architectures/tools/embeddings.py",
    "chars": 9714,
    "preview": "# This file is taken from signjoey repository\nimport math\n\nimport torch\nfrom torch import Tensor, nn\n\n\ndef get_activatio"
  },
  {
    "path": "alm/models/architectures/tools/transformer_adpt.py",
    "chars": 14088,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom alm.models.architecture"
  },
  {
    "path": "alm/models/architectures/tools/utils.py",
    "chars": 7517,
    "preview": "import math\nimport torch\nimport torch.nn as nn\n\nclass PeriodicPositionalEncoding(nn.Module):\n    def __init__(self, d_mo"
  },
  {
    "path": "alm/models/get_model.py",
    "chars": 595,
    "preview": "import importlib\n\ndef get_model(cfg, datamodule):\n    modeltype = cfg.model.model_type\n    return get_module(cfg, datamo"
  },
  {
    "path": "alm/models/losses/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "alm/models/losses/utils.py",
    "chars": 7112,
    "preview": "import torch\nimport torch.nn as nn\n\nclass MaskedConsistency:\n    def __init__(self) -> None:\n        self.loss = nn.MSEL"
  },
  {
    "path": "alm/models/losses/voca.py",
    "chars": 7145,
    "preview": "# import numpy as np\n# import torch\n# import torch.nn as nn\n# from torchmetrics import Metric\n\n# class almLosses(Metric)"
  },
  {
    "path": "alm/models/modeltype/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "alm/models/modeltype/base.py",
    "chars": 4311,
    "preview": "import os\nfrom pathlib import Path\nimport numpy as np\nfrom pytorch_lightning import LightningModule\nimport torch\nfrom co"
  },
  {
    "path": "alm/models/modeltype/diffusion_bias.py",
    "chars": 34657,
    "preview": "import torch\nfrom torch.optim import AdamW, Adam\nimport torch.nn.functional as F\nfrom torchmetrics import MetricCollecti"
  },
  {
    "path": "alm/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "alm/utils/demo_utils.py",
    "chars": 13178,
    "preview": "from transformers import Wav2Vec2Processor\nimport numpy as np\nimport librosa\nimport os\nimport torch\nimport cv2\nimport py"
  },
  {
    "path": "alm/utils/logger.py",
    "chars": 2454,
    "preview": "from pathlib import Path\nimport os\nimport time\nimport logging\nfrom omegaconf import OmegaConf\nfrom pytorch_lightning.uti"
  },
  {
    "path": "alm/utils/temos_utils.py",
    "chars": 574,
    "preview": "from typing import Dict, List\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\ndef lengths_to_mask(lengths: Te"
  },
  {
    "path": "configs/assets/biwi.yaml",
    "chars": 162,
    "preview": "FOLDER: './experiments/biwi' # Experiment files saving path\n\nTEST:\n  FOLDER: './results' # Testing files saving path\n\nDA"
  },
  {
    "path": "configs/assets/vocaset.yaml",
    "chars": 171,
    "preview": "FOLDER: './experiments/vocaset' # Experiment files saving path\n\nTEST:\n  FOLDER: './results' # Testing files saving path\n"
  },
  {
    "path": "configs/base.yaml",
    "chars": 1448,
    "preview": "# FOLDER: ./experiments\nSEED_VALUE: 1234\nDEBUG: True\nTRAIN:\n  SPLIT: 'train'\n  NUM_WORKERS: 1 #2 # Number of workers\n  B"
  },
  {
    "path": "configs/diffusion/biwi/diffspeaker_hubert_biwi.yaml",
    "chars": 2743,
    "preview": "NAME: diffspeaker_hubert_biwi # Experiment name\nDEBUG: False # Debug mode\nACCELERATOR: 'gpu' # Devices optioncal: “cpu”,"
  },
  {
    "path": "configs/diffusion/biwi/diffspeaker_wav2vec2_biwi.yaml",
    "chars": 2861,
    "preview": "NAME: diffspeaker_wav2vec2_biwi # Experiment name\nDEBUG: False # Debug mode\nACCELERATOR: 'gpu' # Devices optioncal: “cpu"
  },
  {
    "path": "configs/diffusion/diffusion_bias_modules/denoiser.yaml",
    "chars": 1207,
    "preview": "denoiser: # this is copied from configs/baselines/transformer_adpt_modules/transformer.yaml\n  target: alm.models.archite"
  },
  {
    "path": "configs/diffusion/diffusion_bias_modules/scheduler.yaml",
    "chars": 770,
    "preview": "scheduler:\n  target: diffusers.DDIMScheduler\n  num_inference_timesteps: 50\n  eta: 0.0\n  params:\n    num_train_timesteps:"
  },
  {
    "path": "configs/diffusion/vocaset/diffspeaker_hubert_vocaset.yaml",
    "chars": 2618,
    "preview": "NAME: diffspeaker_hubert_vocaset # Experiment name\nDEBUG: False # Debug mode\nACCELERATOR: 'gpu' # Devices optioncal: “cp"
  },
  {
    "path": "configs/diffusion/vocaset/diffspeaker_wav2vec2_vocaset.yaml",
    "chars": 2621,
    "preview": "NAME: diffspeaker_wav2vec2_vocaset # Experiment name\nDEBUG: False # Debug mode\nACCELERATOR: 'gpu' # Devices optioncal: “"
  },
  {
    "path": "datasets/biwi/README.md",
    "chars": 92,
    "preview": "should contanin\n```\n    regions\n    templates\n    templates.pkl\n    vertices_npy\n    wav\n```"
  },
  {
    "path": "datasets/biwi/regions/fdd.txt",
    "chars": 50854,
    "preview": "7, 9, 59, 89, 94, 103, 122, 126, 152, 159, 160, 161, 164, 166, 169, 174, 177, 178, 179, 180, 182, 183, 184, 185, 195, 21"
  },
  {
    "path": "datasets/biwi/regions/lve.txt",
    "chars": 32611,
    "preview": "2, 6, 21, 23, 24, 25, 27, 28, 29, 39, 45, 47, 49, 50, 55, 60, 68, 70, 73, 76, 80, 81, 82, 83, 85, 86, 87, 133, 170, 171,"
  },
  {
    "path": "datasets/vocaset/FLAME_masks.pkl",
    "chars": 215062,
    "preview": "(dp0\nS'eye_region'\np1\ncnumpy.core.multiarray\n_reconstruct\np2\n(cnumpy\nndarray\np3\n(I0\ntp4\nS'b'\np5\ntp6\nRp7\n(I1\n(I751\ntp8\ncn"
  },
  {
    "path": "datasets/vocaset/README.md",
    "chars": 80,
    "preview": "should contanin\n```\n    templates\n    templates.pkl\n    vertices_npy\n    wav\n```"
  },
  {
    "path": "datasets/vocaset/templates/README.md",
    "chars": 148,
    "preview": "Put \"FLAME_sample.ply\" in this folder. \"FLAME_sample.ply\" can be downloaded from [voca](https://github.com/TimoBolkart/v"
  },
  {
    "path": "datasets/vocaset/templates.pkl",
    "chars": 4832756,
    "preview": "(dp0\nS'FaceTalk_170904_00128_TA'\np1\ncnumpy.core.multiarray\n_reconstruct\np2\n(cnumpy\nndarray\np3\n(I0\ntp4\nS'b'\np5\ntp6\nRp7\n(I"
  },
  {
    "path": "demo_biwi.py",
    "chars": 3165,
    "preview": "import os\nimport pickle\nimport torch\n\nfrom alm.config import parse_args\nfrom alm.models.get_model import get_model\nfrom "
  },
  {
    "path": "demo_vocaset.py",
    "chars": 7443,
    "preview": "import os\nimport pickle\nimport torch\nimport torch.nn.functional as F\n\nfrom alm.config import parse_args\nfrom alm.models."
  },
  {
    "path": "demo_vocaset_text.py",
    "chars": 4681,
    "preview": "import os\nimport pickle\nimport torch\n\nfrom alm.config import parse_args\nfrom alm.models.get_model import get_model\nfrom "
  },
  {
    "path": "eval_biwi.py",
    "chars": 5370,
    "preview": "#############################################################################################################\n# this fil"
  },
  {
    "path": "eval_vocaset.py",
    "chars": 4626,
    "preview": "#############################################################################################################\n# this fil"
  },
  {
    "path": "requirements.txt",
    "chars": 298,
    "preview": "imageio==2.21.2\nlibrosa==0.9.2\nlmdb==1.3.0\nnumpy==1.23.5\nomegaconf==2.3.0\nopencv_python==4.7.0.72\npsbody_mesh==0.4\npsuti"
  },
  {
    "path": "scripts/demo/demo_biwi.sh",
    "chars": 781,
    "preview": "export CUDA_VISIBLE_DEVICES=0\n\n# use hubert backbone\npython demo_biwi.py \\\n    --cfg configs/diffusion/biwi/diffspeaker_"
  },
  {
    "path": "scripts/demo/demo_vocaset.sh",
    "chars": 887,
    "preview": "export CUDA_VISIBLE_DEVICES=1\n\n# # use hubert backbone\n# python demo_vocaset.py \\\n#     --cfg configs/diffusion/vocaset/"
  },
  {
    "path": "scripts/diffusion/biwi_evaluation/diffspeaker_hubert_biwi.sh",
    "chars": 160,
    "preview": "export CUDA_VISIBLE_DEVICES=0\npython eval_biwi.py \\\n    --cfg configs/diffusion/biwi/diffspeaker_hubert_biwi.yaml \\\n    "
  },
  {
    "path": "scripts/diffusion/biwi_evaluation/diffspeaker_wav2vec2_biwi.sh",
    "chars": 162,
    "preview": "export CUDA_VISIBLE_DEVICES=0\npython eval_biwi.py \\\n    --cfg configs/diffusion/biwi/diffspeaker_wav2vec2_biwi.yaml \\\n  "
  },
  {
    "path": "scripts/diffusion/biwi_training/diffspeaker_hubert_biwi.sh",
    "chars": 207,
    "preview": "export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\npython -m train \\\n    --cfg configs/diffusion/biwi/diffspeaker_hubert_biwi.y"
  },
  {
    "path": "scripts/diffusion/biwi_training/diffspeaker_wav2vec2_biwi.sh",
    "chars": 209,
    "preview": "export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\npython -m train \\\n    --cfg configs/diffusion/biwi/diffspeaker_wav2vec2_biwi"
  },
  {
    "path": "scripts/diffusion/vocaset_evaluation/diffspeaker_hubert_vocaset.sh",
    "chars": 172,
    "preview": "export CUDA_VISIBLE_DEVICES=0\npython eval_vocaset.py \\\n    --cfg configs/diffusion/vocaset/diffspeaker_hubert_vocaset.ya"
  },
  {
    "path": "scripts/diffusion/vocaset_evaluation/diffspeaker_wav2vec2_vocaset.sh",
    "chars": 174,
    "preview": "export CUDA_VISIBLE_DEVICES=0\npython eval_vocaset.py \\\n    --cfg configs/diffusion/vocaset/diffspeaker_wav2vec2_vocaset."
  },
  {
    "path": "scripts/diffusion/vocaset_training/diffspeaker_hubert_vocaset.sh",
    "chars": 203,
    "preview": "export CUDA_VISIBLE_DEVICES=0\npython -m train \\\n    --cfg configs/diffusion/vocaset/diffspeaker_hubert_vocaset.yaml \\\n  "
  },
  {
    "path": "scripts/diffusion/vocaset_training/diffspeaker_wav2vec2_vocaset.sh",
    "chars": 205,
    "preview": "export CUDA_VISIBLE_DEVICES=0\npython -m train \\\n    --cfg configs/diffusion/vocaset/diffspeaker_wav2vec2_vocaset.yaml \\\n"
  },
  {
    "path": "train.py",
    "chars": 8560,
    "preview": "import os\nfrom pprint import pformat\n\nimport pytorch_lightning as pl\nimport torch\nfrom omegaconf import OmegaConf\nfrom p"
  }
]

// ... and 3 more files (download for full content)

About this extraction

This page contains the full source code of the theEricMa/DiffSpeaker GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 66 files (5.1 MB), approximately 1.3M tokens, and a symbol index with 157 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!