Showing preview only (1,443K chars total). Download the full file or copy to clipboard to get everything.
Repository: Hanbo-Cheng/DAWN-pytorch
Branch: main
Commit: 76c44b8396c3
Files: 291
Total size: 72.5 MB
Directory structure:
gitextract_jmo9ls54/
├── .gitignore
├── DAWN_256.yaml
├── DM_3/
│ ├── datasets_hdtf_wpose_lmk_block_lmk.py
│ ├── datasets_hdtf_wpose_lmk_block_lmk_rand.py
│ ├── modules/
│ │ ├── local_attention.py
│ │ ├── text.py
│ │ ├── video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D.py
│ │ ├── video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D.py
│ │ ├── video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test.py
│ │ ├── video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi.py
│ │ ├── video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py
│ │ └── video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py
│ ├── test_lr.py
│ ├── train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D.py
│ ├── train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D_s2.py
│ └── utils.py
├── LFG/
│ ├── __init__.py
│ ├── augmentation.py
│ ├── frames_dataset.py
│ ├── hdtf_dataset.py
│ ├── modules/
│ │ ├── avd_network.py
│ │ ├── bg_motion_predictor.py
│ │ ├── flow_autoenc.py
│ │ ├── generator.py
│ │ ├── model.py
│ │ ├── pixelwise_flow_predictor.py
│ │ ├── region_predictor.py
│ │ └── util.py
│ ├── run_hdtf.py
│ ├── run_hdtf_crema.py
│ ├── sync_batchnorm/
│ │ ├── __init__.py
│ │ ├── batchnorm.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ └── unittest.py
│ ├── test_flowautoenc_crema_video.py
│ ├── test_flowautoenc_hdtf_video.py
│ ├── test_flowautoenc_hdtf_video_256.py
│ ├── train.py
│ └── vis_flow.py
├── PBnet/
│ ├── run_cvae_h_ann_reemb_rope_eye_3.sh
│ └── src/
│ ├── __init__.py
│ ├── config.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── datasets_hdtf_pos_chunk_norm_2_fast.py
│ │ ├── datasets_hdtf_pos_chunk_norm_eye_fast.py
│ │ ├── datasets_hdtf_pos_df.py
│ │ ├── datasets_hdtf_pos_dict_norm_2.py
│ │ ├── datasets_hdtf_wpose_lmk_block.py
│ │ ├── get_dataset.py
│ │ └── tools.py
│ ├── evaluate/
│ │ ├── __init__.py
│ │ ├── action2motion/
│ │ │ ├── accuracy.py
│ │ │ ├── diversity.py
│ │ │ ├── evaluate.py
│ │ │ ├── fid.py
│ │ │ └── models.py
│ │ ├── evaluate_cvae.py
│ │ ├── evaluate_cvae_debug.py
│ │ ├── evaluate_cvae_f3.py
│ │ ├── evaluate_cvae_f3_debug.py
│ │ ├── evaluate_cvae_f3_mel.py
│ │ ├── evaluate_cvae_norm.py
│ │ ├── evaluate_cvae_norm_all.py
│ │ ├── evaluate_cvae_norm_all_seg.py
│ │ ├── evaluate_cvae_norm_all_seg_weye.py
│ │ ├── evaluate_cvae_norm_all_seg_weye2.py
│ │ ├── evaluate_cvae_norm_eye_pose.py
│ │ ├── evaluate_cvae_norm_eye_pose_test.py
│ │ ├── evaluate_cvae_onlyeye_all_seg.py
│ │ ├── othermetrics/
│ │ │ ├── acceleration.py
│ │ │ └── evaluation.py
│ │ ├── stgcn/
│ │ │ ├── accuracy.py
│ │ │ ├── diversity.py
│ │ │ ├── evaluate.py
│ │ │ └── fid.py
│ │ ├── tables/
│ │ │ ├── archtable.py
│ │ │ ├── bstable.py
│ │ │ ├── easy_table.py
│ │ │ ├── easy_table_A2M.py
│ │ │ ├── kltable.py
│ │ │ ├── latexmodela2m.py
│ │ │ ├── latexmodelsa2m.py
│ │ │ ├── latexmodelsstgcn.py
│ │ │ ├── losstable.py
│ │ │ ├── maketable.py
│ │ │ ├── numlayertable.py
│ │ │ └── posereptable.py
│ │ ├── tools.py
│ │ ├── tvae_eval.py
│ │ ├── tvae_eval_norm.py
│ │ ├── tvae_eval_norm_all.py
│ │ ├── tvae_eval_norm_eye_pose.py
│ │ ├── tvae_eval_norm_eye_pose_seg.py
│ │ ├── tvae_eval_norm_seg.py
│ │ ├── tvae_eval_onlyeye_all_seg.py
│ │ ├── tvae_eval_single.py
│ │ ├── tvae_eval_single_both_eye_pose.py
│ │ ├── tvae_eval_std.py
│ │ ├── tvae_eval_train.py
│ │ ├── tvae_eval_train_norm.py
│ │ └── tvae_eval_train_std.py
│ ├── generate/
│ │ └── generate_sequences.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── architectures/
│ │ │ ├── __init__.py
│ │ │ ├── autotrans.py
│ │ │ ├── fc.py
│ │ │ ├── gru.py
│ │ │ ├── grutrans.py
│ │ │ ├── mlp.py
│ │ │ ├── resnet34.py
│ │ │ ├── tools/
│ │ │ │ ├── embeddings.py
│ │ │ │ ├── resnet.py
│ │ │ │ ├── transformer_layers.py
│ │ │ │ └── util.py
│ │ │ ├── transformer.py
│ │ │ ├── transformerdecoder.py
│ │ │ ├── transformerdecoder4.py
│ │ │ ├── transformerdecoder5.py
│ │ │ ├── transformerreemb.py
│ │ │ ├── transformerreemb5.py
│ │ │ ├── transformerreemb6.py
│ │ │ └── transgru.py
│ │ ├── get_model.py
│ │ ├── modeltype/
│ │ │ ├── __init__.py
│ │ │ ├── cae.py
│ │ │ ├── cae_0.py
│ │ │ ├── cvae.py
│ │ │ └── lstm.py
│ │ ├── rotation2xyz.py
│ │ ├── smpl.py
│ │ └── tools/
│ │ ├── __init__.py
│ │ ├── graphconv.py
│ │ ├── hessian_penalty.py
│ │ ├── losses.py
│ │ ├── mmd.py
│ │ ├── msssim_loss.py
│ │ ├── normalize_data.py
│ │ ├── ssim_loss.py
│ │ └── tools.py
│ ├── parser/
│ │ ├── base.py
│ │ ├── checkpoint.py
│ │ ├── dataset.py
│ │ ├── evaluation.py
│ │ ├── finetunning.py
│ │ ├── generate.py
│ │ ├── model.py
│ │ ├── recognition.py
│ │ ├── tools.py
│ │ ├── training.py
│ │ └── visualize.py
│ ├── preprocess/
│ │ ├── humanact12_process.py
│ │ ├── phspdtools.py
│ │ └── uestc_vibe_postprocessing.py
│ ├── recognition/
│ │ ├── compute_accuracy.py
│ │ ├── get_model.py
│ │ └── models/
│ │ ├── stgcn.py
│ │ └── stgcnutils/
│ │ ├── graph.py
│ │ └── tgcn.py
│ ├── render/
│ │ ├── renderer.py
│ │ └── rendermotion.py
│ ├── train/
│ │ ├── __init__.py
│ │ ├── train_cvae_ganloss_ann_eye.py
│ │ ├── train_cvae_ganloss_ann_fast.py
│ │ ├── trainer.py
│ │ ├── trainer_gan.py
│ │ └── trainer_gan_ann.py
│ ├── utils/
│ │ ├── PYTORCH3D_LICENSE
│ │ ├── __init__.py
│ │ ├── fixseed.py
│ │ ├── get_model_and_data.py
│ │ ├── misc.py
│ │ ├── rotation_conversions.py
│ │ ├── tensors.py
│ │ ├── tensors_eye.py
│ │ ├── tensors_eye_eval.py
│ │ ├── tensors_hdtf.py
│ │ ├── tensors_onlyeye.py
│ │ ├── utils.py
│ │ └── video.py
│ └── visualize/
│ ├── __init__.py
│ ├── anim.py
│ ├── visualize.py
│ ├── visualize_checkpoint.py
│ ├── visualize_dataset.py
│ ├── visualize_latent_space.py
│ ├── visualize_nturefined.py
│ └── visualize_sequence.py
├── README.md
├── README_CN.md
├── config/
│ ├── DAWN_128.yaml
│ ├── DAWN_256.yaml
│ ├── hdtf128.yaml
│ ├── hdtf128_1000ep.yaml
│ ├── hdtf128_1000ep_crema.yaml
│ ├── hdtf256.yaml
│ └── hdtf256_400ep.yaml
├── extract_init_states/
│ ├── FaceBoxes/
│ │ ├── FaceBoxes.py
│ │ ├── FaceBoxes_ONNX.py
│ │ ├── __init__.py
│ │ ├── build_cpu_nms.sh
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ └── faceboxes.py
│ │ ├── onnx.py
│ │ ├── readme.md
│ │ ├── utils/
│ │ │ ├── .gitignore
│ │ │ ├── __init__.py
│ │ │ ├── box_utils.py
│ │ │ ├── build.py
│ │ │ ├── config.py
│ │ │ ├── functions.py
│ │ │ ├── nms/
│ │ │ │ ├── .gitignore
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cpu_nms.cp38-win_amd64.pyd
│ │ │ │ ├── cpu_nms.pyx
│ │ │ │ └── py_cpu_nms.py
│ │ │ ├── nms_wrapper.py
│ │ │ ├── prior_box.py
│ │ │ └── timer.py
│ │ └── weights/
│ │ ├── .gitignore
│ │ ├── FaceBoxesProd.pth
│ │ └── readme.md
│ ├── TDDFA_ONNX.py
│ ├── bfm/
│ │ ├── .gitignore
│ │ ├── __init__.py
│ │ ├── bfm.py
│ │ ├── bfm_onnx.py
│ │ └── readme.md
│ ├── build.sh
│ ├── configs/
│ │ ├── .gitignore
│ │ ├── BFM_UV.mat
│ │ ├── bfm_noneck_v3.onnx
│ │ ├── bfm_noneck_v3.pkl
│ │ ├── indices.npy
│ │ ├── mb05_120x120.yml
│ │ ├── mb1_120x120.yml
│ │ ├── ncc_code.npy
│ │ ├── param_mean_std_62d_120x120.pkl
│ │ ├── readme.md
│ │ ├── resnet_120x120.yml
│ │ └── tri.pkl
│ ├── demo_pose_extract_2d_lmk_img.py
│ ├── functions.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── mobilenet_v1.py
│ │ ├── mobilenet_v3.py
│ │ └── resnet.py
│ ├── pose.py
│ ├── readme.md
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── asset/
│ │ │ ├── .gitignore
│ │ │ ├── build_render_ctypes.sh
│ │ │ └── render.c
│ │ ├── depth.py
│ │ ├── functions.py
│ │ ├── io.py
│ │ ├── onnx.py
│ │ ├── pncc.py
│ │ ├── pose.py
│ │ ├── render.py
│ │ ├── render_ctypes.py
│ │ ├── serialization.py
│ │ ├── tddfa_util.py
│ │ └── uv.py
│ └── weights/
│ ├── .gitignore
│ ├── mb05_120x120.pth
│ ├── mb1_120x120.onnx
│ ├── mb1_120x120.pth
│ └── readme.md
├── filter_fourier.py
├── hubert_extract/
│ └── data_gen/
│ └── process_lrs3/
│ ├── binarizer.py
│ ├── process_audio_hubert.py
│ ├── process_audio_hubert_interpolate.py
│ ├── process_audio_hubert_interpolate_batch.py
│ ├── process_audio_hubert_interpolate_demo.py
│ ├── process_audio_hubert_interpolate_single.py
│ └── process_audio_mel_f0.py
├── misc.py
├── requirements.txt
├── run_ood_test/
│ ├── run_DM_v0_df_test_128_both_pose_blink.sh
│ ├── run_DM_v0_df_test_128_separate_pose_blink.sh
│ ├── run_DM_v0_df_test_256.sh
│ ├── run_DM_v0_df_test_256_1.sh
│ └── run_DM_v0_df_test_256_1_separate_pose_blink.sh
├── sync_batchnorm/
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── comm.py
│ ├── replicate.py
│ ├── replicate_ddp.py
│ └── unittest.py
└── unified_video_generator.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea/
__pycache__
**/__pycache__
cache
submit*
pretrain_models/*
*.mp3
*.mp4
tmp*
output/*
================================================
FILE: DAWN_256.yaml
================================================
input_size: 256
max_n_frames: 200
random_seed: 1234
mean: [0.0, 0.0, 0.0]
win_width: 40
sampling_step: 20
ddim_sampling_eta: 1.0
cond_scale: 1.0
model_config:
is_train: true
pose_dim: 6
config_pth: './config/hdtf256.yaml'
ae_pretrained_pth: './pretrain_models/LFG_256_400ep.pth'
diffusion_pretrained_pth: './pretrain_models/DAWN_256.pth'
================================================
FILE: DM_3/datasets_hdtf_wpose_lmk_block_lmk.py
================================================
# dataset for HDTF, stage 1
from os import name
import sys
sys.path.append('your_path')
import os
import random
import torch
import numpy as np
import torch.utils.data as data
import torch.nn.functional as Ft
import imageio.v2 as imageio
import cv2
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from scipy.interpolate import interp1d
import decord
from torchvision.transforms.functional import to_pil_image
from torchvision import transforms
import time
import pickle as pkl
decord.bridge.set_bridge('torch')
def resize(im, desired_size, interpolation):
old_size = im.shape[:2]
ratio = float(desired_size)/max(old_size)
new_size = tuple(int(x*ratio) for x in old_size)
im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)
delta_w = desired_size - new_size[1]
delta_h = desired_size - new_size[0]
top, bottom = delta_h//2, delta_h-(delta_h//2)
left, right = delta_w//2, delta_w-(delta_w//2)
color = [0, 0, 0]
new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return new_im
class HDTF(data.Dataset):
def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=80, image_size=128,mode='train',
mean=(128, 128, 128), color_jitter=True):
super(HDTF, self).__init__()
self.mean = torch.tensor(mean)[None,:,None,None]
self.data_dir = data_dir
self.pose_dir = pose_dir
self.eye_blink_dir = eye_blink_dir
self.is_jitter = color_jitter
self.max_num_frames = max_num_frames
self.image_size = image_size
self.mode = mode
vid_list = []
# # crema
# self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert'
# if mode == 'train':
# for id_name in os.listdir(data_dir):
# if id_name in ['s64','s76','s88','s90','s91']:
# continue
# vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])
# if mode == 'test':
# for id_name in ['s64','s76','s88','s90','s91']:
# vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])
# self.videos = vid_list
# hdtf
vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\
'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\
'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\
'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\
'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\
'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\
'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\
'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\
'WRA_MarcoRubio_000']
bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000','RD_Radio39_000']
# vid_id_name_list = [item + '.mp4' for item in vid_id_name_list]
# bad_id_name = [item + '.mp4' for item in bad_id_name]
# hdtf
self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk'
self.mouth_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/mouth_ratio_bar'
self.lmk_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/lmk_25hz_chunk'
with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f:
self.len_dict = pkl.load(f)
# vid_id_name_list = ['RD_Radio47_000','WDA_CatherineCortezMasto_000','WDA_JoeNeguse_001','WDA_MichelleLujanGrisham_000','WRA_ErikPaulsen_002', \
# 'WDA_ZoeLofgren_000','WRA_JebHensarling2_003','WRA_MichaelSteele_000', 'WRA_ToddYoung_000', 'WRA_VickyHartzler_000']
if mode == 'train':
for id_name in os.listdir(data_dir):
# id_name = id_name[:-4]
if id_name in vid_id_name_list or id_name in bad_id_name:
continue
vid_list.append(id_name)
self.videos = vid_list
if mode == 'test':
self.videos = vid_id_name_list
def check_head(self, frame_list, video_name, start, end):
'''
Check if the desired pose address exists.
'''
start_path = self.get_pose_path(frame_list, video_name, start)
end_path = self.get_pose_path(frame_list, video_name, end)
if os.path.exists(start_path) and os.path.exists(end_path):
return True
else:
return False
def get_block_data_for_two(self, path, start, end):
# TODO: id function
'''
input:
start: start id
end: end id
output:
the data from block
'''
block_st = start//25
block_ed = end//25
st_pos = block_st % 25
ed_pos = block_ed % 25
block_st_name = 'chunk_%04d.npy' % (block_st)
block_ed_name = 'chunk_%04d.npy' % (block_ed)
if block_st != block_ed:
block_st_path = os.path.join(path, block_st_name)
block_ed_path = os.path.join(path, block_ed_name)
block_st = np.load(block_st_path)
block_ed = np.load(block_ed_path)
return np.concatenate((block_st[st_pos:], block_ed[:ed_pos]))
else:
block_st_path = os.path.join(path, block_st_name)
block_st = np.load(block_st_path)
return block_st[st_pos, ed_pos]
def get_block_data(self, path, start, end):
# TODO: id function
'''
input:
start: start id
end: end id
output:
the data from block
'''
block_st = start//25
block_ed = end//25
st_pos = start % 25
ed_pos = end % 25
block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)]
if block_st != block_ed:
arr_list = []
block_st = np.load(block_list[0])
arr_list.append(block_st[st_pos:])
for path in block_list[1:-1]:
arr_list.append(np.load(path))
block_ed = np.load(block_list[-1])
arr_list.append(block_ed[:ed_pos])
return np.concatenate(arr_list)
else:
block_st_path = os.path.join(path, block_list[0])
block_st = np.load(block_st_path)
return block_st[st_pos: ed_pos]
def check_len(self, name):
return self.len_dict[name]
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
video_name = self.videos[idx]
path = os.path.join(self.data_dir, video_name)
hubert_path = os.path.join(self.hubert_dir, video_name)
lmk_path = os.path.join(self.lmk_dir, video_name)
pose_path = os.path.join(self.pose_dir, video_name)
eye_blink_path = os.path.join(self.eye_blink_dir, video_name)
total_num_frames = self.check_len(video_name)
if total_num_frames <= self.max_num_frames:
sample_frames = total_num_frames
start = 0
else:
sample_frames = self.max_num_frames
start = np.random.randint(total_num_frames-self.max_num_frames)
start=start
stop=sample_frames+start
sample_frame_npy = self.get_block_data(path = path, start = start, end = stop)
sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32)
sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32)
sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32)
sample_frame_list = torch.tensor(sample_frame_npy).permute(0,3,1,2)
sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy)
sample_frame_list = sample_frame_list - self.mean # 20, 3, 128, 128
# sample_frame_list = [np.transpose(x, (2, 0, 1)) for x in sample_frame_list]
# sample_frame_list_npy = np.stack(sample_frame_list, axis=1)
# sample_pose_list_npy = np.stack(sample_pose_list, axis = 1)
# sample_eye_blink_list_npy = np.stack(sample_eye_blink_list, axis = 1)
# change to float32
sample_frame_list = sample_frame_list.permute(1, 0, 2, 3)
# sample_frame_list = np.array(sample_frame_list/255.0, dtype=np.float32) #3, 40, 128, 128
# sample_frame_list = sample_frame_list/255. # put to mode l forward
# added to change the video_name of crema
video_name = video_name.replace('/','_')
sample_pose_list_npy = sample_pose_list_npy.transpose(1,0) # for compatibility
sample_eye_blink_list_npy = sample_eye_blink_list_npy.transpose(1,0)
# if __debug__:
# end_time = time.time() # end
# print(f'process time {end_time- start_time}') # spend lot of time
# start_time = end_time
if self.mode == 'test':
return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, start
return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, total_num_frames
if __name__ == "__main__":
# hdtf
data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz"
pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose"
# crema
# data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'
dataset = HDTF(data_dir=data_dir, pose_dir=pose_dir ,mode='train')
for i in range(10):
dataset.__getitem__(i)
print('------')
test_dataset = data.DataLoader(dataset=dataset,
batch_size=10,
num_workers=8,
shuffle=False)
for i, batch in enumerate(test_dataset):
print(i)
================================================
FILE: DM_3/datasets_hdtf_wpose_lmk_block_lmk_rand.py
================================================
# dataset for HDTF, stage 2
from os import name
import sys
sys.path.append('your_path')
import os
import random
import torch
import numpy as np
import torch.utils.data as data
import torch.nn.functional as Ft
import imageio.v2 as imageio
import cv2
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from scipy.interpolate import interp1d
import decord
from torchvision.transforms.functional import to_pil_image
from torchvision import transforms
import time
import pickle as pkl
decord.bridge.set_bridge('torch')
def resize(im, desired_size, interpolation):
old_size = im.shape[:2]
ratio = float(desired_size)/max(old_size)
new_size = tuple(int(x*ratio) for x in old_size)
im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation)
delta_w = desired_size - new_size[1]
delta_h = desired_size - new_size[0]
top, bottom = delta_h//2, delta_h-(delta_h//2)
left, right = delta_w//2, delta_w-(delta_w//2)
color = [0, 0, 0]
new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return new_im
class HDTF(data.Dataset):
def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=80, image_size=128, audio_dir=None, ref_id = None, mode='train',
mean=(128, 128, 128), color_jitter=True):
super(HDTF, self).__init__()
self.mean = torch.tensor(mean)[None,:,None,None]
self.data_dir = data_dir
self.pose_dir = pose_dir
self.eye_blink_dir = eye_blink_dir
self.is_jitter = color_jitter
self.max_num_frames = max_num_frames
self.image_size = image_size
self.mode = mode
vid_list = []
# # crema
# self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert'
# if mode == 'train':
# for id_name in os.listdir(data_dir):
# if id_name in ['s64','s76','s88','s90','s91']:
# continue
# vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])
# if mode == 'test':
# for id_name in ['s64','s76','s88','s90','s91']:
# vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ])
# self.videos = vid_list
# hdtf
vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\
'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\
'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\
'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\
'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\
'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\
'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\
'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\
'WRA_MarcoRubio_000']
bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000']
# vid_id_name_list = [item + '.mp4' for item in vid_id_name_list]
# bad_id_name = [item + '.mp4' for item in bad_id_name]
# hdtf
if audio_dir == None:
self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk'
else:
self.hubert_dir = audio_dir
self.ref_id = ref_id
self.mouth_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/mouth_ratio_bar'
self.lmk_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/lmk_25hz_chunk'
with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f:
self.len_dict = pkl.load(f)
# vid_id_name_list = ['RD_Radio47_000','WDA_CatherineCortezMasto_000','WDA_JoeNeguse_001','WDA_MichelleLujanGrisham_000','WRA_ErikPaulsen_002', \
# 'WDA_ZoeLofgren_000','WRA_JebHensarling2_003','WRA_MichaelSteele_000', 'WRA_ToddYoung_000', 'WRA_VickyHartzler_000']
if mode == 'train':
for id_name in os.listdir(data_dir):
# id_name = id_name[:-4]
if id_name in vid_id_name_list or id_name in bad_id_name:
continue
vid_list.append(id_name)
self.videos = vid_list
if mode == 'test':
self.videos = vid_id_name_list
def check_head(self, frame_list, video_name, start, end):
'''
Check if the desired pose address exists.
'''
start_path = self.get_pose_path(frame_list, video_name, start)
end_path = self.get_pose_path(frame_list, video_name, end)
if os.path.exists(start_path) and os.path.exists(end_path):
return True
else:
return False
def get_block_data_for_two(self, path, start, end):
# TODO: id function
'''
input:
start: start id
end: end id
output:
the data from block
'''
block_st = start//25
block_ed = end//25
st_pos = block_st % 25
ed_pos = block_ed % 25
block_st_name = 'chunk_%04d.npy' % (block_st)
block_ed_name = 'chunk_%04d.npy' % (block_ed)
if block_st != block_ed:
block_st_path = os.path.join(path, block_st_name)
block_ed_path = os.path.join(path, block_ed_name)
block_st = np.load(block_st_path)
block_ed = np.load(block_ed_path)
return np.concatenate((block_st[st_pos:], block_ed[:ed_pos]))
else:
block_st_path = os.path.join(path, block_st_name)
block_st = np.load(block_st_path)
return block_st[st_pos, ed_pos]
def get_block_data(self, path, start, end):
# TODO: id function
'''
input:
start: start id
end: end id
output:
the data from block
'''
block_st = start//25
block_ed = end//25
st_pos = start % 25
ed_pos = end % 25
block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)]
if block_st != block_ed:
arr_list = []
block_st = np.load(block_list[0])
arr_list.append(block_st[st_pos:])
for path in block_list[1:-1]:
arr_list.append(np.load(path))
block_ed = np.load(block_list[-1])
arr_list.append(block_ed[:ed_pos])
return np.concatenate(arr_list)
else:
block_st_path = os.path.join(path, block_list[0])
block_st = np.load(block_st_path)
return block_st[st_pos: ed_pos]
def check_len(self, name):
return self.len_dict[name]
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
video_name = self.videos[idx]
path = os.path.join(self.data_dir, video_name)
hubert_path = os.path.join(self.hubert_dir, video_name)
lmk_path = os.path.join(self.lmk_dir, video_name)
pose_path = os.path.join(self.pose_dir, video_name)
eye_blink_path = os.path.join(self.eye_blink_dir, video_name)
total_num_frames = self.check_len(video_name)
if total_num_frames <= self.max_num_frames:
sample_frames = total_num_frames
start = 0
else:
sample_frames = self.max_num_frames
start = np.random.randint(total_num_frames-self.max_num_frames)
start=start
stop=sample_frames+start
if self.ref_id == None:
ref_id = np.random.randint(total_num_frames)
elif self.ref_id == "clip":
ref_id = np.random.randint(sample_frames) + start
else:
ref_id = 0
sample_frame_npy = self.get_block_data(path = path, start = start, end = stop)
sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32)
sample_lmk_npy = self.get_block_data(path = lmk_path, start = start, end = stop).astype(np.float32)
sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32)
sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32)
ref_frame_npy = self.get_block_data(path = path, start = ref_id, end = ref_id + 1)
ref_hubert_feature_npy = self.get_block_data(path = hubert_path, start = ref_id, end = ref_id + 1).astype(np.float32)
ref_pose_list_npy = self.get_block_data(path = pose_path, start = ref_id, end = ref_id + 1).astype(np.float32)
ref_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = ref_id, end = ref_id + 1).astype(np.float32)
# mouth_path = os.path.join(self.mouth_dir, video_name+'.npy')
# mouth_seq = np.load(mouth_path).astype(np.float32)
# ref_mouth = mouth_seq[ref_id]
# mouth_seq = mouth_seq[start:stop]
mouth_lmk_tensor = torch.tensor(sample_lmk_npy[:,48:67])
sample_frame_list = torch.tensor(sample_frame_npy).permute(0,3,1,2)
sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy)
sample_frame_list = sample_frame_list - self.mean # 20, 3, 128, 128
ref_frame_npy = torch.tensor(ref_frame_npy).permute(0,3,1,2)
ref_hubert_feature_npy = torch.tensor(ref_hubert_feature_npy)
ref_frame_npy = ref_frame_npy - self.mean # 20, 3, 128, 128
sample_hubert_feature_tensor = torch.concat([ref_hubert_feature_npy, sample_hubert_feature_tensor], dim = 0)
sample_frame_list = torch.concat([ref_frame_npy, sample_frame_list], dim = 0)
# sample_frame_list = [np.transpose(x, (2, 0, 1)) for x in sample_frame_list]
# sample_frame_list_npy = np.stack(sample_frame_list, axis=1)
# sample_pose_list_npy = np.stack(sample_pose_list, axis = 1)
# sample_eye_blink_list_npy = np.stack(sample_eye_blink_list, axis = 1)
# change to float32
sample_frame_list = sample_frame_list.permute(1, 0, 2, 3)
# sample_frame_list = np.array(sample_frame_list/255.0, dtype=np.float32) #3, 40, 128, 128
# sample_frame_list = sample_frame_list/255. # put to mode l forward
# added to change the video_name of crema
video_name = video_name.replace('/','_')
sample_pose_list_npy = np.concatenate([ref_pose_list_npy, sample_pose_list_npy], axis = 0)
sample_eye_blink_list_npy = np.concatenate([ref_eye_blink_list_npy, sample_eye_blink_list_npy], axis = 0)
sample_pose_list_npy = sample_pose_list_npy.transpose(1,0) # for compatibility
sample_eye_blink_list_npy = sample_eye_blink_list_npy.transpose(1,0)
# mouth_seq = np.concatenate([ref_mouth[None], mouth_seq], axis = 0)
# mouth_seq_npy = mouth_seq.transpose(1,0)
# if __debug__:
# end_time = time.time() # end
# print(f'load data time {end_time- start_time}') # spend lot of time
# start_time = end_time
if self.mode == 'test':
return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, mouth_lmk_tensor, video_name, start
return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, mouth_lmk_tensor, video_name, total_num_frames
if __name__ == "__main__":
# hdtf
data_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/images_25hz_128_chunk"
pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar_chunk"
eye_blink_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/eye_blink_bbox_from_xpc_bar_2_chunk"
# crema
# data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images'
dataset = HDTF(data_dir=data_dir,
pose_dir=pose_dir,
eye_blink_dir = eye_blink_dir,
image_size=128,
max_num_frames=30,
color_jitter=True)
for i in range(10):
dataset.__getitem__(i)
print('------')
test_dataset = data.DataLoader(dataset=dataset,
batch_size=10,
num_workers=8,
shuffle=False)
for i, batch in enumerate(test_dataset):
print(i)
================================================
FILE: DM_3/modules/local_attention.py
================================================
import sys
# sys.path.append('your/path/DAWN-pytorch')
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from einops_exts import rearrange_many
import math
import concurrent.futures
# from local_attn_cuda_pkg.test_cuda import attn_forward
# from local_attn_cuda_pkg.test_cuda import attn_forward, compute_res_forward
# def attn_forward(x, y, batch_size, hw, seq_len, k_size, head, d, device):
# attn = torch.zeros(batch_size, hw, seq_len, k_size, head, device=device)
# local_attn_res.attn_cuda(x, y, attn, batch_size, hw, seq_len, k_size, head, d)
# return attn
# def compute_res_forward(attn, z, batch_size, hw, seq_len, head, head_dim, k_size, device):
# res = torch.zeros(batch_size, hw, seq_len, head, head_dim, device=device)
# local_attn_res.compute_res_cuda(attn, z, res, batch_size, hw, seq_len, head, head_dim, k_size)
# return res
def exists(x):
return x is not None
def to_mask(x, mask, mode='mul'):
if mask is None:
return x
else:
while x.dim() > mask.dim():
mask = mask.unsqueeze(-1)
if mode == 'mul':
return x * mask
else:
return x + mask
# def extract_seq_patches(x, kernel_size, rate):
# """x.shape = [batch_size, seq_len, seq_dim]"""
# seq_len = x.size(1)
# seq_dim = x.size(2)
# k_size = kernel_size + (rate - 1) * (kernel_size - 1)
# p_right = (k_size - 1) // 2
# p_left = k_size - 1 - p_right
# x = F.pad(x, (0, 0, p_left, p_right), mode='constant', value=0)
# xs = [x[:, i: i + seq_len] for i in range(0, k_size, rate)]
# x = torch.cat(xs, dim=2)
# return x.reshape(-1, seq_len, kernel_size, seq_dim)
def extract_seq_patches(x, kernel_size, rate):
"""x.shape = [batch_size, hw, seq_len, seq_dim]"""
# batch_size, hw, seq_len, seq_dim = x.size()
# Calculate the size of the expanded kernel and the number of padding to be added on both sides.
k_size = kernel_size + (rate - 1) * (kernel_size - 1)
p_right = (k_size - 1) // 2
p_left = k_size - 1 - p_right
# padding
x = F.pad(x, (0, 0, p_left, p_right), mode='constant', value=0) # pad only the second dimension
# Use the unfold method to extract sliding windows.
x_unfold = x.unfold(dimension=2, size=k_size, step=rate) # x, window, k_size, step, rate
x_unfold = x_unfold.transpose(-1, -2)
# reshape (batch_size, hw, seq_len, kernel_size, seq_dim)
x_patches = x_unfold[:, :, :, ::rate]
return x_patches
def window_attn(x, y, z, kernel_size, mask, rate):
"""y.shape x.shape = [batch_size, hw, seq_len, self.heads, dim_head]"""
batch_size, hw, seq_len, head, head_dim = x.size()
device = x.device
# Calculate the size of the expanded kernel and the number of padding to be added on both sides.
k_size = kernel_size + (rate - 1) * (kernel_size - 1)
p_right = (k_size - 1) // 2
p_left = k_size - 1 - p_right
# padding
y = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) # pad only the second dimension
z = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)
attn = torch.zeros(batch_size, hw, seq_len, k_size, head).to(device)
for i in range(seq_len):
# torch.matmul(x[:,:,i].unsqueeze(2), y[:,:,i:i + k_size].transpose()) # b, hw, 1, d ; b, hw, w, d
attn[:,:, i] = torch.einsum('b n h d, b n w h d -> b n w h', x[:,:,i], y[:,:,i:i + k_size])
# reshape (batch_size, hw, seq_len, kernel_size, seq_dim)
# res = rearrange(res, 'b n l w h -> b n h l w')
attn = to_mask(attn, mask.unsqueeze(0), 'add')
attn = attn - attn.amax(dim=-2, keepdim=True).detach()
attn = F.softmax(attn, dim=-2)
res = torch.zeros(batch_size, hw, seq_len, head, head_dim).to(device)
for i in range(seq_len):
res[:,:,i] = torch.einsum('b n w h, b n w h d -> b n h d', attn[:,:,i], z[:,:,i : i +k_size]) # attn[:,:,i] * z[:,:,i : i +k_size]
res = res.view(batch_size, hw, seq_len, -1)
return res
def window_attn_2(x, y, z, kernel_size, mask, rate): # bad optimization
"""
The optimized window_attn function eliminates two explicit for loops and utilizes tensor operations for parallel computation.
param:
x (Tensor): [batch_size, hw, seq_len, heads, dim_head]
y (Tensor): [batch_size, hw, seq_len, heads, dim_head]
z (Tensor): [batch_size, hw, seq_len, heads, dim_head]
kernel_size (int): window size
mask (Tensor)
rate (int)
return:
Tensor: [batch_size, hw, seq_len, heads * dim_head]
"""
batch_size, hw, seq_len, head, head_dim = x.size()
k_size = kernel_size + (rate - 1) * (kernel_size - 1)
p_right = (k_size - 1) // 2
p_left = k_size - 1 - p_right
y_padded = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) # [batch_size, hw, seq_len + p_left + p_right, heads, dim_head]
z_padded = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)
# y_windows z_windows [batch_size, hw, seq_len, k_size, heads, dim_head]
y_windows = y_padded.as_strided(
size=(batch_size, hw, seq_len, k_size, head, head_dim),
stride=(
y_padded.stride(0),
y_padded.stride(1),
y_padded.stride(2),
y_padded.stride(2),
y_padded.stride(3),
y_padded.stride(4)
)
)
z_windows = z_padded.as_strided(
size=(batch_size, hw, seq_len, k_size, head, head_dim),
stride=(
z_padded.stride(0),
z_padded.stride(1),
z_padded.stride(2),
z_padded.stride(2),
z_padded.stride(3),
z_padded.stride(4)
)
)
# x: [batch_size, hw, seq_len, heads, dim_head] -> [batch_size, hw, seq_len, 1, heads, dim_head]
x_expanded = x #.unsqueeze(3) # [batch_size, hw, seq_len, 1, heads, dim_head]
attn_scores = torch.einsum('b n l h d, b n l s h d -> b n l s h', x_expanded, y_windows)
attn = to_mask(attn_scores, mask.unsqueeze(0), 'add')
attn = attn - attn.amax(dim=-2, keepdim=True).detach()
attn = F.softmax(attn, dim=-2) # Softmax on k_size
res = (attn.unsqueeze(-1) * z_windows).sum(dim=-3)
res = res.view(batch_size, hw, seq_len, head * head_dim)
return res
def window_attn_stream(x, y, z, kernel_size, mask, rate): # bad optimization
"""y.shape x.shape = [batch_size, hw, seq_len, self.heads, dim_head]"""
batch_size, hw, seq_len, head, head_dim = x.size()
# Calculate the size of the expanded kernel and the number of padding to be added on both sides.
k_size = kernel_size + (rate - 1) * (kernel_size - 1)
p_right = (k_size - 1) // 2
p_left = k_size - 1 - p_right
# padding
y = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) # pad only the second dimension
z = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)
attn = torch.zeros(batch_size, hw, seq_len, k_size, head, device=x.device)
res = torch.zeros(batch_size, hw, seq_len, head, head_dim, device=x.device)
streams = [torch.cuda.Stream() for _ in range(seq_len)]
def compute_attn(i):
with torch.cuda.stream(streams[i]):
attn[:, :, i] = torch.einsum('b n h d, b n w h d -> b n w h', x[:, :, i], y[:, :, i:i + k_size])
def compute_res(i):
with torch.cuda.stream(streams[i]):
res[:, :, i] = torch.einsum('b n w h, b n w h d -> b n h d', attn[:, :, i], z[:, :, i:i + k_size])
for i in range(seq_len):
compute_attn(i)
for stream in streams:
stream.synchronize()
attn = to_mask(attn, mask.unsqueeze(0), 'add')
attn = attn - attn.amax(dim=-2, keepdim=True).detach()
attn = F.softmax(attn, dim=-2)
for i in range(seq_len):
compute_res(i)
for stream in streams:
stream.synchronize()
res = res.view(batch_size, hw, seq_len, -1)
return res
def create_sliding_window_mask(x, win_size, rate):
# mask (len, len, head)
# assert mask.dim() == 3, "The input mask must be of shape (len, len, head)"
k_size = win_size + (rate - 1) * (win_size - 1)
p_right = (k_size - 1) // 2
p_left = k_size - 1 - p_right
# padding
x = F.pad(x, (p_left, p_right), mode='constant', value=-1e10) # pad only the second dimension
res = []
for i in range(x.shape[1]):
res.append(x[:, i , i :i +k_size])
return torch.stack(res, dim = 1) # len k_size, head
class OurLayer(nn.Module):
def reuse(self, layer, *args, **kwargs):
outputs = layer(*args, **kwargs)
return outputs
def heavy_computation(x, y, attn, k_size, i):
attn[:,:, i] = torch.einsum('b n h d, b n w h d -> b n w h', x[:,:,i], y[:,:,i:i + k_size])
def heavy_computation2(res, z, attn, k_size, i):
res[:,:,i] = torch.einsum('b n w h, b n w h d -> b n h d', attn[:,:,i], z[:,:,i : i +k_size]) # attn[:,:,i] * z[:,:,i : i +k_size]
from functools import partial
def window_attn_mp(x, y, z, kernel_size, mask, rate):
"""y.shape x.shape = [batch_size, hw, seq_len, self.heads, dim_head]"""
batch_size, hw, seq_len, head, head_dim = x.size()
device = x.device
# Calculate the size of the expanded kernel and the number of padding to be added on both sides.
k_size = kernel_size + (rate - 1) * (kernel_size - 1)
p_right = (k_size - 1) // 2
p_left = k_size - 1 - p_right
# padding
y = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) # pad only the second dimension
z = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0)
attn = torch.zeros(batch_size, hw, seq_len, k_size, head).to(device)
unary = partial(heavy_computation, x,y,attn, k_size)
with concurrent.futures.ProcessPoolExecutor() as executor:
executor.map(unary, list(range(seq_len)))
# reshape (batch_size, hw, seq_len, kernel_size, seq_dim)
# res = rearrange(res, 'b n l w h -> b n h l w')
attn = to_mask(attn, mask.unsqueeze(0), 'add')
attn = attn - attn.amax(dim=-2, keepdim=True).detach()
attn = F.softmax(attn, dim=-2)
res = torch.zeros(batch_size, hw, seq_len, head, head_dim).to(device)
unary2 = partial(heavy_computation2, res,z,attn, k_size)
with concurrent.futures.ProcessPoolExecutor() as executor:
executor.map(unary2, list(range(seq_len)))
res = res.view(batch_size, hw, seq_len, -1)
return res
class LocalSelfAttention_opt(OurLayer):
def __init__(self, d_model, heads, size_per_head, neighbors=3, rate=1, rotary_emb=None,
key_size=None, mask_right=False):
super(LocalSelfAttention_opt, self).__init__()
self.heads = heads
self.size_per_head = size_per_head
self.out_dim = heads * size_per_head
self.key_size = key_size if key_size else size_per_head
self.neighbors = neighbors
self.rate = rate
self.mask_right = mask_right
self.rotary_emb = rotary_emb
# self.q_dense = nn.Linear(self.key_size * self.heads, self.key_size * self.heads, bias=False)
# self.k_dense = nn.Linear(self.key_size * self.heads, self.key_size * self.heads, bias=False)
# self.v_dense = nn.Linear(self.key_size * self.heads, self.key_size * self.heads, bias=False)
# self.q_dense.weight.data.fill_(1)
# self.k_dense.weight.data.fill_(1)
# self.v_dense.weight.data.fill_(1)
self.to_qkv = nn.Linear(d_model, self.key_size * self.heads * 3, bias=False)
self.to_out = nn.Linear(self.key_size * self.heads, d_model, bias=False)
# self.to_qkv.weight.data.fill_(1)
# self.to_out.weight.data.fill_(1)
def forward(self, inputs, pos_bias, focus_present_mask=None,):
# if isinstance(inputs, list):
# x, x_mask = inputs
# else:
# x, x_mask = inputs, None
x = inputs
x_mask = pos_bias
kernel_size = 1 + 2 * self.neighbors
# if x_mask is not None:
# xp_mask = create_sliding_window_mask(x_mask, kernel_size, self.rate) # b, hw, seq, d_model -> b, hw, seq, win, d_model
batch_size, hw, seq_len, seq_dim = x.size()
if x_mask is not None:
xp_mask = x_mask.unsqueeze(0) # b, hw, seq, win, 1
v_mask = xp_mask
else:
v_mask = None
# k = self.k_dense(x)
# v = self.v_dense(x)
qw, k, v = self.to_qkv(x).chunk(3, dim=-1) # qw: b, hw, seq_len, d_model
qw = qw/ (self.key_size ** 0.5)
qw = qw.view(batch_size, hw, seq_len, self.heads, self.key_size)
k = k.view(batch_size, hw, seq_len, self.heads, self.key_size) # b, hw, seq_len,h, d_head
v = v.view(batch_size, hw, seq_len, self.heads, self.key_size)
st = time.time()
if exists(self.rotary_emb):
qw = self.rotary_emb.rotate_queries_or_keys(qw)
k = self.rotary_emb.rotate_queries_or_keys(k)
ed = time.time()
# print("rope local: ", ed - st)
st = time.time()
# qw = qw.view(batch_size * hw, seq_len, seq_dim) # b * hw, seq, d_model
# k = k.view(batch_size, hw, seq_len, self.key_size * self.heads)
res = window_attn(qw, k, v, kernel_size, v_mask.permute(0, 2, 3, 1), rate = 1)
ed = time.time()
# print("rope local: ", ed - st)
return self.to_out(res)
class MultiHeadLocalAttention(nn.Module):
def __init__(self, d_model, num_heads, window_size):
super(MultiHeadLocalAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.window_size = window_size
assert d_model % num_heads == 0
self.depth = d_model // num_heads
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
# self.out = nn.Linear(d_model, d_model)
self.query.weight.data.fill_(1)
self.key.weight.data.fill_(1)
self.value.weight.data.fill_(1)
self.query.bias.data.fill_(0)
self.key.bias.data.fill_(0)
self.value.bias.data.fill_(0)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth)."""
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, depth)
def forward(self, x):
batch_size, seq_len, d_model = x.size()
assert d_model == self.d_model
Q = self.split_heads(self.query(x), batch_size)
K = self.split_heads(self.key(x), batch_size)
V = self.split_heads(self.value(x), batch_size)
# Create the attention scores
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.depth ** 0.5) # (batch_size, num_heads, seq_len, seq_len)
# Create the mask
mask = (torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)).abs()
mask = (mask > self.window_size).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
mask = mask.to(x.device)
# Apply the mask to the attention scores
attn_scores = attn_scores.masked_fill(mask, float('-inf'))
# Compute the attention weights
attn_weights = F.softmax(attn_scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len)
# Compute the output
output = torch.matmul(attn_weights, V) # (batch_size, num_heads, seq_len, depth)
output = output.permute(0, 2, 1, 3) # (batch_size, seq_len, num_heads, depth)
output = output.reshape(batch_size, seq_len, d_model)
return output
class Attention(nn.Module):
def __init__(
self,
dim,
heads=4,
dim_head=32,
rotary_emb=None
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.rotary_emb = rotary_emb
self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)
self.to_out = nn.Linear(hidden_dim, dim, bias=False)
self.to_qkv.weight.data.fill_(1)
self.to_out.weight.data.fill_(1)
def forward(
self,
x,
pos_bias=None,
focus_present_mask=None
):
n, device = x.shape[-2], x.device
qkv = self.to_qkv(x).chunk(3, dim=-1)
if exists(focus_present_mask) and focus_present_mask.all():
# if all batch samples are focusing on present
# it would be equivalent to passing that token's values through to the output
values = qkv[-1]
return self.to_out(values)
# split out heads
q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)
# scale
q = q * self.scale
# rotate positions into queries and keys for time attention
st = time.time()
if exists(self.rotary_emb):
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
ed = time.time()
print("rope normal: ", ed - st)
# similarity
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
# relative positional bias
if exists(pos_bias):
sim = sim + pos_bias
if exists(focus_present_mask) and not (~focus_present_mask).all():
attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)
attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
mask = torch.where(
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# numerical stability
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
# aggregate values
out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
out = rearrange(out, '... h n d -> ... n (h d)')
# return self.to_out(out)
return out
class RelativePositionBias(nn.Module):
def __init__(
self,
heads=8,
num_buckets=32,
max_distance=128
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, n, device):
q_pos = torch.arange(n, dtype=torch.long, device=device)
k_pos = torch.arange(n, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,
max_distance=self.max_distance)
mask = - (((rel_pos > 20) + (rel_pos < - 20)) * (1e10))
values = self.relative_attention_bias(rp_bucket)
return mask + rearrange(values, 'i j h -> h i j')
if __name__ == "__main__":
# Example usage:
d_model = 256
window_size = 20
seq_len = 200
batch_size = 1
head = 4
res_pos = RelativePositionBias(heads=4, max_distance=32)
rope = RotaryEmbedding(min(64, d_model//head), seq_before_head_dim = True)
rope2 = RotaryEmbedding(min(64, d_model//head))
model = LocalSelfAttention_opt(d_model, head, d_model//head, window_size, rotary_emb=rope)
model_2 = Attention(d_model, head, dim_head= d_model//head, rotary_emb = rope2)
rp = res_pos(200, 'cpu')
xp_mask = create_sliding_window_mask(rp, 2 * window_size + 1, 1)
for i in range(5):
x = torch.randn(batch_size, 9, seq_len, d_model)
st = time.time()
output = model([x, xp_mask])
ed = time.time()
print("optimized: ", ed - st)
st = time.time()
output_2 = model_2(x , pos_bias = rp)
ed = time.time()
print("origin: ", ed - st)
print(((output - output_2)**2).mean())
================================================
FILE: DM_3/modules/text.py
================================================
# the code from https://github.com/lucidrains/video-diffusion-pytorch
import torch
from einops import rearrange
def exists(val):
return val is not None
# singleton globals
MODEL = None
TOKENIZER = None
HUBERT_MODEL_DIM = 20*1024
# BERT_MODEL_DIM = 768
def get_tokenizer():
global TOKENIZER
if not exists(TOKENIZER):
TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
return TOKENIZER
def get_bert():
global MODEL
if not exists(MODEL):
MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased')
if torch.cuda.is_available():
MODEL = MODEL.cuda()
return MODEL
# tokenize
def tokenize(texts, add_special_tokens=True):
if not isinstance(texts, (list, tuple)):
texts = [texts]
tokenizer = get_tokenizer()
encoding = tokenizer.batch_encode_plus(
texts,
add_special_tokens=add_special_tokens,
padding=True,
return_tensors='pt'
)
token_ids = encoding.input_ids
return token_ids
# embedding function
@torch.no_grad()
def bert_embed(
token_ids,
return_cls_repr=False,
eps=1e-8,
pad_id=0.
):
model = get_bert()
mask = token_ids != pad_id
if torch.cuda.is_available():
token_ids = token_ids.cuda()
mask = mask.cuda()
outputs = model(
input_ids=token_ids,
attention_mask=mask,
output_hidden_states=True
)
hidden_state = outputs.hidden_states[-1]
if return_cls_repr:
return hidden_state[:, 0] # return [cls] as representation
if not exists(mask):
return hidden_state.mean(dim=1)
mask = mask[:, 1:] # mean all tokens excluding [cls], accounting for length
mask = rearrange(mask, 'b n -> b n 1')
numer = (hidden_state[:, 1:] * mask).sum(dim=1)
denom = mask.sum(dim=1)
masked_mean = numer / (denom + eps)
return masked_mean
================================================
FILE: DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D.py
================================================
'''
stage 1: using 0th as the reference, short and fixed clip
with lip loss, 6D pose, conditioned by cross attention
'''
import os
import torch
import torch.nn as nn
import sys
sys.path.append('your/path')
from LFG.modules.generator import Generator
from LFG.modules.bg_motion_predictor import BGMotionPredictor
from LFG.modules.region_predictor import RegionPredictor
from DM_3.modules.video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi import DynamicNfUnet3D, DynamicNfGaussianDiffusion
import yaml
from sync_batchnorm import DataParallelWithCallback
from filter_fourier import *
from DM_v0_loss_exp.modules.util import AntiAliasInterpolation2d
from torchvision import models
import numpy as np
import time
from einops import rearrange
class Attention(nn.Module):
def __init__(self, params):
super(Attention, self).__init__()
self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False)
self.fc_attention = nn.Linear(params['dim_attention'], 1)
def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):
ht_query = self.fc_query(ht_query)
attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :])
attention_score = self.fc_attention(attention_score).squeeze(3)
attention_score = attention_score - attention_score.max()
attention_score = torch.exp(attention_score) * ctx_mask
attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10)
ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2)
return ct, attention_score
class Face_loc_Encoder(nn.Module):
def __init__(self, dim = 1):
super(Face_loc_Encoder, self).__init__()
self.conv1 = nn.Conv2d(dim, 8, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
return x
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
x = (x - self.mean) / self.std
h_relu1 = self.slice1(x)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss.
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict
class FlowDiffusion(nn.Module):
def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250,
null_cond_prob=0.1,
ddim_sampling_eta=1.,
dim_mults=(1, 2, 4, 8),
is_train=True,
use_residual_flow=False,
learn_null_cond=False,
use_deconv=True,
padding_mode="zeros",
pretrained_pth="",
config_pth=""):
super(FlowDiffusion, self).__init__()
self.use_residual_flow = use_residual_flow
checkpoint = torch.load(pretrained_pth)
with open(config_pth) as f:
config = yaml.safe_load(f)
self.generator = Generator(num_regions=config['model_params']['num_regions'],
num_channels=config['model_params']['num_channels'],
revert_axis_swap=config['model_params']['revert_axis_swap'],
**config['model_params']['generator_params']).cuda()
self.generator.load_state_dict(checkpoint['generator'])
self.generator.eval()
self.set_requires_grad(self.generator, False)
self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],
num_channels=config['model_params']['num_channels'],
estimate_affine=config['model_params']['estimate_affine'],
**config['model_params']['region_predictor_params']).cuda()
self.region_predictor.load_state_dict(checkpoint['region_predictor'])
self.region_predictor.eval()
self.set_requires_grad(self.region_predictor, False)
self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'],
**config['model_params']['bg_predictor_params'])
self.bg_predictor.load_state_dict(checkpoint['bg_predictor'])
self.bg_predictor.eval()
self.set_requires_grad(self.bg_predictor, False)
self.scales = config['train_params']['scales']
self.pyramid = ImagePyramide(self.scales, self.generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
# self.vgg_loss_weights = config['train_params']['loss_weights']['rec_vgg']
# if sum(self.vgg_loss_weights) != 0:
# self.vgg = Vgg19()
# if torch.cuda.is_available():
# self.vgg = self.vgg.cuda()
self.unet = DynamicNfUnet3D(dim=64,
cond_dim=1024 + 6 + 2,
cond_aud=1024,
cond_pose=6,
cond_eye=2,
num_frames=num_frames,
channels=3 + 256 + 16,
out_grid_dim=2,
out_conf_dim=1,
dim_mults=dim_mults,
use_hubert_audio_cond=True,
learn_null_cond=learn_null_cond,
use_final_activation=False,
use_deconv=use_deconv,
padding_mode=padding_mode)
self.diffusion = DynamicNfGaussianDiffusion(
denoise_fn = self.unet,
num_frames=num_frames,
image_size=img_size,
sampling_timesteps=sampling_timesteps,
timesteps=1000, # number of steps
loss_type='l2', # L1 or L2
use_dynamic_thres=True,
null_cond_prob=null_cond_prob,
ddim_sampling_eta=ddim_sampling_eta
)
self.face_loc_emb = Face_loc_Encoder()
# training
self.is_train = is_train
if self.is_train:
self.unet.train()
self.diffusion.train()
def update_num_frames(self, new_num_frames):
# to update num_frames of Unet3D and GaussianDiffusion
self.unet.update_num_frames(new_num_frames)
self.diffusion.update_num_frames(new_num_frames)
def generate_bbox_mask(self, bbox, size = 32):
# b = bbox.shape[0]
b, c, fn = bbox.size()
bbox = bbox[:,:,0] # b, c, fn
bbox[:, :2] = (bbox[:, :2]/bbox[:, 4].unsqueeze(1)) * size # rescale to 32* 32 for 128, and 64 * 64 for 256
bbox[:,2:4] = (bbox[:, 2:4]/bbox[:, 5].unsqueeze(1) )* size
bbox_left_top = bbox[:, :4:2].to(torch.int32) # left up
bbox_right_bottom = (bbox[:, 1:4:2] +1).to(torch.int32) # right down
# generating 2D index
row_indices = torch.arange(size).view(1, size, 1).expand(b, size, size).to(torch.uint8).cuda()
col_indices = torch.arange(size).view(1, 1, size).expand(b, size, size).to(torch.uint8).cuda()
# set the face bbox as 1, the first channel is y, the second is x
mask = (row_indices >= bbox_left_top[:, 1].view(b, 1, 1)) & (row_indices <= bbox_right_bottom[:, 1].view(b, 1, 1)) & \
(col_indices >= bbox_left_top[:, 0].view(b, 1, 1)) & (col_indices <= bbox_right_bottom[:, 0].view(b, 1, 1))
# mask : b,32,32
bbox_mask = mask.unsqueeze(1).float() # b, 1, 32, 32
return bbox_mask
def generate_mouth_mask(self, mouth_lmk, origin_size, size = 32):
b, fn, pn, c = mouth_lmk.size() # b, fn 12, 2
origin_size = origin_size.unsqueeze(1)
mouth_lmk = (mouth_lmk/origin_size) * size
ld_coner = mouth_lmk.max(dim=-2)[0]
ru_coner = mouth_lmk.min(dim=-2)[0]
row_indices = torch.arange(size).view(1, size, 1).expand(b, fn, size, size).to(torch.uint8).cuda()
col_indices = torch.arange(size).view(1, 1, size).expand(b, fn, size, size).to(torch.uint8).cuda()
mask = (row_indices >= ru_coner[:,:, 1].view(b, fn, 1, 1)) & (row_indices <= ld_coner[:,:, 1].view(b, fn, 1, 1)) & \
(col_indices >= ru_coner[:,:, 0].view(b, fn, 1, 1)) & (col_indices <= ld_coner[:,:, 0].view(b, fn, 1, 1))
bbox_mask = (mask).float() # b, 1, 32, 32
return bbox_mask
def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink, bbox, mouth_lmk, is_eval=False, ref_id = 0):
if True:
b,c,f,h,w = real_vid.size()
real_vid = rearrange(real_vid, 'b c f h w -> (b f) c h w')
bright = 64. / 255
contrast = 0.25
sat = 0.25
hue = 0.04
color_jitters = transforms.ColorJitter(hue = (-hue, hue), \
contrast = (max(0, 1 - contrast), 1 + contrast),
saturation = (max(0, 1 - sat), 1 + sat),
brightness = (max(0, 1 - bright), 1 + bright))
# mast have shape : [..., 1 or 3, H, W]
real_vid = real_vid/255. # because the img are floats, so need to scale to 0-1
real_vid = color_jitters(real_vid) # shape need be checked
real_vid = rearrange(real_vid, '(b f) c h w -> b c f h w', b = b, f = f)
ref_img = real_vid[:,:,ref_id,:,:].clone().detach()
b, _, nf, H, W = real_vid.size()
ref_pose = ref_pose.squeeze(1).permute(0, 2, 1)[:, :, :-1]
ref_eye_blink = ref_eye_blink.squeeze(1).permute(0, 2, 1)
init_pose = ref_pose[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 7 init state
init_eye = ref_eye_blink[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 2
ref_text = torch.concat([ref_text, (ref_pose-init_pose), (ref_eye_blink-init_eye)], dim=-1)
bbox_mask = self.generate_bbox_mask(bbox, size = real_vid.shape[-1]) # b, 1, 32, 32
bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask
mouth_mask = self.generate_mouth_mask(mouth_lmk, bbox[:,None,-2:,0], size = real_vid.shape[-1]//4)
real_grid_list = []
real_conf_list = []
real_out_img_list = []
real_warped_img_list = []
output_dict = {}
with torch.no_grad():
b,c,f,h,w = real_vid.size()
real_vid_tmp = rearrange(real_vid, 'b c f h w -> (b f) c h w') # real_vid.reshape(b * f, c, h, w)
ref_img_tmp = ref_img.unsqueeze(1).repeat(1,f,1,1,1).reshape(-1, 3, h, w)
source_region_params = self.region_predictor(ref_img_tmp)
driving_region_params = self.region_predictor(real_vid_tmp)
bg_params = self.bg_predictor(ref_img_tmp, real_vid_tmp)
generated = self.generator(ref_img_tmp, source_region_params=source_region_params,
driving_region_params=driving_region_params, bg_params=bg_params)
output_dict["real_vid_grid"] = rearrange(generated["optical_flow"], '(b f) h w c -> b c f h w', b = b, f = f)
output_dict["real_vid_conf"] = rearrange(generated["occlusion_map"], '(b f) c h w -> b c f h w', b = b, f = f)
output_dict["real_out_vid"] = rearrange(generated["prediction"], '(b f) c h w -> b c f h w', b = b, f = f)
output_dict["real_warped_vid"] = rearrange(generated["deformed"], '(b f) c h w -> b c f h w', b = b, f = f)
ref_img_fea = generated["bottle_neck_feat"][::f].clone().detach() #bs, 256, 32, 32
del real_vid_tmp, ref_img_tmp
del generated
if self.is_train:
if self.use_residual_flow:
h, w, = H // 4, W // 4
identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()
output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion(
torch.cat((output_dict["real_vid_grid"] - identity_grid,
output_dict["real_vid_conf"] * 2 - 1), dim=1),
ref_img_fea,
bbox_mask,
ref_text)
else:
output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion(
torch.cat((output_dict["real_vid_grid"],
output_dict["real_vid_conf"] * 2 - 1), dim=1),
ref_img_fea,
bbox_mask,
ref_text)
pred = self.diffusion.pred_x0
pred_flow = pred[:, :2, :, :, :]
pred_conf = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5
# loss_high_freq = hf_loss(fea = pred_flow, mask = self.gaussian_mask.cuda(), dim = 2)
loss_high_freq = nn.MSELoss(reduce = False)(pred_flow, output_dict["real_vid_grid"]) + nn.MSELoss(reduce = False)(pred_conf, output_dict["real_vid_conf"])
output_dict['mouth_loss'] = ((output_dict["loss"] * mouth_mask.unsqueeze(1)).sum())/(mouth_mask.sum())
output_dict["loss"] = output_dict["loss"].mean(1)
output_dict["floss"] = loss_high_freq.mean(1)
if(is_eval):
with torch.no_grad():
fake_out_img_list = []
fake_warped_img_list = []
pred = self.diffusion.pred_x0 # bs, 3, nf, 32, 32
if self.use_residual_flow:
output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] + identity_grid
else:
output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] # optical flow predicted by DM_2 bs, 2, nf, 32, 32
output_dict["fake_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # occlusion map predicted by DM_2 bs, 1, nf, 32, 32
for idx in range(nf):
fake_grid = output_dict["fake_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1) #bs, 32, 32, 2
fake_conf = output_dict["fake_vid_conf"][:, :, idx, :, :] #bs, 1, 32, 32
# predict fake out image and fake warped image
generated = self.generator.forward_with_flow(source_image=ref_img,
optical_flow=fake_grid,
occlusion_map=fake_conf)
fake_out_img_list.append(generated["prediction"])
fake_warped_img_list.append(generated["deformed"].detach())
del generated
output_dict["fake_out_vid"] = torch.stack(fake_out_img_list, dim=2)
output_dict["fake_warped_vid"] = torch.stack(fake_warped_img_list, dim=2).detach()
# output_dict["rec_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["fake_out_vid"])
# output_dict["rec_warp_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["fake_warped_vid"])
# b,c,f,h,w = real_vid.size()
# real_vid_tensor = real_vid.permute(0,2,1,3,4).reshape(b*f,c,h,w).detach()
# fake_out_vid_tensor = output_dict["fake_out_vid"].permute(0,2,1,3,4).reshape(b*f,c,h,w)
# if sum(self.vgg_loss_weights) != 0:
# pyramide_real = self.pyramid(real_vid_tensor)
# pyramide_generated = self.pyramid(fake_out_vid_tensor)
# value_total = 0
# for scale in self.scales:
# x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
# y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
# for i, weight in enumerate(self.vgg_loss_weights):
# value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
# value_total += self.vgg_loss_weights[i] * value
# output_dict['rec_vgg_loss'] = value_total
# if __debug__:
# end_time = time.time() # end
# # print(f'forward for eval part time {end_time- start_time}')
# start_time = end_time
return output_dict
def sample_one_video(self, real_vid, sample_img, sample_audio_hubert, sample_pose, sample_eye, sample_bbox, cond_scale):
output_dict = {}
sample_img_fea = self.generator.compute_fea(sample_img) # sample_img: bs,3,128,128 sample_img_fea: 1,256,32,32
bbox_mask = self.generate_bbox_mask(sample_bbox, size = sample_img.shape[-1])
bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask
ref_pose = sample_pose.permute(0, 2, 1)[:,:,:-1]
ref_eye_blink = sample_eye.permute(0, 2, 1)
init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1,ref_pose.shape[1], 1)
init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1) # b, fn, 2
ref_text = torch.concat([sample_audio_hubert, (ref_pose - init_pose), (ref_eye_blink - init_eye)], dim=-1)
bs = sample_img_fea.size(0)
# if cond_scale = 1.0, not using unconditional model
# pred bs, 3, nf, 32, 32
pred = self.diffusion.sample(sample_img_fea, bbox_mask, cond=ref_text,
batch_size=bs, cond_scale=cond_scale)
if self.use_residual_flow:
b, _, nf, h, w = pred[:, :2, :, :, :].size()
identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()
output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] + identity_grid
else:
output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] # bs, 2, nf, 32, 32
output_dict["sample_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # bs, 1, nf, 32, 32
nf = output_dict["sample_vid_grid"].size(2)
with torch.no_grad():
sample_out_img_list = []
sample_warped_img_list = []
for idx in range(nf):
sample_grid = output_dict["sample_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1)
sample_conf = output_dict["sample_vid_conf"][:, :, idx, :, :]
# predict fake out image and fake warped image
generated = self.generator.forward_with_flow(source_image=sample_img,
optical_flow=sample_grid,
occlusion_map=sample_conf)
sample_out_img_list.append(generated["prediction"])
sample_warped_img_list.append(generated["deformed"])
output_dict["sample_out_vid"] = torch.stack(sample_out_img_list, dim=2)
output_dict["sample_warped_vid"] = torch.stack(sample_warped_img_list, dim=2)
output_dict["rec_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["sample_out_vid"])
output_dict["rec_warp_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["sample_warped_vid"])
# b,c,f,h,w = real_vid[0].unsqueeze(dim=0).size()
# real_vid_tensor = real_vid[0].unsqueeze(dim=0).permute(0,2,1,3,4).reshape(b*f,c,h,w)
# fake_out_vid_tensor = output_dict["sample_out_vid"].permute(0,2,1,3,4).reshape(b*f,c,h,w)
# pyramide_real = self.pyramid(real_vid_tensor)
# pyramide_generated = self.pyramid(fake_out_vid_tensor)
# if sum(self.vgg_loss_weights) != 0:
# value_total = 0
# for scale in self.scales:
# x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
# y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
# for i, weight in enumerate(self.vgg_loss_weights):
# value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
# value_total += self.vgg_loss_weights[i] * value
# output_dict['rec_vgg_loss'] = value_total
return output_dict
def get_grid(self, b, nf, H, W, normalize=True):
if normalize:
h_range = torch.linspace(-1, 1, H)
w_range = torch.linspace(-1, 1, W)
else:
h_range = torch.arange(0, H)
w_range = torch.arange(0, W)
grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float() # flip h,w to x,y
return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1)
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
bs = 5
img_size = 128
num_frames = 10
ref_text = ["play basketball"] * bs
ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32).cuda()
real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32).cuda()
model = FlowDiffusion(num_frames=num_frames, use_residual_flow=False, sampling_timesteps=10, dim_mults=(1, 2, 4, 8, 16))
model.cuda()
# embedding ref_text
# cond = bert_embed(tokenize(ref_text), return_cls_repr=model.diffusion.text_use_bert_cls).cuda()
# to simulate the situation of hubert embedding
cond = torch.rand((bs,10,1024), dtype=torch.float32).cuda()
model = DataParallelWithCallback(model)
output_dict = model.forward(real_vid=real_vid, ref_img=ref_img, ref_text=cond)
model.module.sample_one_video(sample_img=ref_img[0].unsqueeze(dim=0),
sample_audio_hubert=cond[0].unsqueeze(dim=0),
cond_scale=1.0)
================================================
FILE: DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D.py
================================================
'''
stage 2: using random reference, long and dynamic clip
with lip loss, 6D pose, conditioned by cross attention
'''
import os
import torch
import torch.nn as nn
import sys
sys.path.append('your/path')
from LFG.modules.generator import Generator
from LFG.modules.bg_motion_predictor import BGMotionPredictor
from LFG.modules.region_predictor import RegionPredictor
from DM_3.modules.video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi import DynamicNfUnet3D, DynamicNfGaussianDiffusion
import yaml
from sync_batchnorm import DataParallelWithCallback
from filter_fourier import *
from torchvision import models
import numpy as np
import time
from einops import rearrange
class Attention(nn.Module):
def __init__(self, params):
super(Attention, self).__init__()
self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False)
self.fc_attention = nn.Linear(params['dim_attention'], 1)
def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):
ht_query = self.fc_query(ht_query)
attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :])
attention_score = self.fc_attention(attention_score).squeeze(3)
attention_score = attention_score - attention_score.max()
attention_score = torch.exp(attention_score) * ctx_mask
attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10)
ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2)
return ct, attention_score
class Face_loc_Encoder(nn.Module):
def __init__(self, dim = 1):
super(Face_loc_Encoder, self).__init__()
self.conv1 = nn.Conv2d(dim, 8, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
return x
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
x = (x - self.mean) / self.std
h_relu1 = self.slice1(x)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class FlowDiffusion(nn.Module):
def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250,
null_cond_prob=0.1,
ddim_sampling_eta=1.,
dim_mults=(1, 2, 4, 8),
is_train=True,
use_residual_flow=False,
learn_null_cond=False,
use_deconv=True,
padding_mode="zeros",
pretrained_pth="your_path/data/log-hdtf/hdtf128_2023-11-17_20:13/snapshots/RegionMM.pth",
config_pth="your/path/DAWN-pytorch/config/hdtf128.yaml"):
super(FlowDiffusion, self).__init__()
self.use_residual_flow = use_residual_flow
checkpoint = torch.load(pretrained_pth)
with open(config_pth) as f:
config = yaml.safe_load(f)
self.generator = Generator(num_regions=config['model_params']['num_regions'],
num_channels=config['model_params']['num_channels'],
revert_axis_swap=config['model_params']['revert_axis_swap'],
**config['model_params']['generator_params']).cuda()
self.generator.load_state_dict(checkpoint['generator'])
self.generator.eval()
self.set_requires_grad(self.generator, False)
self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],
num_channels=config['model_params']['num_channels'],
estimate_affine=config['model_params']['estimate_affine'],
**config['model_params']['region_predictor_params']).cuda()
self.region_predictor.load_state_dict(checkpoint['region_predictor'])
self.region_predictor.eval()
self.set_requires_grad(self.region_predictor, False)
self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'],
**config['model_params']['bg_predictor_params'])
self.bg_predictor.load_state_dict(checkpoint['bg_predictor'])
self.bg_predictor.eval()
self.set_requires_grad(self.bg_predictor, False)
self.scales = config['train_params']['scales']
self.unet = DynamicNfUnet3D(dim=64,
cond_dim=1024 + 6 + 2,
cond_aud=1024,
cond_pose=6,
cond_eye=2,
num_frames=num_frames,
channels=3 + 256 + 16,
out_grid_dim=2,
out_conf_dim=1,
dim_mults=dim_mults,
use_hubert_audio_cond=True,
learn_null_cond=learn_null_cond,
use_final_activation=False,
use_deconv=use_deconv,
padding_mode=padding_mode)
self.diffusion = DynamicNfGaussianDiffusion(
denoise_fn = self.unet,
num_frames=num_frames,
image_size=img_size,
sampling_timesteps=sampling_timesteps,
timesteps=1000, # number of steps
loss_type='l2', # L1 or L2
use_dynamic_thres=True,
null_cond_prob=null_cond_prob,
ddim_sampling_eta=ddim_sampling_eta
)
self.face_loc_emb = Face_loc_Encoder()
# training
self.is_train = is_train
if self.is_train:
self.unet.train()
self.diffusion.train()
def update_num_frames(self, new_num_frames):
# to update num_frames of Unet3D and GaussianDiffusion
self.unet.update_num_frames(new_num_frames)
self.diffusion.update_num_frames(new_num_frames)
def generate_bbox_mask(self, bbox, size = 32):
# b = bbox.shape[0]
b, c, fn = bbox.size()
bbox = bbox[:,:,0] # b, c, fn
bbox[:, :2] = (bbox[:, :2]/bbox[:, 4].unsqueeze(1)) * size # rescale 32* 32
bbox[:,2:4] = (bbox[:, 2:4]/bbox[:, 5].unsqueeze(1) )* size
bbox_left_top = bbox[:, :4:2].to(torch.int32)
bbox_right_bottom = (bbox[:, 1:4:2] +1).to(torch.int32)
row_indices = torch.arange(size).view(1, size, 1).expand(b, size, size).to(torch.uint8).cuda()
col_indices = torch.arange(size).view(1, 1, size).expand(b, size, size).to(torch.uint8).cuda()
mask = (row_indices >= bbox_left_top[:, 1].view(b, 1, 1)) & (row_indices <= bbox_right_bottom[:, 1].view(b, 1, 1)) & \
(col_indices >= bbox_left_top[:, 0].view(b, 1, 1)) & (col_indices <= bbox_right_bottom[:, 0].view(b, 1, 1))
bbox_mask = mask.unsqueeze(1).float() # b, 1, 32, 32
return bbox_mask
def generate_mouth_mask(self, mouth_lmk, origin_size, size = 32):
b, fn, pn, c = mouth_lmk.size() # b, fn 12, 2
origin_size = origin_size.unsqueeze(1)
mouth_lmk = (mouth_lmk/origin_size) * size
ld_coner = mouth_lmk.max(dim=-2)[0]
ru_coner = mouth_lmk.min(dim=-2)[0]
row_indices = torch.arange(size).view(1, size, 1).expand(b, fn, size, size).to(torch.uint8).cuda()
col_indices = torch.arange(size).view(1, 1, size).expand(b, fn, size, size).to(torch.uint8).cuda()
mask = (row_indices >= ru_coner[:,:, 1].view(b, fn, 1, 1)) & (row_indices <= ld_coner[:,:, 1].view(b, fn, 1, 1)) & \
(col_indices >= ru_coner[:,:, 0].view(b, fn, 1, 1)) & (col_indices <= ld_coner[:,:, 0].view(b, fn, 1, 1))
bbox_mask = (mask).float() # b, 1, 32, 32
return bbox_mask
def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink, bbox, mouth_lmk, is_eval=False):
if True:
b,c,f,h,w = real_vid.size()
real_vid = rearrange(real_vid, 'b c f h w -> (b f) c h w')
bright = 64. / 255
contrast = 0.25
sat = 0.25
hue = 0.04
# bright_f = random.uniform(max(0, 1 - bright), 1 + bright)
# contrast_f = random.uniform(max(0, 1 - contrast), 1 + contrast)
# sat_f = random.uniform(max(0, 1 - sat), 1 + sat)
# hue_f = random.uniform(-hue, hue)
color_jitters = transforms.ColorJitter(hue = (-hue, hue), \
contrast = (max(0, 1 - contrast), 1 + contrast),
saturation = (max(0, 1 - sat), 1 + sat),
brightness = (max(0, 1 - bright), 1 + bright))
# mast have shape : [..., 1 or 3, H, W]
real_vid = real_vid/255. # because the img are floats, so need to scale to 0-1
real_vid = color_jitters(real_vid) # shape need be checked
real_vid = rearrange(real_vid, '(b f) c h w -> b c f h w', b = b, f = f)
ref_img = real_vid[:,:,0,:,:].clone().detach()
real_vid = real_vid[:,:,1:,:,:]
# if __debug__:
# end_time = time.time() # end
# print(f'data augment time 1 {end_time- start_time}')
# start_time = end_time
# sample_frame_list = color_jitters(sample_frame_list)
# if __debug__:
# end_time = time.time() # end
# print(f'data augment time 2 {end_time- start_time}')
# start_time = end_time
# else:
# real_vid = rearrange(real_vid, 'b f c h w -> b c f h w')
# else:
# real_vid = real_vid/255.
# ref_img = ref_img/255.
b, _, _, H, W = real_vid.size()
_, nf, _ = ref_text.size()
ref_pose = ref_pose.squeeze(1).permute(0, 2, 1)[:, :, :-1]
ref_eye_blink = ref_eye_blink.squeeze(1).permute(0, 2, 1)
init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1, nf, 1) # b, fn, 7 init state
init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1, nf, 1) # b, fn, 2
ref_text = torch.concat([ref_text, (ref_pose-init_pose), (ref_eye_blink-init_eye)], dim=-1)
ref_text = ref_text[:, 1:]
bbox_mask = self.generate_bbox_mask(bbox, size = real_vid.shape[-1]) # b, 1, 32, 32
bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask
mouth_mask = self.generate_mouth_mask(mouth_lmk, bbox[:,None,-2:,0], size = real_vid.shape[-1]//4)
real_grid_list = []
real_conf_list = []
real_out_img_list = []
real_warped_img_list = []
output_dict = {}
# if __debug__:
# end_time = time.time() # end
# # print(f'forward process time {end_time- start_time}')
# start_time = end_time
with torch.no_grad():
# for idx in range(nf):
# driving_region_params = self.region_predictor(real_vid[:, :, idx, :, :])
# bg_params = self.bg_predictor(ref_img, real_vid[:, :, idx, :, :])
# generated = self.generator(ref_img, source_region_params=source_region_params,
# driving_region_params=driving_region_params, bg_params=bg_params)
# generated.update({'source_region_params': source_region_params,
# 'driving_region_params': driving_region_params})
# real_grid_list.append(generated["optical_flow"].permute(0, 3, 1, 2))
# # normalized occlusion map
# real_conf_list.append(generated["occlusion_map"])
# real_out_img_list.append(generated["prediction"])
# real_warped_img_list.append(generated["deformed"])
b,c,f,h,w = real_vid.size()
real_vid_tmp = rearrange(real_vid, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h, w)
ref_img_tmp = ref_img.unsqueeze(1).repeat(1,f,1,1,1).reshape(-1, 3, h, w)
source_region_params = self.region_predictor(ref_img_tmp)
driving_region_params = self.region_predictor(real_vid_tmp)
bg_params = self.bg_predictor(ref_img_tmp, real_vid_tmp)
generated = self.generator(ref_img_tmp, source_region_params=source_region_params,
driving_region_params=driving_region_params, bg_params=bg_params)
output_dict["real_vid_grid"] = rearrange(generated["optical_flow"], '(b f) h w c -> b c f h w', b = b, f = f) # .permute(0,3,1,2).reshape(b, 2, f, 32, 32)
output_dict["real_vid_conf"] = rearrange(generated["occlusion_map"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["occlusion_map"].reshape(b, 1, f, 32, 32)
output_dict["real_out_vid"] = rearrange(generated["prediction"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["prediction"].reshape(b, 3, f, h, w)
output_dict["real_warped_vid"] = rearrange(generated["deformed"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["deformed"].reshape(b, 3, f, h, w)
# output_dict["real_vid_grid"] = torch.stack(real_grid_list, dim=2) # bs,2,num_frames,32,32
# output_dict["real_vid_conf"] = torch.stack(real_conf_list, dim=2) # bs,1,num_frames,32,32
# output_dict["real_out_vid"] = torch.stack(real_out_img_list, dim=2) # bs,3,num_frames,128,128
# output_dict["real_warped_vid"] = torch.stack(real_warped_img_list, dim=2) # bs,3,num_frames,128,128
# reference images are the same for different time steps, just pick the final one
# ref_img_fea = generated["bottle_neck_feat"].clone().detach() #bs, 256, 32, 32
ref_img_fea = generated["bottle_neck_feat"][::f].clone().detach() #bs, 256, 32, 32
del real_vid_tmp, ref_img_tmp
del generated
# if __debug__:
# end_time = time.time() # end
# # print(f'generate gt flow time {end_time- start_time}')
# start_time = end_time
if self.is_train:
if self.use_residual_flow:
h, w, = H // 4, W // 4
identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()
output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion(
torch.cat((output_dict["real_vid_grid"] - identity_grid,
output_dict["real_vid_conf"] * 2 - 1), dim=1),
ref_img_fea,
bbox_mask,
ref_text)
else:
output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion(
torch.cat((output_dict["real_vid_grid"],
output_dict["real_vid_conf"] * 2 - 1), dim=1),
ref_img_fea,
bbox_mask,
ref_text)
pred = self.diffusion.pred_x0
pred_flow = pred[:, :2, :, :, :]
pred_conf = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5
# loss_high_freq = hf_loss(fea = pred_flow, mask = self.gaussian_mask.cuda(), dim = 2)
# loss_high_freq = hf_loss_2(pred_flow, output_dict["real_vid_grid"], dim=2)
loss_high_freq = nn.MSELoss(reduce = False)(pred_flow, output_dict["real_vid_grid"]) + nn.MSELoss(reduce = False)(pred_conf, output_dict["real_vid_conf"])
output_dict['mouth_loss'] = ((output_dict["loss"] * mouth_mask.unsqueeze(1)).sum())/(mouth_mask.sum())
output_dict["loss"] = output_dict["loss"].mean(1)
output_dict["floss"] = loss_high_freq.mean(1)
# if __debug__:
# end_time = time.time() # end
# # print(f'forward diffusion time {end_time- start_time}')
# start_time = end_time
if(is_eval):
with torch.no_grad():
fake_out_img_list = []
fake_warped_img_list = []
pred = self.diffusion.pred_x0 # bs, 3, nf, 32, 32
if self.use_residual_flow:
output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] + identity_grid
else:
output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] # optical flow predicted by DM_2 bs, 2, nf, 32, 32
output_dict["fake_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # occlusion map predicted by DM_2 bs, 1, nf, 32, 32
for idx in range(nf - 1):
fake_grid = output_dict["fake_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1) #bs, 32, 32, 2
fake_conf = output_dict["fake_vid_conf"][:, :, idx, :, :] #bs, 1, 32, 32
# predict fake out image and fake warped image
generated = self.generator.forward_with_flow(source_image=ref_img,
optical_flow=fake_grid,
occlusion_map=fake_conf)
fake_out_img_list.append(generated["prediction"])
fake_warped_img_list.append(generated["deformed"].detach())
del generated
output_dict["fake_out_vid"] = torch.stack(fake_out_img_list, dim=2)
output_dict["fake_warped_vid"] = torch.stack(fake_warped_img_list, dim=2).detach()
# output_dict["rec_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["fake_out_vid"])
# output_dict["rec_warp_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["fake_warped_vid"])
# b,c,f,h,w = real_vid.size()
# if __debug__:
# end_time = time.time() # end
# # print(f'forward for eval part time {end_time- start_time}')
# start_time = end_time
return output_dict
def sample_one_video(self, real_vid, sample_img, sample_audio_hubert, sample_pose, sample_eye, sample_bbox, cond_scale):
output_dict = {}
sample_img_fea = self.generator.compute_fea(sample_img) # sample_img: bs,3,128,128 sample_img_fea: 1,256,32,32
bbox_mask = self.generate_bbox_mask(sample_bbox, size = sample_img.shape[-1])
bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask
ref_pose = sample_pose.permute(0, 2, 1)[:,:,:-1]
ref_eye_blink = sample_eye.permute(0, 2, 1)
init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1,ref_pose.shape[1], 1)
init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1) # b, fn, 2
ref_text = torch.concat([sample_audio_hubert, (ref_pose - init_pose), (ref_eye_blink - init_eye)], dim=-1)
ref_text = ref_text[:, 1:]
bs = sample_img_fea.size(0)
# if cond_scale = 1.0, not using unconditional model
# pred bs, 3, nf, 32, 32
pred = self.diffusion.sample(sample_img_fea, bbox_mask, cond=ref_text,
batch_size=bs, cond_scale=cond_scale)
if self.use_residual_flow:
b, _, nf, h, w = pred[:, :2, :, :, :].size()
identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()
output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] + identity_grid
else:
output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] # bs, 2, nf, 32, 32
output_dict["sample_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # bs, 1, nf, 32, 32
nf = output_dict["sample_vid_grid"].size(2)
with torch.no_grad():
sample_out_img_list = []
sample_warped_img_list = []
for idx in range(nf):
sample_grid = output_dict["sample_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1)
sample_conf = output_dict["sample_vid_conf"][:, :, idx, :, :]
# predict fake out image and fake warped image
generated = self.generator.forward_with_flow(source_image=sample_img,
optical_flow=sample_grid,
occlusion_map=sample_conf)
sample_out_img_list.append(generated["prediction"])
sample_warped_img_list.append(generated["deformed"])
output_dict["sample_out_vid"] = torch.stack(sample_out_img_list, dim=2)
output_dict["sample_warped_vid"] = torch.stack(sample_warped_img_list, dim=2)
output_dict["rec_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["sample_out_vid"])
output_dict["rec_warp_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["sample_warped_vid"])
return output_dict
def get_grid(self, b, nf, H, W, normalize=True):
if normalize:
h_range = torch.linspace(-1, 1, H)
w_range = torch.linspace(-1, 1, W)
else:
h_range = torch.arange(0, H)
w_range = torch.arange(0, W)
grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float() # flip h,w to x,y
return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1)
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
bs = 5
img_size = 128
num_frames = 10
ref_text = ["play basketball"] * bs
ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32).cuda()
real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32).cuda()
model = FlowDiffusion(num_frames=num_frames, use_residual_flow=False, sampling_timesteps=10, dim_mults=(1, 2, 4, 8, 16))
model.cuda()
# embedding ref_text
# cond = bert_embed(tokenize(ref_text), return_cls_repr=model.diffusion.text_use_bert_cls).cuda()
# to simulate the situation of hubert embedding
cond = torch.rand((bs,10,1024), dtype=torch.float32).cuda()
model = DataParallelWithCallback(model)
output_dict = model.forward(real_vid=real_vid, ref_img=ref_img, ref_text=cond)
model.module.sample_one_video(sample_img=ref_img[0].unsqueeze(dim=0),
sample_audio_hubert=cond[0].unsqueeze(dim=0),
cond_scale=1.0)
================================================
FILE: DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test.py
================================================
import os
import torch
import torch.nn as nn
import sys
from LFG.modules.generator import Generator
from LFG.modules.bg_motion_predictor import BGMotionPredictor
from LFG.modules.region_predictor import RegionPredictor
from DM_3.modules.video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test import DynamicNfUnet3D, DynamicNfGaussianDiffusion
import yaml
from sync_batchnorm import DataParallelWithCallback
from filter_fourier import *
from torchvision import models
import numpy as np
import time
from einops import rearrange
class Attention(nn.Module):
def __init__(self, params):
super(Attention, self).__init__()
self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False)
self.fc_attention = nn.Linear(params['dim_attention'], 1)
def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):
ht_query = self.fc_query(ht_query)
attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :])
attention_score = self.fc_attention(attention_score).squeeze(3)
attention_score = attention_score - attention_score.max()
attention_score = torch.exp(attention_score) * ctx_mask
attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10)
ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2)
return ct, attention_score
class Face_loc_Encoder(nn.Module):
def __init__(self, dim = 1):
super(Face_loc_Encoder, self).__init__()
self.conv1 = nn.Conv2d(dim, 8, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
return x
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
x = (x - self.mean) / self.std
h_relu1 = self.slice1(x)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class FlowDiffusion(nn.Module):
def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250, win_width = 40,
null_cond_prob=0.1,
ddim_sampling_eta=1.,
pose_dim = 7,
dim_mults=(1, 2, 4, 8),
is_train=True,
use_residual_flow=False,
learn_null_cond=False,
use_deconv=True,
padding_mode="zeros",
pretrained_pth="your_path/data/log-hdtf/hdtf128_2023-11-17_20:13/snapshots/RegionMM.pth",
config_pth="your/path/DAWN-pytorch/config/hdtf128.yaml"):
super(FlowDiffusion, self).__init__()
self.use_residual_flow = use_residual_flow
checkpoint = torch.load(pretrained_pth)
with open(config_pth) as f:
config = yaml.safe_load(f)
self.generator = Generator(num_regions=config['model_params']['num_regions'],
num_channels=config['model_params']['num_channels'],
revert_axis_swap=config['model_params']['revert_axis_swap'],
**config['model_params']['generator_params']).cuda()
self.generator.load_state_dict(checkpoint['generator'])
self.generator.eval()
self.set_requires_grad(self.generator, False)
self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],
num_channels=config['model_params']['num_channels'],
estimate_affine=config['model_params']['estimate_affine'],
**config['model_params']['region_predictor_params']).cuda()
self.region_predictor.load_state_dict(checkpoint['region_predictor'])
self.region_predictor.eval()
self.set_requires_grad(self.region_predictor, False)
self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'],
**config['model_params']['bg_predictor_params'])
self.bg_predictor.load_state_dict(checkpoint['bg_predictor'])
self.bg_predictor.eval()
self.set_requires_grad(self.bg_predictor, False)
self.scales = config['train_params']['scales']
self.pose_dim = pose_dim
self.unet = DynamicNfUnet3D(dim=64,
cond_dim=1024 + self.pose_dim + 2,
cond_aud=1024,
cond_pose=self.pose_dim,
cond_eye=2,
num_frames=num_frames,
channels=3 + 256 + 16,
out_grid_dim=2,
out_conf_dim=1,
dim_mults=dim_mults,
use_hubert_audio_cond=True,
learn_null_cond=learn_null_cond,
use_final_activation=False,
use_deconv=use_deconv,
padding_mode=padding_mode,
win_width = win_width)
self.diffusion = DynamicNfGaussianDiffusion(
denoise_fn = self.unet,
num_frames=num_frames,
image_size=img_size,
sampling_timesteps=sampling_timesteps,
timesteps=1000, # number of steps
loss_type='l2', # L1 or L2
use_dynamic_thres=True,
null_cond_prob=null_cond_prob,
ddim_sampling_eta=ddim_sampling_eta
)
self.face_loc_emb = Face_loc_Encoder()
# training
self.is_train = is_train
if self.is_train:
self.unet.train()
self.diffusion.train()
def update_num_frames(self, new_num_frames):
# to update num_frames of Unet3D and GaussianDiffusion
self.unet.update_num_frames(new_num_frames)
self.diffusion.update_num_frames(new_num_frames)
def generate_bbox_mask(self, bbox, size = 32):
# b = bbox.shape[0]
b, c, fn = bbox.size()
bbox = bbox[:,:,0] # b, c, fn
bbox[:, :2] = (bbox[:, :2]/bbox[:, 4].unsqueeze(1)) * size
bbox[:,2:4] = (bbox[:, 2:4]/bbox[:, 5].unsqueeze(1) )* size
bbox_left_top = bbox[:, :4:2].to(torch.int32)
bbox_right_bottom = (bbox[:, 1:4:2] +1).to(torch.int32)
row_indices = torch.arange(size).view(1, size, 1).expand(b, size, size).to(torch.uint8).cuda()
col_indices = torch.arange(size).view(1, 1, size).expand(b, size, size).to(torch.uint8).cuda()
mask = (row_indices >= bbox_left_top[:, 1].view(b, 1, 1)) & (row_indices <= bbox_right_bottom[:, 1].view(b, 1, 1)) & \
(col_indices >= bbox_left_top[:, 0].view(b, 1, 1)) & (col_indices <= bbox_right_bottom[:, 0].view(b, 1, 1))
bbox_mask = mask.unsqueeze(1).float() # b, 1, 32, 32
return bbox_mask
def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink, bbox, is_eval=False, ref_id = 0):
if True:
b,c,f,h,w = real_vid.size()
real_vid = rearrange(real_vid, 'b c f h w -> (b f) c h w')
bright = 64. / 255
contrast = 0.25
sat = 0.25
hue = 0.04
color_jitters = transforms.ColorJitter(hue = (-hue, hue), \
contrast = (max(0, 1 - contrast), 1 + contrast),
saturation = (max(0, 1 - sat), 1 + sat),
brightness = (max(0, 1 - bright), 1 + bright))
# mast have shape : [..., 1 or 3, H, W]
real_vid = real_vid/255. # because the img are floats, so need to scale to 0-1
real_vid = color_jitters(real_vid) # shape need be checked
real_vid = rearrange(real_vid, '(b f) c h w -> b c f h w', b = b, f = f)
ref_img = real_vid[:,:,ref_id,:,:].clone().detach()
b, _, nf, H, W = real_vid.size()
ref_pose = ref_pose.squeeze(1).permute(0, 2, 1)
ref_eye_blink = ref_eye_blink.squeeze(1).permute(0, 2, 1)
init_pose = ref_pose[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 7 init state
init_eye = ref_eye_blink[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 2
ref_text = torch.concat([ref_text, (ref_pose-init_pose), (ref_eye_blink-init_eye)], dim=-1)
bbox_mask = self.generate_bbox_mask(bbox, size = H) # b, 1, 32, 32
bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask
real_grid_list = []
real_conf_list = []
real_out_img_list = []
real_warped_img_list = []
output_dict = {}
with torch.no_grad():
b,c,f,h,w = real_vid.size()
real_vid_tmp = rearrange(real_vid, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h, w)
ref_img_tmp = ref_img.unsqueeze(1).repeat(1,f,1,1,1).reshape(-1, 3, 128, 128)
source_region_params = self.region_predictor(ref_img_tmp)
driving_region_params = self.region_predictor(real_vid_tmp)
bg_params = self.bg_predictor(ref_img_tmp, real_vid_tmp)
generated = self.generator(ref_img_tmp, source_region_params=source_region_params,
driving_region_params=driving_region_params, bg_params=bg_params)
output_dict["real_vid_grid"] = rearrange(generated["optical_flow"], '(b f) h w c -> b c f h w', b = b, f = f) # .permute(0,3,1,2).reshape(b, 2, f, 32, 32)
output_dict["real_vid_conf"] = rearrange(generated["occlusion_map"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["occlusion_map"].reshape(b, 1, f, 32, 32)
output_dict["real_out_vid"] = rearrange(generated["prediction"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["prediction"].reshape(b, 3, f, h, w)
output_dict["real_warped_vid"] = rearrange(generated["deformed"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["deformed"].reshape(b, 3, f, h, w)
ref_img_fea = generated["bottle_neck_feat"][::f].clone().detach() #bs, 256, 32, 32
del real_vid_tmp, ref_img_tmp
del generated
if self.is_train:
if self.use_residual_flow:
h, w, = H // 4, W // 4
identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()
output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion(
torch.cat((output_dict["real_vid_grid"] - identity_grid,
output_dict["real_vid_conf"] * 2 - 1), dim=1),
ref_img_fea,
bbox_mask,
ref_text)
else:
output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion(
torch.cat((output_dict["real_vid_grid"],
output_dict["real_vid_conf"] * 2 - 1), dim=1),
ref_img_fea,
bbox_mask,
ref_text)
pred = self.diffusion.pred_x0
pred_flow = pred[:, :2, :, :, :]
# loss_high_freq = hf_loss(fea = pred_flow, mask = self.gaussian_mask.cuda(), dim = 2)
loss_high_freq = hf_loss_2(pred_flow, output_dict["real_vid_grid"], dim=2)
output_dict["loss"] = output_dict["loss"].mean(1)
output_dict["floss"] = loss_high_freq.mean(1)
# if __debug__:
# end_time = time.time() # end
# # print(f'forward diffusion time {end_time- start_time}')
# start_time = end_time
if(is_eval):
with torch.no_grad():
fake_out_img_list = []
fake_warped_img_list = []
pred = self.diffusion.pred_x0 # bs, 3, nf, 32, 32
if self.use_residual_flow:
output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] + identity_grid
else:
output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] # optical flow predicted by DM_2 bs, 2, nf, 32, 32
output_dict["fake_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # occlusion map predicted by DM_2 bs, 1, nf, 32, 32
for idx in range(nf):
fake_grid = output_dict["fake_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1) #bs, 32, 32, 2
fake_conf = output_dict["fake_vid_conf"][:, :, idx, :, :] #bs, 1, 32, 32
# predict fake out image and fake warped image
generated = self.generator.forward_with_flow(source_image=ref_img,
optical_flow=fake_grid,
occlusion_map=fake_conf)
fake_out_img_list.append(generated["prediction"])
fake_warped_img_list.append(generated["deformed"].detach())
del generated
output_dict["fake_out_vid"] = torch.stack(fake_out_img_list, dim=2)
output_dict["fake_warped_vid"] = torch.stack(fake_warped_img_list, dim=2).detach()
return output_dict
def sample_one_video(self, sample_img, sample_audio_hubert, sample_pose, sample_eye, sample_bbox, cond_scale, init_pose = None, init_eye = None, real_vid = None):
output_dict = {}
sample_img_fea = self.generator.compute_fea(sample_img) # sample_img: bs,3,128,128 sample_img_fea: 1,256,32,32
bbox_mask = self.generate_bbox_mask(sample_bbox, size = sample_img.shape[-1])
bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask
sample_pose = sample_pose[:,:self.pose_dim]
ref_pose = sample_pose.permute(0, 2, 1)
ref_eye_blink = sample_eye.permute(0, 2, 1)
if init_pose == None:
init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1,ref_pose.shape[1], 1)
else:
init_pose = init_pose.unsqueeze(1).repeat(1,ref_pose.shape[1], 1)
init_pose = init_pose[:,:,:self.pose_dim]
if init_eye == None:
init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1)
else:
init_eye = init_eye.unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1)
if ref_pose.shape[-1] != init_pose.shape[-1]:
ref_pose = torch.concat([ref_pose, init_pose[:,:,-1].unsqueeze(-1)], dim = -1)
ref_text = torch.concat([sample_audio_hubert, (ref_pose - init_pose), (ref_eye_blink - init_eye)], dim=-1)
bs = sample_img_fea.size(0)
# if cond_scale = 1.0, not using unconditional model
# pred bs, 3, nf, 32, 32
start_time = time.time() # end
start_time_total = time.time() # end
pred = self.diffusion.sample(sample_img_fea, bbox_mask, cond=ref_text,
batch_size=bs, cond_scale=cond_scale)
if self.use_residual_flow:
b, _, nf, h, w = pred[:, :2, :, :, :].size()
identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda()
output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] + identity_grid
else:
output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] # bs, 2, nf, 32, 32
output_dict["sample_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # bs, 1, nf, 32, 32
nf = output_dict["sample_vid_grid"].size(2)
end_time = time.time() # end
print(f'DDIM time {end_time- start_time}')
start_time = end_time
with torch.no_grad():
sample_out_img_list = []
sample_warped_img_list = []
for idx in range(nf):
sample_grid = output_dict["sample_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1)
sample_conf = output_dict["sample_vid_conf"][:, :, idx, :, :]
# predict fake out image and fake warped image
generated = self.generator.forward_with_flow(source_image=sample_img,
optical_flow=sample_grid,
occlusion_map=sample_conf)
sample_out_img_list.append(generated["prediction"])
sample_warped_img_list.append(generated["deformed"])
output_dict["sample_out_vid"] = torch.stack(sample_out_img_list, dim=2)
output_dict["sample_warped_vid"] = torch.stack(sample_warped_img_list, dim=2)
# real_vid_tmp = rearrange(real_vids, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h, w)
# with torch.no_grad():
# sample_grid = output_dict["sample_vid_grid"]
# sample_grid = rearrange(sample_grid, 'b c f h w -> (b f) h w c')
# sample_conf = output_dict["sample_vid_conf"]
# sample_conf = rearrange(sample_conf, 'b c f h w -> (b f) c h w')
# sample_img = sample_img.repeat(nf, 1, 1, 1)
# generated = self.generator.forward_with_flow(source_image=sample_img,
# optical_flow=sample_grid,
# occlusion_map=sample_conf)
# output_dict["sample_out_vid"] = rearrange(generated["prediction"], '(b f) c h w -> b c f h w', b = 1, f =nf)
end_time = time.time() # end
# with open('your/path/DAWN-pytorch/speed_test.txt', 'a') as f:
# f.write(f'AE time {end_time- start_time}\n')
# f.write(f'Total time {end_time- start_time_total}')
# print(f'AE time {end_time- start_time}')
# print(f'Total time {end_time- start_time_total}')
start_time = end_time
return output_dict
def get_grid(self, b, nf, H, W, normalize=True):
if normalize:
h_range = torch.linspace(-1, 1, H)
w_range = torch.linspace(-1, 1, W)
else:
h_range = torch.arange(0, H)
w_range = torch.arange(0, W)
grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float() # flip h,w to x,y
return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1)
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
bs = 5
img_size = 128
num_frames = 10
ref_text = ["play basketball"] * bs
ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32).cuda()
real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32).cuda()
model = FlowDiffusion(num_frames=num_frames, use_residual_flow=False, sampling_timesteps=10, dim_mults=(1, 2, 4, 8, 16))
model.cuda()
# embedding ref_text
# cond = bert_embed(tokenize(ref_text), return_cls_repr=model.diffusion.text_use_bert_cls).cuda()
# to simulate the situation of hubert embedding
cond = torch.rand((bs,10,1024), dtype=torch.float32).cuda()
model = DataParallelWithCallback(model)
output_dict = model.forward(real_vid=real_vid, ref_img=ref_img, ref_text=cond)
model.module.sample_one_video(sample_img=ref_img[0].unsqueeze(dim=0),
sample_audio_hubert=cond[0].unsqueeze(dim=0),
cond_scale=1.0)
================================================
FILE: DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi.py
================================================
'''
adding pose condtioning on baseline
using cross attention to add different condition
for training
'''
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial
from torchvision import transforms as T
from PIL import Image
from tqdm import tqdm
from einops import rearrange, repeat, reduce, pack, unpack
from einops_exts import rearrange_many
from rotary_embedding_torch import RotaryEmbedding
# from DM.modules.text import tokenize, bert_embed, HUBERT_MODEL_DIM
# helpers functions
def exists(x):
return x is not None
def noop(*args, **kwargs):
pass
def is_odd(n):
return (n % 2) == 1
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cycle(dl):
while True:
for data in dl:
yield data
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
def is_list_str(x):
if not isinstance(x, (list, tuple)):
return False
return all([type(el) == str for el in x])
# relative positional bias
class RelativePositionBias(nn.Module):
def __init__(
self,
heads=8,
num_buckets=32,
max_distance=128
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, n, device):
q_pos = torch.arange(n, dtype=torch.long, device=device)
k_pos = torch.arange(n, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,
max_distance=self.max_distance)
# mask = -(((rel_pos > 35) + (rel_pos < -35)) * (1e8)) # -(((rp_bucket ==15) + (rp_bucket >= 30)) * (1e8))
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j') # + mask
# small helper modules
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
def Upsample(dim, use_deconv=True, padding_mode="reflect"):
if use_deconv:
return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
else:
return nn.Sequential(
nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'),
nn.Conv3d(dim, dim, (1, 3, 3), (1, 1, 1), (0, 1, 1), padding_mode=padding_mode)
)
def Downsample(dim):
return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
def forward(self, x):
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
class LayerNorm_img(nn.Module):
def __init__(self, dim, stable = False):
super().__init__()
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
def l2norm(t):
return F.normalize(t, dim = -1)
# building block modules
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, time_scale_shift=None, audio_scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exists(time_scale_shift):
time_scale, time_shift = time_scale_shift
x = x * (time_scale + 1) + time_shift
# added by lml to change the control method of audio embedding, inspired by diffusedhead
# if exists(audio_scale_shift):
# # audio_scale and audio_shift:(bs, 64, nf, 1, 1)
# # x:(bs, 64, nf, 32, 32)
# audio_scale, audio_shift = audio_scale_shift
# x = x * (audio_scale + 1) + audio_shift
return self.act(x)
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):
super().__init__()
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.audio_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(audio_emb_dim, dim_out * 2)
) if exists(audio_emb_dim) else None
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None, audio_emb=None):
time_scale_shift = None
audio_scale_shift = None
if exists(self.time_mlp):
assert exists(time_emb), 'time emb must be passed in'
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1
time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1
# added by lml to get audio embedding
if exists(self.audio_mlp):
assert exists(audio_emb), 'audio emb must be passed in'
audio_emb = self.audio_mlp(audio_emb)
audio_emb = rearrange(audio_emb, 'b n c -> b c n 1 1') # bs, 128, nf, 1, 1
audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1
h = self.block1(x, time_scale_shift=time_scale_shift, audio_scale_shift=audio_scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class ResnetBlock_ca(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):
super().__init__()
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.audio_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(audio_emb_dim, dim_out * 2)
) if exists(audio_emb_dim) else None
# self.audio_mlp_2 = nn.Sequential(
# nn.SiLU(),
# nn.Linear(dim_out, dim_out * 2)
# ) if exists(audio_emb_dim) else None
attn_klass = CrossAttention
self.cross_attn = attn_klass(
dim = dim,
context_dim = dim_out * 2
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None, audio_emb=None):
time_scale_shift = None
audio_scale_shift = None
b, c, f, H, W = x.size()
if exists(self.time_mlp):
assert exists(time_emb), 'time emb must be passed in'
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1
time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1
# added by lml to get audio embedding
if exists(self.audio_mlp):
assert exists(audio_emb), 'audio emb must be passed in'
audio_emb = self.audio_mlp(audio_emb)
if exists(self.cross_attn):
# h = rearrange(x, 'b c f ... -> (b f) ... c')
# h, ps = pack([h], 'b * c')
# audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')
# audio_emb = self.cross_attn(h, context = audio_emb)
# # h, = unpack(h, ps, 'b * c')
# # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)
# # audio_emb = self.audio_mlp_2(audio_emb)
# audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)
assert exists(audio_emb)
h = rearrange(x, 'b c f ... -> (b f) ... c')
# h = rearrange(x, 'b c ... -> b ... c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = audio_emb) + h
h, = unpack(h, ps, 'b * c')
# h = rearrange(h, 'b ... c -> b c ...')
h = rearrange(h, '(b f) ... c -> b f c ...', b = b, f = f)
# audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1
# audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1
h = self.block1(x, time_scale_shift=time_scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class ResnetBlock_ca_mul(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, pose_emb_dim=None, eye_emb_dim=None, groups=8):
super().__init__()
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.audio_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(audio_emb_dim, dim_out * 2)
) if exists(audio_emb_dim) else None
self.pose_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(pose_emb_dim, dim_out * 2)
) if exists(pose_emb_dim) else None
self.eye_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(eye_emb_dim, dim_out * 2)
) if exists(eye_emb_dim) else None
self.audio_emb_dim = audio_emb_dim
self.pose_emb_dim = pose_emb_dim
self.eye_emb_dim = eye_emb_dim
# self.audio_mlp_2 = nn.Sequential(
# nn.SiLU(),
# nn.Linear(dim_out, dim_out * 2)
# ) if exists(audio_emb_dim) else None
attn_klass = CrossAttention
self.cross_attn_aud = attn_klass(
dim = dim,
context_dim = dim_out * 2,
out_dim = dim_out
)
self.cross_attn_pose = attn_klass(
dim = dim,
context_dim = dim_out * 2,
out_dim = dim_out
)
self.cross_attn_eye = attn_klass(
dim = dim,
context_dim = dim_out * 2,
out_dim = dim_out
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None, audio_emb=None):
time_scale_shift = None
audio_scale_shift = None
'''
need seperate 3 diffiserent condition
'''
if exists(audio_emb):
pose_emb = audio_emb[:,:,self.audio_emb_dim:self.audio_emb_dim + self.pose_emb_dim]
eye_emb = audio_emb[:,:,self.audio_emb_dim + self.pose_emb_dim: ]
audio_emb = audio_emb[:,:,:self.audio_emb_dim]
b, c, f, H, W = x.size()
if exists(self.time_mlp):
assert exists(time_emb), 'time emb must be passed in'
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1
time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1
# added by lml to get audio embedding
if exists(self.audio_mlp): # mouth lmk + audio emb
assert exists(audio_emb), 'audio emb must be passed in'
audio_emb = self.audio_mlp(audio_emb)
pose_emb = self.pose_mlp(pose_emb) # TODO: embedding
eye_emb = self.eye_mlp(eye_emb)
if exists(self.cross_attn_aud):
# h = rearrange(x, 'b c f ... -> (b f) ... c')
# h, ps = pack([h], 'b * c')
# audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')
# audio_emb = self.cross_attn(h, context = audio_emb)
# # h, = unpack(h, ps, 'b * c')
# # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)
# # audio_emb = self.audio_mlp_2(audio_emb)
# audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)
assert exists(audio_emb)
h_cond = rearrange(x, 'b c f ... -> (b f) ... c')
# h = rearrange(x, 'b c ... -> b ... c')
h_cond, ps = pack([h_cond], 'b * c')
h_pose = self.cross_attn_pose(h_cond, context = pose_emb)
h_aud = self.cross_attn_aud(h_cond, context = audio_emb)
h_eye = self.cross_attn_eye(h_cond, context = eye_emb)
h_cond = h_pose + h_aud + h_eye
h_cond, = unpack(h_cond, ps, 'b * c')
# h = rearrange(h, 'b ... c -> b c ...')
h_cond = rearrange(h_cond, '(b f) ... c -> b c f ...', b = b, f = f)
# audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1
# audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1
h = self.block1(x, time_scale_shift=time_scale_shift)
if exists(self.audio_mlp):
h = h_cond + h
h = self.block2(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
out_dim,
*,
context_dim = None,
dim_head = 8,
heads = 8,
norm_context = False,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = LayerNorm_img(dim)
self.norm_context = LayerNorm_img(context_dim) if norm_context else Identity()
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, out_dim, bias = False),
LayerNorm_img(out_dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x) # bn * fn ?
# context: b, fn, c
context = rearrange(context, 'b f c -> (b f) c')
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# cosine sim attention
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# similarities
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearCrossAttention(CrossAttention):
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = rearrange(context, 'b f c -> (b f) c')
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2) # b * nf * h, 2, c//h
v = torch.cat((nv, v), dim = -2)
# masking
max_neg_value = -torch.finfo(x.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b n -> b 1 n')
k = k.masked_fill(~mask, max_neg_value)
v = v.masked_fill(~mask, 0.)
# linear attention
q = q.softmax(dim = -1) # # b * nf * h, 32*32, c//h,
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v) # b * nf * h, 2, c//h, b * nf * h, 2, c//h
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
return self.to_out(out)
class SpatialLinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, f, h, w = x.shape
x = rearrange(x, 'b c f h w -> (b f) c h w')
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)
q = q.softmax(dim=-2)
k = k.softmax(dim=-1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)
out = self.to_out(out)
return rearrange(out, '(b f) c h w -> b c f h w', b=b)
# attention along space and time
class EinopsToAndFrom(nn.Module):
def __init__(self, from_einops, to_einops, fn):
super().__init__()
self.from_einops = from_einops
self.to_einops = to_einops
self.fn = fn
def forward(self, x, **kwargs):
shape = x.shape
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
x = self.fn(x, **kwargs)
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
heads=4,
dim_head=32,
rotary_emb=None
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.rotary_emb = rotary_emb
self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)
self.to_out = nn.Linear(hidden_dim, dim, bias=False)
def forward(
self,
x,
pos_bias=None,
focus_present_mask=None
): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c'
n, device = x.shape[-2], x.device
qkv = self.to_qkv(x).chunk(3, dim=-1)
if exists(focus_present_mask) and focus_present_mask.all():
# if all batch samples are focusing on present
# it would be equivalent to passing that token's values through to the output
values = qkv[-1]
return self.to_out(values)
# split out heads
q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)
# scale
q = q * self.scale
# rotate positions into queries and keys for time attention
if exists(self.rotary_emb):
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
# similarity
sim = einsum('... h i d, ... h j d -> ... h i j', q, k)
# relative positional bias
if exists(pos_bias):
sim = sim + pos_bias
if exists(focus_present_mask) and not (~focus_present_mask).all():
attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)
attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
mask = torch.where(
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# numerical stability
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
# aggregate values
out = einsum('... h i j, ... h j d -> ... h i d', attn, v)
out = rearrange(out, '... h n d -> ... n (h d)')
return self.to_out(out)
# model
class Unet3D(nn.Module):
def __init__(
self,
dim,
cond_aud=1024,
cond_pose=7,
cond_eye=2,
cond_dim=None,
out_grid_dim=2,
out_conf_dim=1,
num_frames=40,
dim_mults=(1, 2, 4, 8),
channels=3,
attn_heads=8,
attn_dim_head=32,
use_hubert_audio_cond=False,
init_dim=None,
init_kernel_size=7,
use_sparse_linear_attn=True,
resnet_groups=8,
use_final_activation=False,
learn_null_cond=False,
use_deconv=True,
padding_mode="zeros",
):
super().__init__()
self.null_cond_mask = None
self.channels = channels
self.num_frames = num_frames
self.HUBERT_MODEL_DIM = 1024
# temporal attention and its relative positional encoding
rotary_emb = RotaryEmbedding(min(32, attn_dim_head))
temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c',
Attention(dim, heads=attn_heads, dim_head=attn_dim_head,
rotary_emb=rotary_emb))
self.time_rel_pos_bias = RelativePositionBias(heads=attn_heads,
max_distance=32) # realistically will not be able to generate that many frames of video... yet
# initial conv
init_dim = default(init_dim, dim)
assert is_odd(init_kernel_size)
init_padding = init_kernel_size // 2
self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size),
padding=(0, init_padding, init_padding))
self.init_temporal_attn = Residual(PreNorm(init_dim, temporal_attn(init_dim)))
# dimensions
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time conditioning
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# audio conditioning
self.has_cond = exists(cond_dim) or use_hubert_audio_cond
self.cond_dim = cond_dim
self.cond_aud_dim = cond_aud
self.cond_pose_dim = cond_pose
self.cond_eye_dim = cond_eye
# modified by lml
self.learn_null_cond = learn_null_cond
# cat(t,cond) is not suitable
# cond_dim = time_dim + int(cond_dim or 0)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
# block type
block_klass = partial(ResnetBlock_ca_mul, groups=resnet_groups)
block_klass_cond = partial(block_klass, time_emb_dim=time_dim, audio_emb_dim=self.cond_aud_dim, pose_emb_dim=self.cond_pose_dim, eye_emb_dim=self.cond_eye_dim)
# block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) # cat embedding
# modules for all layers
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
block_klass_cond(dim_in, dim_out),
block_klass_cond(dim_out, dim_out),
Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out,
heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
Residual(PreNorm(dim_out, temporal_attn(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = block_klass_cond(mid_dim, mid_dim)
spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads))
self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))
self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim)))
self.mid_block2 = block_klass_cond(mid_dim, mid_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
block_klass_cond(dim_out * 2, dim_in),
block_klass_cond(dim_in, dim_in),
Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in,
heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
Residual(PreNorm(dim_in, temporal_attn(dim_in))),
Upsample(dim_in, use_deconv, padding_mode) if not is_last else nn.Identity()
]))
# out_dim = default(out_grid_dim, channels)
self.final_conv = nn.Sequential(
block_klass(dim * 2, dim),
nn.Conv3d(dim, out_grid_dim, 1)
)
# added by nhm
self.use_final_activation = use_final_activation
if self.use_final_activation:
self.final_activation = nn.Tanh()
else:
self.final_activation = nn.Identity()
# added by nhm for predicting occlusion mask
self.occlusion_map = nn.Sequential(
block_klass(dim * 2, dim),
nn.Conv3d(dim, out_conf_dim, 1)
)
def forward_with_cond_scale(
self,
*args,
cond_scale=2.,
**kwargs
):
logits = self.forward(*args, null_cond_prob=0., **kwargs)
if cond_scale == 1 or not self.has_cond:
return logits
null_logits = self.forward(*args, null_cond_prob=1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
time,
cond=None,
null_cond_prob=0.,
focus_present_mask=None,
prob_focus_present=0.
# probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
):
assert not (self.has_cond and not exists(cond)), 'cond must be passed in if cond_dim specified'
batch, device = x.shape[0], x.device
focus_present_mask = default(focus_present_mask,
lambda: prob_mask_like((batch,), prob_focus_present, device=device))
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)
x = self.init_conv(x)
r = x.clone()
x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)
t = self.time_mlp(time) if exists(self.time_mlp) else None
if self.learn_null_cond:
self.null_cond_emb = nn.Parameter(torch.randn(1, self.num_frames, self.cond_dim)) if self.has_cond else None
else:
self.null_cond_emb = torch.zeros(1, self.num_frames, self.cond_dim) if self.has_cond else None
# classifier free guidance
if self.has_cond:
batch, device = x.shape[0], x.device
self.null_cond_mask = prob_mask_like((batch, self.num_frames,), null_cond_prob, device=device)
cond = torch.where(rearrange(self.null_cond_mask, 'b n -> b n 1'), self.null_cond_emb.to(cond.device), cond)
# t (bs, 256) cond (bs, nf*1024)->(bs, nf, 1024) in this version
# it's the original cond embedding method used in LFDM
# t = torch.cat((t, cond), dim=-1)
h = []
for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:
x = block1(x, t, cond)
x = block2(x, t, cond)
x = spatial_attn(x)
x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t, cond)
x = self.mid_spatial_attn(x)
x = self.mid_temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)
x = self.mid_block2(x, t, cond)
for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t, cond)
x = block2(x, t, cond)
x = spatial_attn(x)
x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)
x = upsample(x)
x = torch.cat((x, r), dim=1)
return torch.cat((self.final_conv(x), self.occlusion_map(x)), dim=1)
# to dynamically change num_frames of Unet3D
class DynamicNfUnet3D(Unet3D):
def __init__(self, default_num_frames=20, *args, **kwargs):
super(DynamicNfUnet3D, self).__init__(*args, **kwargs)
self.default_num_frames = default_num_frames
self.num_frames = default_num_frames
def update_num_frames(self, new_num_frames):
self.num_frames = new_num_frames
# gaussian diffusion trainer class
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.9999)
class GaussianDiffusion(nn.Module):
def __init__(
self,
denoise_fn,
*,
image_size,
num_frames,
text_use_bert_cls=False,
channels=3,
timesteps=1000,
sampling_timesteps=250,
ddim_sampling_eta=1.,
loss_type='l1',
use_dynamic_thres=False, # from the Imagen paper
dynamic_thres_percentile=0.9,
null_cond_prob=0.1
):
super().__init__()
self.null_cond_prob = null_cond_prob
self.channels = channels
self.image_size = image_size
self.num_frames = num_frames
self.denoise_fn = denoise_fn
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
self.sampling_timesteps = default(sampling_timesteps,
timesteps)
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = ddim_sampling_eta
# register buffer helper function that casts float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# text conditioning parameters
self.text_use_bert_cls = text_use_bert_cls
# dynamic thresholding when sampling
self.use_dynamic_thres = use_dynamic_thres
self.dynamic_thres_percentile = dynamic_thres_percentile
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, cond_scale=1.):
fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)
x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn.forward_with_cond_scale(torch.cat([x, fea], dim=1),
t,
cond=cond,
cond_scale=cond_scale))
if clip_denoised:
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_recon, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim=-1
)
s.clamp_(min=1.)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode()
def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, fea=fea,
clip_denoised=clip_denoised, cond=cond,
cond_scale=cond_scale)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode()
def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), fea, cond=cond,
cond_scale=cond_scale)
return img
# return unnormalize_img(img)
@torch.inference_mode()
def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=16):
# text bert: cond 1,768
# device = next(self.denoise_fn.parameters()).device
# if is_list_str(cond):
# cond = torch.rand((1 ,768), dtype=torch.float32).cuda() #used to debug
# cond = bert_embed(tokenize(cond), return_cls_repr=self.text_use_bert_cls).to(device)
batch_size = cond.shape[0] if exists(cond) else batch_size
# batch_size = 1 if exists(cond) else batch_size
image_size = self.image_size
channels = self.channels
num_frames = self.num_frames
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
fea = torch.cat([fea, bbox_mask], dim=1)
return sample_fn(fea, (batch_size, channels, num_frames, image_size, image_size), cond=cond,
cond_scale=cond_scale)
# add by nhm
@torch.no_grad()
def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoised=True):
batch, device, total_timesteps, sampling_timesteps, eta = \
shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps=sampling_timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
img = torch.randn(shape, device=device) # bs, 3, nf, 32, 32
fea = fea.unsqueeze(dim=2).repeat(1, 1, img.size(2), 1, 1) #bs, 256, nf, 32, 32
for time, time_next in tqdm(time_pairs, desc='sampling loop time step'):
alpha = self.alphas_cumprod_prev[time]
alpha_next = self.alphas_cumprod_prev[time_next]
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
# pred_noise, x_start, *_ = self.model_predictions(img, time_cond, fea)
pred_noise = self.denoise_fn.forward_with_cond_scale(
torch.cat([img, fea], dim=1),
time_cond,
cond=cond,
cond_scale=cond_scale)
x_start = self.predict_start_from_noise(img, t=time_cond, noise=pred_noise)
if clip_denoised:
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_start, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim=-1
)
s.clamp_(min=1.)
s = s.view(-1, *((1,) * (x_start.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_start = x_start.clamp(-s, s) / s
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = ((1 - alpha_next) - sigma ** 2).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0.
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
# img = unnormalize_to_zero_to_one(img)
return img
@torch.inference_mode()
def interpolate(self, x1, x2, t=None, lam=0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, clip_denoised=True, **kwargs):
# x_start: bs, 3, num_frame, 32, 32
# t: bs
# fea: bs, 256, num_frame, 32, 32
# cond: bs, 768
b, c, f, h, w, device = *x_start.shape, x_start.device
noise = default(noise, lambda: torch.randn_like(x_start)) # bs, 3, nf, 32, 32
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# bs, 3, nf, 32, 32
pred_noise = self.denoise_fn.forward(torch.cat([x_noisy, fea, bbox_mask], dim=1), t, cond=cond,
null_cond_prob=self.null_cond_prob,
**kwargs)
if self.loss_type == 'l1':
loss = F.l1_loss(noise, pred_noise, reduce=False)
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, pred_noise, reduce=False)
else:
raise NotImplementedError()
pred_x0 = self.predict_start_from_noise(x_noisy, t, pred_noise)
if clip_denoised:
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(pred_x0, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim=-1
)
s.clamp_(min=1.)
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
self.pred_x0 = pred_x0.clamp(-s, s) / s
return loss, self.denoise_fn.null_cond_mask
def forward(self, x, fea, bbox_mask, cond, *args, **kwargs):
b, device, img_size, = x.shape[0], x.device, self.image_size
# check_shape(x, 'b c f h w', c=self.channels, f=self.num_frames, h=img_size, w=img_size)
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)
bbox_mask = bbox_mask.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1)
return self.p_losses(x, t, fea, bbox_mask, cond, *args, **kwargs)
# trainer class
CHANNELS_TO_MODE = {
1: 'L',
3: 'RGB',
4: 'RGBA'
}
def seek_all_images(img, channels=3):
assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
mode = CHANNELS_TO_MODE[channels]
i = 0
while True:
try:
img.seek(i)
yield img.convert(mode)
except EOFError:
break
i += 1
# to dynamically change num_frames of GaussianDiffusion
class DynamicNfGaussianDiffusion(GaussianDiffusion):
def __init__(self, default_num_frames=20, *args, **kwargs):
super(DynamicNfGaussianDiffusion, self).__init__(*args, **kwargs)
self.default_num_frames = default_num_frames
self.num_frames = default_num_frames
def update_num_frames(self, new_num_frames):
self.num_frames = new_num_frames
# tensor of shape (channels, frames, height, width) -> gif
def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
images = map(T.ToPILImage(), tensor.unbind(dim=1))
first_img, *rest_imgs = images
first_img.save(path, save_all=True, append_images=rest_imgs, duration=duration, loop=loop, optimize=optimize)
return images
# gif -> (channels, frame, height, width) tensor
def gif_to_tensor(path, channels=3, transform=T.ToTensor()):
img = Image.open(path)
tensors = tuple(map(transform, seek_all_images(img, channels=channels)))
return torch.stack(tensors, dim=1)
def identity(t, *args, **kwargs):
return t
def normalize_img(t):
return t * 2 - 1
# def unnormalize_img(t):
# return (t + 1) * 0.5
def cast_num_frames(t, *, frames):
f = t.shape[1]
if f == frames:
return t
if f > frames:
return t[:, :frames]
return F.pad(t, (0, 0, 0, 0, 0, frames - f))
================================================
FILE: DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py
================================================
'''
adding pose condtioning on baseline
using cross attention to add different condition
using local attention, for inference, faster cost more ram
'''
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial
from torchvision import transforms as T
from PIL import Image
from tqdm import tqdm
from einops import rearrange, repeat, reduce, pack, unpack
from einops_exts import rearrange_many
from rotary_embedding_torch import RotaryEmbedding
# from DM.modules.text import tokenize, bert_embed, HUBERT_MODEL_DIM
# helpers functions
def exists(x):
return x is not None
def noop(*args, **kwargs):
pass
def is_odd(n):
return (n % 2) == 1
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cycle(dl):
while True:
for data in dl:
yield data
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
def is_list_str(x):
if not isinstance(x, (list, tuple)):
return False
return all([type(el) == str for el in x])
# relative positional bias
class RelativePositionBias(nn.Module):
def __init__(
self,
heads=8,
num_buckets=32,
max_distance=128,
window_width = 20
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
self.window_width = window_width
@staticmethod
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, n, device):
q_pos = torch.arange(n, dtype=torch.long, device=device)
k_pos = torch.arange(n, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets,
max_distance=self.max_distance)
mask = -(((rel_pos > self.window_width) + (rel_pos < -self.window_width)) * (1e8)) # -(((rp_bucket ==15) + (rp_bucket >= 30)) * (1e8))
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j') + mask
# small helper modules
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
def Upsample(dim, use_deconv=True, padding_mode="reflect"):
if use_deconv:
return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
else:
return nn.Sequential(
nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'),
nn.Conv3d(dim, dim, (1, 3, 3), (1, 1, 1), (0, 1, 1), padding_mode=padding_mode)
)
def Downsample(dim):
return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
def forward(self, x):
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
class LayerNorm_img(nn.Module):
def __init__(self, dim, stable = False):
super().__init__()
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
def l2norm(t):
return F.normalize(t, dim = -1)
# building block modules
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, time_scale_shift=None, audio_scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exists(time_scale_shift):
time_scale, time_shift = time_scale_shift
x = x * (time_scale + 1) + time_shift
# added by lml to change the control method of audio embedding, inspired by diffusedhead
# if exists(audio_scale_shift):
# # audio_scale and audio_shift:(bs, 64, nf, 1, 1)
# # x:(bs, 64, nf, 32, 32)
# audio_scale, audio_shift = audio_scale_shift
# x = x * (audio_scale + 1) + audio_shift
return self.act(x)
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):
super().__init__()
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.audio_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(audio_emb_dim, dim_out * 2)
) if exists(audio_emb_dim) else None
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None, audio_emb=None):
time_scale_shift = None
audio_scale_shift = None
if exists(self.time_mlp):
assert exists(time_emb), 'time emb must be passed in'
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1
time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1
# added by lml to get audio embedding
if exists(self.audio_mlp):
assert exists(audio_emb), 'audio emb must be passed in'
audio_emb = self.audio_mlp(audio_emb)
audio_emb = rearrange(audio_emb, 'b n c -> b c n 1 1') # bs, 128, nf, 1, 1
audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1
h = self.block1(x, time_scale_shift=time_scale_shift, audio_scale_shift=audio_scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class ResnetBlock_ca(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8):
super().__init__()
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.audio_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(audio_emb_dim, dim_out * 2)
) if exists(audio_emb_dim) else None
# self.audio_mlp_2 = nn.Sequential(
# nn.SiLU(),
# nn.Linear(dim_out, dim_out * 2)
# ) if exists(audio_emb_dim) else None
attn_klass = CrossAttention
self.cross_attn = attn_klass(
dim = dim,
context_dim = dim_out * 2
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None, audio_emb=None):
time_scale_shift = None
audio_scale_shift = None
b, c, f, H, W = x.size()
if exists(self.time_mlp):
assert exists(time_emb), 'time emb must be passed in'
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1
time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1
# added by lml to get audio embedding
if exists(self.audio_mlp):
assert exists(audio_emb), 'audio emb must be passed in'
audio_emb = self.audio_mlp(audio_emb)
if exists(self.cross_attn):
# h = rearrange(x, 'b c f ... -> (b f) ... c')
# h, ps = pack([h], 'b * c')
# audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')
# audio_emb = self.cross_attn(h, context = audio_emb)
# # h, = unpack(h, ps, 'b * c')
# # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)
# # audio_emb = self.audio_mlp_2(audio_emb)
# audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)
assert exists(audio_emb)
h = rearrange(x, 'b c f ... -> (b f) ... c')
# h = rearrange(x, 'b c ... -> b ... c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = audio_emb) + h
h, = unpack(h, ps, 'b * c')
# h = rearrange(h, 'b ... c -> b c ...')
h = rearrange(h, '(b f) ... c -> b f c ...', b = b, f = f)
# audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1
# audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1
h = self.block1(x, time_scale_shift=time_scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class ResnetBlock_ca_mul(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, pose_emb_dim=None, eye_emb_dim=None, groups=8):
super().__init__()
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.audio_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(audio_emb_dim, dim_out * 2)
) if exists(audio_emb_dim) else None
self.pose_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(pose_emb_dim, dim_out * 2)
) if exists(pose_emb_dim) else None
self.eye_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(eye_emb_dim, dim_out * 2)
) if exists(eye_emb_dim) else None
self.audio_emb_dim = audio_emb_dim
self.pose_emb_dim = pose_emb_dim
self.eye_emb_dim = eye_emb_dim
# self.audio_mlp_2 = nn.Sequential(
# nn.SiLU(),
# nn.Linear(dim_out, dim_out * 2)
# ) if exists(audio_emb_dim) else None
attn_klass = CrossAttention
self.cross_attn_aud = attn_klass(
dim = dim,
context_dim = dim_out * 2,
out_dim = dim_out
)
self.cross_attn_pose = attn_klass(
dim = dim,
context_dim = dim_out * 2,
out_dim = dim_out
)
self.cross_attn_eye = attn_klass(
dim = dim,
context_dim = dim_out * 2,
out_dim = dim_out
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None, audio_emb=None):
time_scale_shift = None
audio_scale_shift = None
'''
need seperate 3 diffiserent condition
'''
if exists(audio_emb):
pose_emb = audio_emb[:,:,self.audio_emb_dim:self.audio_emb_dim + self.pose_emb_dim]
eye_emb = audio_emb[:,:,self.audio_emb_dim + self.pose_emb_dim: ]
audio_emb = audio_emb[:,:,:self.audio_emb_dim]
b, c, f, H, W = x.size()
if exists(self.time_mlp):
assert exists(time_emb), 'time emb must be passed in'
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1
time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1
# added by lml to get audio embedding
if exists(self.audio_mlp): # mouth lmk + audio emb
assert exists(audio_emb), 'audio emb must be passed in'
audio_emb = self.audio_mlp(audio_emb)
pose_emb = self.pose_mlp(pose_emb) # TODO: embedding
eye_emb = self.eye_mlp(eye_emb)
if exists(self.cross_attn_aud):
# h = rearrange(x, 'b c f ... -> (b f) ... c')
# h, ps = pack([h], 'b * c')
# audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...')
# audio_emb = self.cross_attn(h, context = audio_emb)
# # h, = unpack(h, ps, 'b * c')
# # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c)
# # audio_emb = self.audio_mlp_2(audio_emb)
# audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f)
assert exists(audio_emb)
h_cond = rearrange(x, 'b c f ... -> (b f) ... c')
# h = rearrange(x, 'b c ... -> b ... c')
h_cond, ps = pack([h_cond], 'b * c')
h_pose = self.cross_attn_pose(h_cond, context = pose_emb)
h_aud = self.cross_attn_aud(h_cond, context = audio_emb)
h_eye = self.cross_attn_eye(h_cond, context = eye_emb)
h_cond = h_pose + h_aud + h_eye
h_cond, = unpack(h_cond, ps, 'b * c')
# h = rearrange(h, 'b ... c -> b c ...')
h_cond = rearrange(h_cond, '(b f) ... c -> b c f ...', b = b, f = f)
# audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1
# audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1
h = self.block1(x, time_scale_shift=time_scale_shift)
if exists(self.audio_mlp):
h = h_cond + h
h = self.block2(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
out_dim,
*,
context_dim = None,
dim_head = 8,
heads = 8,
norm_context = False,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = LayerNorm_img(dim)
self.norm_context = LayerNorm_img(context_dim) if norm_context else Identity()
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, out_dim, bias = False),
LayerNorm_img(out_dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x) # bn * fn ?
# context: b, fn, c
context = rearrange(context, 'b f c -> (b f) c')
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# cosine sim attention
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# similarities
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearCrossAttention(CrossAttention):
def forward(self, x, context, mask = None):
b, n, c = x.size()
b, n, device = *x.shape[:2], x.device # x : b * fn, 32*32, c
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1)) # b*fn, 32*32, c, b*fn, 1, c * 2,
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads, d = c//self.heads), (q, k, v)) # head * b*fn, n, c//head
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2) # b * nf * h, 2, c//h
v = torch.cat((nv, v), dim = -2)
# masking
max_neg_value = -torch.finfo(x.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b n -> b n 1')
k = k.masked_fill(~mask, max_neg_value)
v = v.masked_fill(~mask, 0.)
# linear attention
q = q.softmax(dim = -1) # # b * nf * h, 32*32, c//h,
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v) # b * nf * h, 2, c//h, b * nf * h, 2, c//h
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
return self.to_out(out)
class SpatialLinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, f, h, w = x.shape
x = rearrange(x, 'b c f h w -> (b f) c h w')
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)
q = q.softmax(dim=-2)
k = k.softmax(dim=-1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)
out = self.to_out(out)
return rearrange(out, '(b f) c h w -> b c f h w', b=b)
# attention along space and time
class EinopsToAndFrom(nn.Module):
def __init__(self, from_einops, to_einops, fn):
super().__init__()
self.from_einops = from_einops
self.to_einops = to_einops
self.fn = fn
def forward(self, x, **kwargs):
shape = x.shape
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
x = self.fn(x, **kwargs)
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
heads=4,
dim_head=32,
rotary_emb=None
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.rotary_emb = rotary_emb
self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)
self.to_out = nn.Linear(hidden_dim, dim, bias=False)
def forward(
self,
x,
pos_bias=None,
focus_present_mask=None
): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c'
n, device = x.shape[-2], x.device
qkv = self.to_qkv(x).chunk(3, dim=-1)
if exists(focus_present_mask) and focus_present_mask.all():
# if all batch samples are focusing on present
# it would be equivalent to passing that token's values through to the output
values = qkv[-1]
return self.to_out(values)
# split out heads
q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)
# scale
q = q * self.scale
# rotate positions into queries and keys for time attention
if exists(self.rotary_emb):
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
# similarity
sim = einsum('... h i d, ... h j d -> ... h i j', q, k)
# relative positional bias
if exists(pos_bias):
sim = sim + pos_bias
if exists(focus_present_mask) and not (~focus_present_mask).all():
attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool)
attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
mask = torch.where(
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# numerical stability
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
# aggregate values
out = einsum('... h i j, ... h j d -> ... h i d', attn, v)
out = rearrange(out, '... h n d -> ... n (h d)')
return self.to_out(out)
# model
class Unet3D(nn.Module):
def __init__(
self,
dim,
cond_aud=1024,
cond_pose=7,
cond_eye=2,
cond_dim=None,
out_grid_dim=2,
out_conf_dim=1,
num_frames=40,
dim_mults=(1, 2, 4, 8),
channels=3,
attn_heads=8,
attn_dim_head=32,
use_hubert_audio_cond=False,
init_dim=None,
init_kernel_size=7,
use_sparse_linear_attn=True,
resnet_groups=8,
use_final_activation=False,
learn_null_cond=False,
use_deconv=True,
padding_mode="zeros",
win_width = 20
):
super().__init__()
self.null_cond_mask = None
self.channels = channels
self.num_frames = num_frames
self.HUBERT_MODEL_DIM = 1024
# temporal attention and its relative positional encoding
rotary_emb = RotaryEmbedding(min(32, attn_dim_head))
temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c',
Attention(dim, heads=attn_heads, dim_head=attn_dim_head,
rotary_emb=rotary_emb))
self.time_rel_pos_bias = RelativePositionBias(heads=attn_heads,
max_distance=32, window_width = win_width) # realistically will not be able to generate that many frames of video... yet
# initial conv
init_dim = default(init_dim, dim)
assert is_odd(init_kernel_size)
init_padding = init_kernel_size // 2
self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size),
padding=(0, init_padding, init_padding))
self.init_temporal_attn = Residual(PreNorm(init_dim, temporal_attn(init_dim)))
# dimensions
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time conditioning
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# audio conditioning
self.has_cond = exists(cond_dim) or use_hubert_audio_cond
self.cond_dim = cond_dim
self.cond_aud_dim = cond_aud
self.cond_pose_dim = cond_pose
self.cond_eye_dim = cond_eye
# modified by lml
self.learn_null_cond = learn_null_cond
# cat(t,cond) is not suitable
# cond_dim = time_dim + int(cond_dim or 0)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
# block type
block_klass = partial(ResnetBlock_ca_mul, groups=resnet_groups)
block_klass_cond = partial(block_klass, time_emb_dim=time_dim, audio_emb_dim=self.cond_aud_dim, pose_emb_dim=self.cond_pose_dim, eye_emb_dim=self.cond_eye_dim)
# block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) # cat embedding
# modules for all layers
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
block_klass_cond(dim_in, dim_out),
block_klass_cond(dim_out, dim_out),
Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out,
heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
Residual(PreNorm(dim_out, temporal_attn(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = block_klass_cond(mid_dim, mid_dim)
spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads))
self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))
self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim)))
self.mid_block2 = block_klass_cond(mid_dim, mid_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
block_klass_cond(dim_out * 2, dim_in),
block_klass_cond(dim_in, dim_in),
Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in,
heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
Residual(PreNorm(dim_in, temporal_attn(dim_in))),
Upsample(dim_in, use_deconv, padding_mode) if not is_last else nn.Identity()
]))
# out_dim = default(out_grid_dim, channels)
self.final_conv = nn.Sequential(
block_klass(dim * 2, dim),
nn.Conv3d(dim, out_grid_dim, 1)
)
# added by nhm
self.use_final_activation = use_final_activation
if self.use_final_activation:
self.final_activation = nn.Tanh()
else:
self.final_activation = nn.Identity()
# added by nhm for predicting occlusion mask
self.occlusion_map = nn.Sequential(
block_klass(dim * 2, dim),
nn.Conv3d(dim, out_conf_dim, 1)
)
def forward_with_cond_scale(
self,
*args,
cond_scale=2.,
**kwargs
):
logits = self.forward(*args, null_cond_prob=0., **kwargs)
if cond_scale == 1 or not self.has_cond:
return logits
null_logits = self.forward(*args, null_cond_prob=1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
time,
cond=None,
null_cond_prob=0.,
focus_present_mask=None,
gitextract_jmo9ls54/ ├── .gitignore ├── DAWN_256.yaml ├── DM_3/ │ ├── datasets_hdtf_wpose_lmk_block_lmk.py │ ├── datasets_hdtf_wpose_lmk_block_lmk_rand.py │ ├── modules/ │ │ ├── local_attention.py │ │ ├── text.py │ │ ├── video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D.py │ │ ├── video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D.py │ │ ├── video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test.py │ │ ├── video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi.py │ │ ├── video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py │ │ └── video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py │ ├── test_lr.py │ ├── train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D.py │ ├── train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D_s2.py │ └── utils.py ├── LFG/ │ ├── __init__.py │ ├── augmentation.py │ ├── frames_dataset.py │ ├── hdtf_dataset.py │ ├── modules/ │ │ ├── avd_network.py │ │ ├── bg_motion_predictor.py │ │ ├── flow_autoenc.py │ │ ├── generator.py │ │ ├── model.py │ │ ├── pixelwise_flow_predictor.py │ │ ├── region_predictor.py │ │ └── util.py │ ├── run_hdtf.py │ ├── run_hdtf_crema.py │ ├── sync_batchnorm/ │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ ├── test_flowautoenc_crema_video.py │ ├── test_flowautoenc_hdtf_video.py │ ├── test_flowautoenc_hdtf_video_256.py │ ├── train.py │ └── vis_flow.py ├── PBnet/ │ ├── run_cvae_h_ann_reemb_rope_eye_3.sh │ └── src/ │ ├── __init__.py │ ├── config.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── datasets_hdtf_pos_chunk_norm_2_fast.py │ │ ├── datasets_hdtf_pos_chunk_norm_eye_fast.py │ │ ├── datasets_hdtf_pos_df.py │ │ ├── datasets_hdtf_pos_dict_norm_2.py │ │ ├── datasets_hdtf_wpose_lmk_block.py │ │ ├── get_dataset.py │ │ └── tools.py │ ├── evaluate/ │ │ ├── __init__.py │ │ ├── action2motion/ │ │ │ ├── accuracy.py │ │ │ ├── diversity.py │ │ │ ├── evaluate.py │ │ │ ├── fid.py │ │ │ └── models.py │ │ ├── evaluate_cvae.py │ │ ├── evaluate_cvae_debug.py │ │ ├── evaluate_cvae_f3.py │ │ ├── evaluate_cvae_f3_debug.py │ │ ├── evaluate_cvae_f3_mel.py │ │ ├── evaluate_cvae_norm.py │ │ ├── evaluate_cvae_norm_all.py │ │ ├── evaluate_cvae_norm_all_seg.py │ │ ├── evaluate_cvae_norm_all_seg_weye.py │ │ ├── evaluate_cvae_norm_all_seg_weye2.py │ │ ├── evaluate_cvae_norm_eye_pose.py │ │ ├── evaluate_cvae_norm_eye_pose_test.py │ │ ├── evaluate_cvae_onlyeye_all_seg.py │ │ ├── othermetrics/ │ │ │ ├── acceleration.py │ │ │ └── evaluation.py │ │ ├── stgcn/ │ │ │ ├── accuracy.py │ │ │ ├── diversity.py │ │ │ ├── evaluate.py │ │ │ └── fid.py │ │ ├── tables/ │ │ │ ├── archtable.py │ │ │ ├── bstable.py │ │ │ ├── easy_table.py │ │ │ ├── easy_table_A2M.py │ │ │ ├── kltable.py │ │ │ ├── latexmodela2m.py │ │ │ ├── latexmodelsa2m.py │ │ │ ├── latexmodelsstgcn.py │ │ │ ├── losstable.py │ │ │ ├── maketable.py │ │ │ ├── numlayertable.py │ │ │ └── posereptable.py │ │ ├── tools.py │ │ ├── tvae_eval.py │ │ ├── tvae_eval_norm.py │ │ ├── tvae_eval_norm_all.py │ │ ├── tvae_eval_norm_eye_pose.py │ │ ├── tvae_eval_norm_eye_pose_seg.py │ │ ├── tvae_eval_norm_seg.py │ │ ├── tvae_eval_onlyeye_all_seg.py │ │ ├── tvae_eval_single.py │ │ ├── tvae_eval_single_both_eye_pose.py │ │ ├── tvae_eval_std.py │ │ ├── tvae_eval_train.py │ │ ├── tvae_eval_train_norm.py │ │ └── tvae_eval_train_std.py │ ├── generate/ │ │ └── generate_sequences.py │ ├── models/ │ │ ├── __init__.py │ │ ├── architectures/ │ │ │ ├── __init__.py │ │ │ ├── autotrans.py │ │ │ ├── fc.py │ │ │ ├── gru.py │ │ │ ├── grutrans.py │ │ │ ├── mlp.py │ │ │ ├── resnet34.py │ │ │ ├── tools/ │ │ │ │ ├── embeddings.py │ │ │ │ ├── resnet.py │ │ │ │ ├── transformer_layers.py │ │ │ │ └── util.py │ │ │ ├── transformer.py │ │ │ ├── transformerdecoder.py │ │ │ ├── transformerdecoder4.py │ │ │ ├── transformerdecoder5.py │ │ │ ├── transformerreemb.py │ │ │ ├── transformerreemb5.py │ │ │ ├── transformerreemb6.py │ │ │ └── transgru.py │ │ ├── get_model.py │ │ ├── modeltype/ │ │ │ ├── __init__.py │ │ │ ├── cae.py │ │ │ ├── cae_0.py │ │ │ ├── cvae.py │ │ │ └── lstm.py │ │ ├── rotation2xyz.py │ │ ├── smpl.py │ │ └── tools/ │ │ ├── __init__.py │ │ ├── graphconv.py │ │ ├── hessian_penalty.py │ │ ├── losses.py │ │ ├── mmd.py │ │ ├── msssim_loss.py │ │ ├── normalize_data.py │ │ ├── ssim_loss.py │ │ └── tools.py │ ├── parser/ │ │ ├── base.py │ │ ├── checkpoint.py │ │ ├── dataset.py │ │ ├── evaluation.py │ │ ├── finetunning.py │ │ ├── generate.py │ │ ├── model.py │ │ ├── recognition.py │ │ ├── tools.py │ │ ├── training.py │ │ └── visualize.py │ ├── preprocess/ │ │ ├── humanact12_process.py │ │ ├── phspdtools.py │ │ └── uestc_vibe_postprocessing.py │ ├── recognition/ │ │ ├── compute_accuracy.py │ │ ├── get_model.py │ │ └── models/ │ │ ├── stgcn.py │ │ └── stgcnutils/ │ │ ├── graph.py │ │ └── tgcn.py │ ├── render/ │ │ ├── renderer.py │ │ └── rendermotion.py │ ├── train/ │ │ ├── __init__.py │ │ ├── train_cvae_ganloss_ann_eye.py │ │ ├── train_cvae_ganloss_ann_fast.py │ │ ├── trainer.py │ │ ├── trainer_gan.py │ │ └── trainer_gan_ann.py │ ├── utils/ │ │ ├── PYTORCH3D_LICENSE │ │ ├── __init__.py │ │ ├── fixseed.py │ │ ├── get_model_and_data.py │ │ ├── misc.py │ │ ├── rotation_conversions.py │ │ ├── tensors.py │ │ ├── tensors_eye.py │ │ ├── tensors_eye_eval.py │ │ ├── tensors_hdtf.py │ │ ├── tensors_onlyeye.py │ │ ├── utils.py │ │ └── video.py │ └── visualize/ │ ├── __init__.py │ ├── anim.py │ ├── visualize.py │ ├── visualize_checkpoint.py │ ├── visualize_dataset.py │ ├── visualize_latent_space.py │ ├── visualize_nturefined.py │ └── visualize_sequence.py ├── README.md ├── README_CN.md ├── config/ │ ├── DAWN_128.yaml │ ├── DAWN_256.yaml │ ├── hdtf128.yaml │ ├── hdtf128_1000ep.yaml │ ├── hdtf128_1000ep_crema.yaml │ ├── hdtf256.yaml │ └── hdtf256_400ep.yaml ├── extract_init_states/ │ ├── FaceBoxes/ │ │ ├── FaceBoxes.py │ │ ├── FaceBoxes_ONNX.py │ │ ├── __init__.py │ │ ├── build_cpu_nms.sh │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ └── faceboxes.py │ │ ├── onnx.py │ │ ├── readme.md │ │ ├── utils/ │ │ │ ├── .gitignore │ │ │ ├── __init__.py │ │ │ ├── box_utils.py │ │ │ ├── build.py │ │ │ ├── config.py │ │ │ ├── functions.py │ │ │ ├── nms/ │ │ │ │ ├── .gitignore │ │ │ │ ├── __init__.py │ │ │ │ ├── cpu_nms.cp38-win_amd64.pyd │ │ │ │ ├── cpu_nms.pyx │ │ │ │ └── py_cpu_nms.py │ │ │ ├── nms_wrapper.py │ │ │ ├── prior_box.py │ │ │ └── timer.py │ │ └── weights/ │ │ ├── .gitignore │ │ ├── FaceBoxesProd.pth │ │ └── readme.md │ ├── TDDFA_ONNX.py │ ├── bfm/ │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── bfm.py │ │ ├── bfm_onnx.py │ │ └── readme.md │ ├── build.sh │ ├── configs/ │ │ ├── .gitignore │ │ ├── BFM_UV.mat │ │ ├── bfm_noneck_v3.onnx │ │ ├── bfm_noneck_v3.pkl │ │ ├── indices.npy │ │ ├── mb05_120x120.yml │ │ ├── mb1_120x120.yml │ │ ├── ncc_code.npy │ │ ├── param_mean_std_62d_120x120.pkl │ │ ├── readme.md │ │ ├── resnet_120x120.yml │ │ └── tri.pkl │ ├── demo_pose_extract_2d_lmk_img.py │ ├── functions.py │ ├── models/ │ │ ├── __init__.py │ │ ├── mobilenet_v1.py │ │ ├── mobilenet_v3.py │ │ └── resnet.py │ ├── pose.py │ ├── readme.md │ ├── utils/ │ │ ├── __init__.py │ │ ├── asset/ │ │ │ ├── .gitignore │ │ │ ├── build_render_ctypes.sh │ │ │ └── render.c │ │ ├── depth.py │ │ ├── functions.py │ │ ├── io.py │ │ ├── onnx.py │ │ ├── pncc.py │ │ ├── pose.py │ │ ├── render.py │ │ ├── render_ctypes.py │ │ ├── serialization.py │ │ ├── tddfa_util.py │ │ └── uv.py │ └── weights/ │ ├── .gitignore │ ├── mb05_120x120.pth │ ├── mb1_120x120.onnx │ ├── mb1_120x120.pth │ └── readme.md ├── filter_fourier.py ├── hubert_extract/ │ └── data_gen/ │ └── process_lrs3/ │ ├── binarizer.py │ ├── process_audio_hubert.py │ ├── process_audio_hubert_interpolate.py │ ├── process_audio_hubert_interpolate_batch.py │ ├── process_audio_hubert_interpolate_demo.py │ ├── process_audio_hubert_interpolate_single.py │ └── process_audio_mel_f0.py ├── misc.py ├── requirements.txt ├── run_ood_test/ │ ├── run_DM_v0_df_test_128_both_pose_blink.sh │ ├── run_DM_v0_df_test_128_separate_pose_blink.sh │ ├── run_DM_v0_df_test_256.sh │ ├── run_DM_v0_df_test_256_1.sh │ └── run_DM_v0_df_test_256_1_separate_pose_blink.sh ├── sync_batchnorm/ │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ ├── replicate_ddp.py │ └── unittest.py └── unified_video_generator.py
SYMBOL INDEX (1652 symbols across 212 files)
FILE: DM_3/datasets_hdtf_wpose_lmk_block_lmk.py
function resize (line 29) | def resize(im, desired_size, interpolation):
class HDTF (line 45) | class HDTF(data.Dataset):
method __init__ (line 46) | def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=8...
method check_head (line 105) | def check_head(self, frame_list, video_name, start, end):
method get_block_data_for_two (line 118) | def get_block_data_for_two(self, path, start, end):
method get_block_data (line 149) | def get_block_data(self, path, start, end):
method check_len (line 184) | def check_len(self, name):
method __len__ (line 189) | def __len__(self):
method __getitem__ (line 192) | def __getitem__(self, idx):
FILE: DM_3/datasets_hdtf_wpose_lmk_block_lmk_rand.py
function resize (line 29) | def resize(im, desired_size, interpolation):
class HDTF (line 45) | class HDTF(data.Dataset):
method __init__ (line 46) | def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=8...
method check_head (line 110) | def check_head(self, frame_list, video_name, start, end):
method get_block_data_for_two (line 123) | def get_block_data_for_two(self, path, start, end):
method get_block_data (line 154) | def get_block_data(self, path, start, end):
method check_len (line 189) | def check_len(self, name):
method __len__ (line 194) | def __len__(self):
method __getitem__ (line 197) | def __getitem__(self, idx):
FILE: DM_3/modules/local_attention.py
function exists (line 24) | def exists(x):
function to_mask (line 27) | def to_mask(x, mask, mode='mul'):
function extract_seq_patches (line 50) | def extract_seq_patches(x, kernel_size, rate):
function window_attn (line 71) | def window_attn(x, y, z, kernel_size, mask, rate):
function window_attn_2 (line 102) | def window_attn_2(x, y, z, kernel_size, mask, rate): # bad optimization
function window_attn_stream (line 167) | def window_attn_stream(x, y, z, kernel_size, mask, rate): # bad optimiz...
function create_sliding_window_mask (line 212) | def create_sliding_window_mask(x, win_size, rate):
class OurLayer (line 228) | class OurLayer(nn.Module):
method reuse (line 230) | def reuse(self, layer, *args, **kwargs):
function heavy_computation (line 235) | def heavy_computation(x, y, attn, k_size, i):
function heavy_computation2 (line 238) | def heavy_computation2(res, z, attn, k_size, i):
function window_attn_mp (line 242) | def window_attn_mp(x, y, z, kernel_size, mask, rate):
class LocalSelfAttention_opt (line 275) | class LocalSelfAttention_opt(OurLayer):
method __init__ (line 277) | def __init__(self, d_model, heads, size_per_head, neighbors=3, rate=1,...
method forward (line 300) | def forward(self, inputs, pos_bias, focus_present_mask=None,):
class MultiHeadLocalAttention (line 345) | class MultiHeadLocalAttention(nn.Module):
method __init__ (line 346) | def __init__(self, d_model, num_heads, window_size):
method split_heads (line 369) | def split_heads(self, x, batch_size):
method forward (line 374) | def forward(self, x):
class Attention (line 404) | class Attention(nn.Module):
method __init__ (line 405) | def __init__(
method forward (line 424) | def forward(
class RelativePositionBias (line 490) | class RelativePositionBias(nn.Module):
method __init__ (line 491) | def __init__(
method _relative_position_bucket (line 503) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 522) | def forward(self, n, device):
FILE: DM_3/modules/text.py
function exists (line 6) | def exists(val):
function get_tokenizer (line 18) | def get_tokenizer():
function get_bert (line 25) | def get_bert():
function tokenize (line 37) | def tokenize(texts, add_special_tokens=True):
function bert_embed (line 57) | def bert_embed(
FILE: DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D.py
class Attention (line 26) | class Attention(nn.Module):
method __init__ (line 27) | def __init__(self, params):
method forward (line 32) | def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):
class Face_loc_Encoder (line 47) | class Face_loc_Encoder(nn.Module):
method __init__ (line 48) | def __init__(self, dim = 1):
method forward (line 53) | def forward(self, x):
class Vgg19 (line 60) | class Vgg19(torch.nn.Module):
method __init__ (line 65) | def __init__(self, requires_grad=False):
method forward (line 93) | def forward(self, x):
class ImagePyramide (line 103) | class ImagePyramide(torch.nn.Module):
method __init__ (line 108) | def __init__(self, scales, num_channels):
method forward (line 115) | def forward(self, x):
class FlowDiffusion (line 121) | class FlowDiffusion(nn.Module):
method __init__ (line 122) | def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250,
method update_num_frames (line 209) | def update_num_frames(self, new_num_frames):
method generate_bbox_mask (line 214) | def generate_bbox_mask(self, bbox, size = 32):
method generate_mouth_mask (line 238) | def generate_mouth_mask(self, mouth_lmk, origin_size, size = 32):
method forward (line 257) | def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink...
method sample_one_video (line 399) | def sample_one_video(self, real_vid, sample_img, sample_audio_hubert, ...
method get_grid (line 466) | def get_grid(self, b, nf, H, W, normalize=True):
method set_requires_grad (line 476) | def set_requires_grad(self, nets, requires_grad=False):
FILE: DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D.py
class Attention (line 25) | class Attention(nn.Module):
method __init__ (line 26) | def __init__(self, params):
method forward (line 31) | def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):
class Face_loc_Encoder (line 46) | class Face_loc_Encoder(nn.Module):
method __init__ (line 47) | def __init__(self, dim = 1):
method forward (line 52) | def forward(self, x):
class Vgg19 (line 59) | class Vgg19(torch.nn.Module):
method __init__ (line 64) | def __init__(self, requires_grad=False):
method forward (line 92) | def forward(self, x):
class FlowDiffusion (line 103) | class FlowDiffusion(nn.Module):
method __init__ (line 104) | def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250,
method update_num_frames (line 182) | def update_num_frames(self, new_num_frames):
method generate_bbox_mask (line 187) | def generate_bbox_mask(self, bbox, size = 32):
method generate_mouth_mask (line 208) | def generate_mouth_mask(self, mouth_lmk, origin_size, size = 32):
method forward (line 227) | def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink...
method sample_one_video (line 404) | def sample_one_video(self, real_vid, sample_img, sample_audio_hubert, ...
method get_grid (line 454) | def get_grid(self, b, nf, H, W, normalize=True):
method set_requires_grad (line 464) | def set_requires_grad(self, nets, requires_grad=False):
FILE: DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test.py
class Attention (line 18) | class Attention(nn.Module):
method __init__ (line 19) | def __init__(self, params):
method forward (line 24) | def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):
class Face_loc_Encoder (line 39) | class Face_loc_Encoder(nn.Module):
method __init__ (line 40) | def __init__(self, dim = 1):
method forward (line 45) | def forward(self, x):
class Vgg19 (line 52) | class Vgg19(torch.nn.Module):
method __init__ (line 57) | def __init__(self, requires_grad=False):
method forward (line 85) | def forward(self, x):
class FlowDiffusion (line 96) | class FlowDiffusion(nn.Module):
method __init__ (line 97) | def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250,...
method update_num_frames (line 177) | def update_num_frames(self, new_num_frames):
method generate_bbox_mask (line 182) | def generate_bbox_mask(self, bbox, size = 32):
method forward (line 203) | def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink...
method sample_one_video (line 325) | def sample_one_video(self, sample_img, sample_audio_hubert, sample_pos...
method get_grid (line 408) | def get_grid(self, b, nf, H, W, normalize=True):
method set_requires_grad (line 418) | def set_requires_grad(self, nets, requires_grad=False):
FILE: DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi.py
function exists (line 27) | def exists(x):
function noop (line 31) | def noop(*args, **kwargs):
function is_odd (line 35) | def is_odd(n):
function default (line 39) | def default(val, d):
function cycle (line 45) | def cycle(dl):
function num_to_groups (line 51) | def num_to_groups(num, divisor):
function prob_mask_like (line 60) | def prob_mask_like(shape, prob, device):
function is_list_str (line 69) | def is_list_str(x):
class RelativePositionBias (line 77) | class RelativePositionBias(nn.Module):
method __init__ (line 78) | def __init__(
method _relative_position_bucket (line 90) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 109) | def forward(self, n, device):
class EMA (line 123) | class EMA():
method __init__ (line 124) | def __init__(self, beta):
method update_model_average (line 128) | def update_model_average(self, ma_model, current_model):
method update_average (line 133) | def update_average(self, old, new):
class Residual (line 139) | class Residual(nn.Module):
method __init__ (line 140) | def __init__(self, fn):
method forward (line 144) | def forward(self, x, *args, **kwargs):
class SinusoidalPosEmb (line 148) | class SinusoidalPosEmb(nn.Module):
method __init__ (line 149) | def __init__(self, dim):
method forward (line 153) | def forward(self, x):
function Upsample (line 163) | def Upsample(dim, use_deconv=True, padding_mode="reflect"):
function Downsample (line 173) | def Downsample(dim):
class LayerNorm (line 177) | class LayerNorm(nn.Module):
method __init__ (line 178) | def __init__(self, dim, eps=1e-5):
method forward (line 183) | def forward(self, x):
class LayerNorm_img (line 188) | class LayerNorm_img(nn.Module):
method __init__ (line 189) | def __init__(self, dim, stable = False):
method forward (line 194) | def forward(self, x):
class PreNorm (line 203) | class PreNorm(nn.Module):
method __init__ (line 204) | def __init__(self, dim, fn):
method forward (line 209) | def forward(self, x, **kwargs):
class Identity (line 213) | class Identity(nn.Module):
method __init__ (line 214) | def __init__(self, *args, **kwargs):
method forward (line 217) | def forward(self, x, *args, **kwargs):
function l2norm (line 220) | def l2norm(t):
class Block (line 224) | class Block(nn.Module):
method __init__ (line 225) | def __init__(self, dim, dim_out, groups=8):
method forward (line 231) | def forward(self, x, time_scale_shift=None, audio_scale_shift=None):
class ResnetBlock (line 249) | class ResnetBlock(nn.Module):
method __init__ (line 250) | def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=N...
method forward (line 266) | def forward(self, x, time_emb=None, audio_emb=None):
class ResnetBlock_ca (line 287) | class ResnetBlock_ca(nn.Module):
method __init__ (line 288) | def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=N...
method forward (line 317) | def forward(self, x, time_emb=None, audio_emb=None):
class ResnetBlock_ca_mul (line 361) | class ResnetBlock_ca_mul(nn.Module):
method __init__ (line 362) | def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=N...
method forward (line 417) | def forward(self, x, time_emb=None, audio_emb=None):
class CrossAttention (line 479) | class CrossAttention(nn.Module):
method __init__ (line 480) | def __init__(
method forward (line 514) | def forward(self, x, context, mask = None):
class LinearCrossAttention (line 559) | class LinearCrossAttention(CrossAttention):
method forward (line 560) | def forward(self, x, context, mask = None):
class SpatialLinearAttention (line 599) | class SpatialLinearAttention(nn.Module):
method __init__ (line 600) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 608) | def forward(self, x):
class EinopsToAndFrom (line 629) | class EinopsToAndFrom(nn.Module):
method __init__ (line 630) | def __init__(self, from_einops, to_einops, fn):
method forward (line 636) | def forward(self, x, **kwargs):
class Attention (line 645) | class Attention(nn.Module):
method __init__ (line 646) | def __init__(
method forward (line 662) | def forward(
class Unet3D (line 725) | class Unet3D(nn.Module):
method __init__ (line 726) | def __init__(
method forward_with_cond_scale (line 875) | def forward_with_cond_scale(
method forward (line 888) | def forward(
class DynamicNfUnet3D (line 955) | class DynamicNfUnet3D(Unet3D):
method __init__ (line 956) | def __init__(self, default_num_frames=20, *args, **kwargs):
method update_num_frames (line 960) | def update_num_frames(self, new_num_frames):
function extract (line 965) | def extract(a, t, x_shape):
function cosine_beta_schedule (line 971) | def cosine_beta_schedule(timesteps, s=0.008):
class GaussianDiffusion (line 984) | class GaussianDiffusion(nn.Module):
method __init__ (line 985) | def __init__(
method q_mean_variance (line 1062) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 1068) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 1074) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 1083) | def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, c...
method p_sample (line 1109) | def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=...
method p_sample_loop (line 1120) | def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.):
method sample (line 1134) | def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=...
method ddim_sample (line 1153) | def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoi...
method interpolate (line 1207) | def interpolate(self, x1, x2, t=None, lam=0.5):
method q_sample (line 1222) | def q_sample(self, x_start, t, noise=None):
method p_losses (line 1230) | def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, ...
method forward (line 1270) | def forward(self, x, fea, bbox_mask, cond, *args, **kwargs):
function seek_all_images (line 1289) | def seek_all_images(img, channels=3):
class DynamicNfGaussianDiffusion (line 1303) | class DynamicNfGaussianDiffusion(GaussianDiffusion):
method __init__ (line 1304) | def __init__(self, default_num_frames=20, *args, **kwargs):
method update_num_frames (line 1308) | def update_num_frames(self, new_num_frames):
function video_tensor_to_gif (line 1313) | def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
function gif_to_tensor (line 1322) | def gif_to_tensor(path, channels=3, transform=T.ToTensor()):
function identity (line 1328) | def identity(t, *args, **kwargs):
function normalize_img (line 1332) | def normalize_img(t):
function cast_num_frames (line 1340) | def cast_num_frames(t, *, frames):
FILE: DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py
function exists (line 27) | def exists(x):
function noop (line 31) | def noop(*args, **kwargs):
function is_odd (line 35) | def is_odd(n):
function default (line 39) | def default(val, d):
function cycle (line 45) | def cycle(dl):
function num_to_groups (line 51) | def num_to_groups(num, divisor):
function prob_mask_like (line 60) | def prob_mask_like(shape, prob, device):
function is_list_str (line 69) | def is_list_str(x):
class RelativePositionBias (line 77) | class RelativePositionBias(nn.Module):
method __init__ (line 78) | def __init__(
method _relative_position_bucket (line 92) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 111) | def forward(self, n, device):
class EMA (line 125) | class EMA():
method __init__ (line 126) | def __init__(self, beta):
method update_model_average (line 130) | def update_model_average(self, ma_model, current_model):
method update_average (line 135) | def update_average(self, old, new):
class Residual (line 141) | class Residual(nn.Module):
method __init__ (line 142) | def __init__(self, fn):
method forward (line 146) | def forward(self, x, *args, **kwargs):
class SinusoidalPosEmb (line 150) | class SinusoidalPosEmb(nn.Module):
method __init__ (line 151) | def __init__(self, dim):
method forward (line 155) | def forward(self, x):
function Upsample (line 165) | def Upsample(dim, use_deconv=True, padding_mode="reflect"):
function Downsample (line 175) | def Downsample(dim):
class LayerNorm (line 179) | class LayerNorm(nn.Module):
method __init__ (line 180) | def __init__(self, dim, eps=1e-5):
method forward (line 185) | def forward(self, x):
class LayerNorm_img (line 190) | class LayerNorm_img(nn.Module):
method __init__ (line 191) | def __init__(self, dim, stable = False):
method forward (line 196) | def forward(self, x):
class PreNorm (line 205) | class PreNorm(nn.Module):
method __init__ (line 206) | def __init__(self, dim, fn):
method forward (line 211) | def forward(self, x, **kwargs):
class Identity (line 215) | class Identity(nn.Module):
method __init__ (line 216) | def __init__(self, *args, **kwargs):
method forward (line 219) | def forward(self, x, *args, **kwargs):
function l2norm (line 222) | def l2norm(t):
class Block (line 226) | class Block(nn.Module):
method __init__ (line 227) | def __init__(self, dim, dim_out, groups=8):
method forward (line 233) | def forward(self, x, time_scale_shift=None, audio_scale_shift=None):
class ResnetBlock (line 251) | class ResnetBlock(nn.Module):
method __init__ (line 252) | def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=N...
method forward (line 268) | def forward(self, x, time_emb=None, audio_emb=None):
class ResnetBlock_ca (line 289) | class ResnetBlock_ca(nn.Module):
method __init__ (line 290) | def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=N...
method forward (line 319) | def forward(self, x, time_emb=None, audio_emb=None):
class ResnetBlock_ca_mul (line 363) | class ResnetBlock_ca_mul(nn.Module):
method __init__ (line 364) | def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=N...
method forward (line 419) | def forward(self, x, time_emb=None, audio_emb=None):
class CrossAttention (line 481) | class CrossAttention(nn.Module):
method __init__ (line 482) | def __init__(
method forward (line 516) | def forward(self, x, context, mask = None):
class LinearCrossAttention (line 561) | class LinearCrossAttention(CrossAttention):
method forward (line 562) | def forward(self, x, context, mask = None):
class SpatialLinearAttention (line 602) | class SpatialLinearAttention(nn.Module):
method __init__ (line 603) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 611) | def forward(self, x):
class EinopsToAndFrom (line 632) | class EinopsToAndFrom(nn.Module):
method __init__ (line 633) | def __init__(self, from_einops, to_einops, fn):
method forward (line 639) | def forward(self, x, **kwargs):
class Attention (line 648) | class Attention(nn.Module):
method __init__ (line 649) | def __init__(
method forward (line 665) | def forward(
class Unet3D (line 728) | class Unet3D(nn.Module):
method __init__ (line 729) | def __init__(
method forward_with_cond_scale (line 879) | def forward_with_cond_scale(
method forward (line 892) | def forward(
class DynamicNfUnet3D (line 959) | class DynamicNfUnet3D(Unet3D):
method __init__ (line 960) | def __init__(self, default_num_frames=20, *args, **kwargs):
method update_num_frames (line 964) | def update_num_frames(self, new_num_frames):
function extract (line 969) | def extract(a, t, x_shape):
function cosine_beta_schedule (line 975) | def cosine_beta_schedule(timesteps, s=0.008):
class GaussianDiffusion (line 988) | class GaussianDiffusion(nn.Module):
method __init__ (line 989) | def __init__(
method q_mean_variance (line 1066) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 1072) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 1078) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 1087) | def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, c...
method p_sample (line 1113) | def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=...
method p_sample_loop (line 1124) | def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.):
method sample (line 1138) | def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=...
method ddim_sample (line 1157) | def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoi...
method interpolate (line 1211) | def interpolate(self, x1, x2, t=None, lam=0.5):
method q_sample (line 1226) | def q_sample(self, x_start, t, noise=None):
method p_losses (line 1234) | def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, ...
method forward (line 1274) | def forward(self, x, fea, bbox_mask, cond, *args, **kwargs):
function seek_all_images (line 1293) | def seek_all_images(img, channels=3):
class DynamicNfGaussianDiffusion (line 1307) | class DynamicNfGaussianDiffusion(GaussianDiffusion):
method __init__ (line 1308) | def __init__(self, default_num_frames=20, *args, **kwargs):
method update_num_frames (line 1312) | def update_num_frames(self, new_num_frames):
function video_tensor_to_gif (line 1317) | def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
function gif_to_tensor (line 1326) | def gif_to_tensor(path, channels=3, transform=T.ToTensor()):
function identity (line 1332) | def identity(t, *args, **kwargs):
function normalize_img (line 1336) | def normalize_img(t):
function cast_num_frames (line 1344) | def cast_num_frames(t, *, frames):
FILE: DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py
function exists (line 28) | def exists(x):
function noop (line 32) | def noop(*args, **kwargs):
function is_odd (line 36) | def is_odd(n):
function default (line 40) | def default(val, d):
function cycle (line 46) | def cycle(dl):
function num_to_groups (line 52) | def num_to_groups(num, divisor):
function prob_mask_like (line 61) | def prob_mask_like(shape, prob, device):
function is_list_str (line 70) | def is_list_str(x):
class RelativePositionBias (line 78) | class RelativePositionBias(nn.Module):
method __init__ (line 79) | def __init__(
method _relative_position_bucket (line 93) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 112) | def forward(self, n, device):
class EMA (line 126) | class EMA():
method __init__ (line 127) | def __init__(self, beta):
method update_model_average (line 131) | def update_model_average(self, ma_model, current_model):
method update_average (line 136) | def update_average(self, old, new):
class Residual (line 142) | class Residual(nn.Module):
method __init__ (line 143) | def __init__(self, fn):
method forward (line 147) | def forward(self, x, *args, **kwargs):
class SinusoidalPosEmb (line 151) | class SinusoidalPosEmb(nn.Module):
method __init__ (line 152) | def __init__(self, dim):
method forward (line 156) | def forward(self, x):
function Upsample (line 166) | def Upsample(dim, use_deconv=True, padding_mode="reflect"):
function Downsample (line 176) | def Downsample(dim):
class LayerNorm (line 180) | class LayerNorm(nn.Module):
method __init__ (line 181) | def __init__(self, dim, eps=1e-5):
method forward (line 186) | def forward(self, x):
class LayerNorm_img (line 191) | class LayerNorm_img(nn.Module):
method __init__ (line 192) | def __init__(self, dim, stable = False):
method forward (line 197) | def forward(self, x):
class PreNorm (line 206) | class PreNorm(nn.Module):
method __init__ (line 207) | def __init__(self, dim, fn):
method forward (line 212) | def forward(self, x, **kwargs):
class Identity (line 216) | class Identity(nn.Module):
method __init__ (line 217) | def __init__(self, *args, **kwargs):
method forward (line 220) | def forward(self, x, *args, **kwargs):
function l2norm (line 223) | def l2norm(t):
class Block (line 227) | class Block(nn.Module):
method __init__ (line 228) | def __init__(self, dim, dim_out, groups=8):
method forward (line 234) | def forward(self, x, time_scale_shift=None, audio_scale_shift=None):
class ResnetBlock (line 252) | class ResnetBlock(nn.Module):
method __init__ (line 253) | def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=N...
method forward (line 269) | def forward(self, x, time_emb=None, audio_emb=None):
class ResnetBlock_ca (line 290) | class ResnetBlock_ca(nn.Module):
method __init__ (line 291) | def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=N...
method forward (line 320) | def forward(self, x, time_emb=None, audio_emb=None):
class ResnetBlock_ca_mul (line 364) | class ResnetBlock_ca_mul(nn.Module):
method __init__ (line 365) | def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=N...
method forward (line 420) | def forward(self, x, time_emb=None, audio_emb=None):
class CrossAttention (line 482) | class CrossAttention(nn.Module):
method __init__ (line 483) | def __init__(
method forward (line 517) | def forward(self, x, context, mask = None):
class LinearCrossAttention (line 562) | class LinearCrossAttention(CrossAttention):
method forward (line 563) | def forward(self, x, context, mask = None):
class SpatialLinearAttention (line 603) | class SpatialLinearAttention(nn.Module):
method __init__ (line 604) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 612) | def forward(self, x):
class EinopsToAndFrom (line 633) | class EinopsToAndFrom(nn.Module):
method __init__ (line 634) | def __init__(self, from_einops, to_einops, fn):
method forward (line 640) | def forward(self, x, **kwargs):
class Attention (line 649) | class Attention(nn.Module):
method __init__ (line 650) | def __init__(
method forward (line 666) | def forward(
class Unet3D (line 729) | class Unet3D(nn.Module):
method __init__ (line 730) | def __init__(
method forward_with_cond_scale (line 881) | def forward_with_cond_scale(
method forward (line 894) | def forward(
class DynamicNfUnet3D (line 961) | class DynamicNfUnet3D(Unet3D):
method __init__ (line 962) | def __init__(self, default_num_frames=20, *args, **kwargs):
method update_num_frames (line 966) | def update_num_frames(self, new_num_frames):
function extract (line 971) | def extract(a, t, x_shape):
function cosine_beta_schedule (line 977) | def cosine_beta_schedule(timesteps, s=0.008):
class GaussianDiffusion (line 990) | class GaussianDiffusion(nn.Module):
method __init__ (line 991) | def __init__(
method q_mean_variance (line 1068) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 1074) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 1080) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 1089) | def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, c...
method p_sample (line 1115) | def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=...
method p_sample_loop (line 1126) | def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.):
method sample (line 1140) | def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=...
method ddim_sample (line 1159) | def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoi...
method interpolate (line 1213) | def interpolate(self, x1, x2, t=None, lam=0.5):
method q_sample (line 1228) | def q_sample(self, x_start, t, noise=None):
method p_losses (line 1236) | def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, ...
method forward (line 1276) | def forward(self, x, fea, bbox_mask, cond, *args, **kwargs):
function seek_all_images (line 1295) | def seek_all_images(img, channels=3):
class DynamicNfGaussianDiffusion (line 1309) | class DynamicNfGaussianDiffusion(GaussianDiffusion):
method __init__ (line 1310) | def __init__(self, default_num_frames=20, *args, **kwargs):
method update_num_frames (line 1314) | def update_num_frames(self, new_num_frames):
function video_tensor_to_gif (line 1319) | def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
function gif_to_tensor (line 1328) | def gif_to_tensor(path, channels=3, transform=T.ToTensor()):
function identity (line 1334) | def identity(t, *args, **kwargs):
function normalize_img (line 1338) | def normalize_img(t):
function cast_num_frames (line 1346) | def cast_num_frames(t, *, frames):
FILE: DM_3/train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D.py
function get_arguments (line 112) | def get_arguments():
function sample_img (line 157) | def sample_img(rec_img_batch, idx=0):
function main (line 166) | def main():
class AverageMeter (line 504) | class AverageMeter(object):
method __init__ (line 507) | def __init__(self):
method reset (line 510) | def reset(self):
method update (line 516) | def update(self, val, n=1):
function setup_seed (line 523) | def setup_seed(seed):
FILE: DM_3/train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D_s2.py
function get_arguments (line 118) | def get_arguments():
function sample_img (line 163) | def sample_img(rec_img_batch, idx=0):
function main (line 172) | def main():
class AverageMeter (line 531) | class AverageMeter(object):
method __init__ (line 534) | def __init__(self):
method reset (line 537) | def reset(self):
method update (line 543) | def update(self, val, n=1):
function setup_seed (line 550) | def setup_seed(seed):
FILE: DM_3/utils.py
class MultiEpochsDataLoader (line 5) | class MultiEpochsDataLoader(torch.utils.data.DataLoader):
method __init__ (line 7) | def __init__(self, *args, **kwargs):
method __len__ (line 14) | def __len__(self):
method __iter__ (line 17) | def __iter__(self):
class _RepeatSampler (line 22) | class _RepeatSampler(object):
method __init__ (line 28) | def __init__(self, sampler):
method __iter__ (line 31) | def __iter__(self):
FILE: LFG/augmentation.py
function crop_clip (line 20) | def crop_clip(clip, min_h, min_w, h, w):
function pad_clip (line 34) | def pad_clip(clip, h, w):
function resize_clip (line 42) | def resize_clip(clip, size, interpolation='bilinear'):
function get_resize_sizes (line 81) | def get_resize_sizes(im_h, im_w, size):
class RandomFlip (line 91) | class RandomFlip(object):
method __init__ (line 92) | def __init__(self, time_flip=False, horizontal_flip=False):
method __call__ (line 96) | def __call__(self, clip):
class RandomResize (line 105) | class RandomResize(object):
method __init__ (line 115) | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
method __call__ (line 119) | def __call__(self, clip):
class RandomCrop (line 136) | class RandomCrop(object):
method __init__ (line 143) | def __init__(self, size):
method __call__ (line 149) | def __call__(self, clip):
class RandomRotation (line 175) | class RandomRotation(object):
method __init__ (line 184) | def __init__(self, degrees):
method __call__ (line 197) | def __call__(self, clip):
class ColorJitter (line 217) | class ColorJitter(object):
method __init__ (line 230) | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
method get_params (line 236) | def get_params(self, brightness, contrast, saturation, hue):
method __call__ (line 261) | def __call__(self, clip):
class AllAugmentationTransform (line 323) | class AllAugmentationTransform:
method __init__ (line 324) | def __init__(self, resize_param=None, rotation_param=None, flip_param=...
method __call__ (line 342) | def __call__(self, clip):
FILE: LFG/frames_dataset.py
function read_video (line 26) | def read_video(name, frame_shape):
class FramesDataset (line 76) | class FramesDataset(Dataset):
method __init__ (line 84) | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=Fa...
method __len__ (line 118) | def __len__(self):
method __getitem__ (line 121) | def __getitem__(self, idx):
class DatasetRepeater (line 178) | class DatasetRepeater(Dataset):
method __init__ (line 183) | def __init__(self, dataset, num_repeats=100):
method __len__ (line 187) | def __len__(self):
method __getitem__ (line 190) | def __getitem__(self, idx):
class PairedDataset (line 194) | class PairedDataset(Dataset):
method __init__ (line 199) | def __init__(self, initial_dataset, number_of_pairs, seed=0):
method __len__ (line 224) | def __len__(self):
method __getitem__ (line 227) | def __getitem__(self, idx):
FILE: LFG/hdtf_dataset.py
function resize (line 18) | def resize(im, desired_size, interpolation):
class FramesDataset (line 36) | class FramesDataset(Dataset):
method __init__ (line 44) | def __init__(self, root_dir, frame_shape=256, id_sampling=False,
method __len__ (line 67) | def __len__(self):
method __getitem__ (line 70) | def __getitem__(self, idx):
class DatasetRepeater (line 109) | class DatasetRepeater(Dataset):
method __init__ (line 114) | def __init__(self, dataset, num_repeats=100):
method __len__ (line 118) | def __len__(self):
method __getitem__ (line 121) | def __getitem__(self, idx):
FILE: LFG/modules/avd_network.py
class AVDNetwork (line 13) | class AVDNetwork(nn.Module):
method __init__ (line 18) | def __init__(self, num_regions, id_bottle_size=64, pose_bottle_size=64...
method region_params_to_emb (line 64) | def region_params_to_emb(x):
method emb_to_region_params (line 71) | def emb_to_region_params(self, emb):
method forward (line 77) | def forward(self, x_id, x_pose, alpha=0.2):
FILE: LFG/modules/bg_motion_predictor.py
class BGMotionPredictor (line 15) | class BGMotionPredictor(nn.Module):
method __init__ (line 20) | def __init__(self, block_expansion, num_channels, max_features, num_bl...
method forward (line 42) | def forward(self, source_image, driving_image):
FILE: LFG/modules/flow_autoenc.py
class FlowAE (line 14) | class FlowAE(nn.Module):
method __init__ (line 15) | def __init__(self, is_train=False,
method forward (line 39) | def forward(self):
method set_train_input (line 49) | def set_train_input(self, ref_img, dri_img):
FILE: LFG/modules/generator.py
class Generator (line 19) | class Generator(nn.Module):
method __init__ (line 25) | def __init__(self, num_channels, num_regions, block_expansion, max_fea...
method deform_input (line 62) | def deform_input(inp, optical_flow):
method apply_optical (line 71) | def apply_optical(self, input_previous=None, input_skip=None, motion_p...
method forward (line 92) | def forward(self, source_image, driving_region_params, source_region_p...
method compute_fea (line 132) | def compute_fea(self, source_image):
method forward_with_flow (line 138) | def forward_with_flow(self, source_image, optical_flow, occlusion_map):
FILE: LFG/modules/model.py
class Vgg19 (line 19) | class Vgg19(torch.nn.Module):
method __init__ (line 24) | def __init__(self, requires_grad=False):
method forward (line 52) | def forward(self, x):
class ImagePyramide (line 63) | class ImagePyramide(torch.nn.Module):
method __init__ (line 68) | def __init__(self, scales, num_channels):
method forward (line 75) | def forward(self, x):
class Transform (line 82) | class Transform:
method __init__ (line 87) | def __init__(self, bs, **kwargs):
method transform_frame (line 102) | def transform_frame(self, frame):
method warp_coordinates (line 108) | def warp_coordinates(self, coordinates):
method jacobian (line 129) | def jacobian(self, coordinates):
function detach_kp (line 137) | def detach_kp(kp):
class ReconstructionModel (line 141) | class ReconstructionModel(torch.nn.Module):
method __init__ (line 146) | def __init__(self, region_predictor, bg_predictor, generator, train_pa...
method forward (line 164) | def forward(self, x):
FILE: LFG/modules/pixelwise_flow_predictor.py
class PixelwiseFlowPredictor (line 17) | class PixelwiseFlowPredictor(nn.Module):
method __init__ (line 23) | def __init__(self, block_expansion, num_blocks, max_features, num_regi...
method create_heatmap_representations (line 48) | def create_heatmap_representations(self, source_image, driving_region_...
method create_sparse_motions (line 66) | def create_sparse_motions(self, source_image, driving_region_params, s...
method create_deformed_source_image (line 95) | def create_deformed_source_image(self, source_image, sparse_motions):
method forward (line 104) | def forward(self, source_image, driving_region_params, source_region_p...
FILE: LFG/modules/region_predictor.py
function svd (line 16) | def svd(covar, fast=False):
class RegionPredictor (line 28) | class RegionPredictor(nn.Module):
method __init__ (line 33) | def __init__(self, block_expansion, num_regions, num_channels, max_fea...
method region2affine (line 60) | def region2affine(self, region):
method forward (line 77) | def forward(self, x):
FILE: LFG/modules/util.py
function region2gaussian (line 22) | def region2gaussian(center, covar, spatial_size):
function make_coordinate_grid (line 51) | def make_coordinate_grid(spatial_size, type):
class ResBlock2d (line 70) | class ResBlock2d(nn.Module):
method __init__ (line 75) | def __init__(self, in_features, kernel_size, padding):
method forward (line 84) | def forward(self, x):
class UpBlock2d (line 95) | class UpBlock2d(nn.Module):
method __init__ (line 100) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 107) | def forward(self, x):
class DownBlock2d (line 115) | class DownBlock2d(nn.Module):
method __init__ (line 120) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 127) | def forward(self, x):
class SameBlock2d (line 135) | class SameBlock2d(nn.Module):
method __init__ (line 140) | def __init__(self, in_features, out_features, groups=1, kernel_size=3,...
method forward (line 146) | def forward(self, x):
class Encoder (line 153) | class Encoder(nn.Module):
method __init__ (line 158) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 168) | def forward(self, x):
class Decoder (line 175) | class Decoder(nn.Module):
method __init__ (line 180) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 193) | def forward(self, x):
class Hourglass (line 202) | class Hourglass(nn.Module):
method __init__ (line 207) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 213) | def forward(self, x):
class AntiAliasInterpolation2d (line 217) | class AntiAliasInterpolation2d(nn.Module):
method __init__ (line 222) | def __init__(self, channels, scale):
method forward (line 256) | def forward(self, input):
function to_homogeneous (line 267) | def to_homogeneous(coordinates):
function from_homogeneous (line 275) | def from_homogeneous(coordinates):
function draw_colored_heatmap (line 279) | def draw_colored_heatmap(heatmap, colormap, bg_color):
class Visualizer (line 301) | class Visualizer:
method __init__ (line 302) | def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbo...
method draw_image_with_kp (line 308) | def draw_image_with_kp(self, image, kp_array):
method create_image_column_with_kp (line 318) | def create_image_column_with_kp(self, images, kp):
method create_image_column (line 322) | def create_image_column(self, images):
method create_image_grid (line 329) | def create_image_grid(self, *args):
method sample (line 339) | def sample(x, index):
method visualize (line 342) | def visualize(self, driving, source, out, index=0):
FILE: LFG/run_hdtf.py
class Logger (line 29) | class Logger(object):
method __init__ (line 30) | def __init__(self, filename='default.log', stream=sys.stdout):
method write (line 34) | def write(self, message):
method flush (line 38) | def flush(self):
function setup_seed (line 42) | def setup_seed(seed):
FILE: LFG/run_hdtf_crema.py
class Logger (line 29) | class Logger(object):
method __init__ (line 30) | def __init__(self, filename='default.log', stream=sys.stdout):
method write (line 34) | def write(self, message):
method flush (line 38) | def flush(self):
function setup_seed (line 42) | def setup_seed(seed):
FILE: LFG/sync_batchnorm/batchnorm.py
function _sum_ft (line 24) | def _sum_ft(tensor):
function _unsqueeze_ft (line 29) | def _unsqueeze_ft(tensor):
class _SynchronizedBatchNorm (line 38) | class _SynchronizedBatchNorm(_BatchNorm):
method __init__ (line 39) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
method forward (line 48) | def forward(self, input):
method __data_parallel_replicate__ (line 80) | def __data_parallel_replicate__(self, ctx, copy_id):
method _data_parallel_master (line 90) | def _data_parallel_master(self, intermediates):
method _compute_mean_std (line 113) | def _compute_mean_std(self, sum_, ssum, size):
class SynchronizedBatchNorm1d (line 128) | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
method _check_input_dim (line 184) | def _check_input_dim(self, input):
class SynchronizedBatchNorm2d (line 191) | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
method _check_input_dim (line 247) | def _check_input_dim(self, input):
class SynchronizedBatchNorm3d (line 254) | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
method _check_input_dim (line 311) | def _check_input_dim(self, input):
FILE: LFG/sync_batchnorm/comm.py
class FutureResult (line 18) | class FutureResult(object):
method __init__ (line 21) | def __init__(self):
method put (line 26) | def put(self, result):
method get (line 32) | def get(self):
class SlavePipe (line 46) | class SlavePipe(_SlavePipeBase):
method run_slave (line 49) | def run_slave(self, msg):
class SyncMaster (line 56) | class SyncMaster(object):
method __init__ (line 67) | def __init__(self, master_callback):
method __getstate__ (line 78) | def __getstate__(self):
method __setstate__ (line 81) | def __setstate__(self, state):
method register_slave (line 84) | def register_slave(self, identifier):
method run_master (line 102) | def run_master(self, master_msg):
method nr_slaves (line 136) | def nr_slaves(self):
FILE: LFG/sync_batchnorm/replicate.py
class CallbackContext (line 23) | class CallbackContext(object):
function execute_replication_callbacks (line 27) | def execute_replication_callbacks(modules):
class DataParallelWithCallback (line 50) | class DataParallelWithCallback(DataParallel):
method replicate (line 64) | def replicate(self, module, device_ids):
function patch_replication_callback (line 70) | def patch_replication_callback(data_parallel):
FILE: LFG/sync_batchnorm/unittest.py
function as_numpy (line 17) | def as_numpy(v):
class TorchTestCase (line 23) | class TorchTestCase(unittest.TestCase):
method assertTensorClose (line 24) | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
FILE: LFG/test_flowautoenc_crema_video.py
function get_arguments (line 67) | def get_arguments():
function extract_audio_by_frames (line 92) | def extract_audio_by_frames(input_wav_path, start_frame_index, num_frame...
function sample_img (line 111) | def sample_img(rec_img_batch):
function main (line 120) | def main():
class AverageMeter (line 312) | class AverageMeter(object):
method __init__ (line 315) | def __init__(self):
method reset (line 318) | def reset(self):
method update (line 324) | def update(self, val, n=1):
function setup_seed (line 331) | def setup_seed(seed):
FILE: LFG/test_flowautoenc_hdtf_video.py
function get_arguments (line 67) | def get_arguments():
function extract_audio_by_frames (line 92) | def extract_audio_by_frames(input_wav_path, start_frame_index, num_frame...
function sample_img (line 111) | def sample_img(rec_img_batch):
function main (line 120) | def main():
class AverageMeter (line 312) | class AverageMeter(object):
method __init__ (line 315) | def __init__(self):
method reset (line 318) | def reset(self):
method update (line 324) | def update(self, val, n=1):
function setup_seed (line 331) | def setup_seed(seed):
FILE: LFG/test_flowautoenc_hdtf_video_256.py
function get_arguments (line 67) | def get_arguments():
function extract_audio_by_frames (line 92) | def extract_audio_by_frames(input_wav_path, start_frame_index, num_frame...
function sample_img (line 111) | def sample_img(rec_img_batch):
function main (line 120) | def main():
class AverageMeter (line 292) | class AverageMeter(object):
method __init__ (line 295) | def __init__(self):
method reset (line 298) | def reset(self):
method update (line 304) | def update(self, val, n=1):
function setup_seed (line 311) | def setup_seed(seed):
FILE: LFG/train.py
class AverageMeter (line 16) | class AverageMeter(object):
method __init__ (line 19) | def __init__(self):
method reset (line 22) | def reset(self):
method update (line 28) | def update(self, val, n=1):
function train (line 35) | def train(config, generator, region_predictor, bg_predictor, checkpoint,...
FILE: LFG/vis_flow.py
function visualize_dense_optical_flow (line 5) | def visualize_dense_optical_flow(flow_tensor, save_path):
function grid2flow (line 27) | def grid2flow(warped_grid, grid_size=64, img_size=256):
FILE: PBnet/src/datasets/datasets_hdtf_pos_chunk_norm_2_fast.py
function resize (line 28) | def resize(im, desired_size, interpolation):
class HDTF (line 44) | class HDTF(data.Dataset):
method __init__ (line 45) | def __init__(self, data_dir, max_num_frames=80, mode='train'):
method check_head (line 96) | def check_head(self, frame_list, video_name, start, end):
method get_block_data_for_two (line 107) | def get_block_data_for_two(self, path, start, end):
method get_block_data (line 138) | def get_block_data(self, path, start, end):
method check_len (line 173) | def check_len(self, name):
method __len__ (line 177) | def __len__(self):
method __getitem__ (line 180) | def __getitem__(self, idx):
method update_parameters (line 211) | def update_parameters(self, parameters):
FILE: PBnet/src/datasets/datasets_hdtf_pos_chunk_norm_eye_fast.py
function resize (line 28) | def resize(im, desired_size, interpolation):
class HDTF (line 44) | class HDTF(data.Dataset):
method __init__ (line 45) | def __init__(self, data_dir, max_num_frames=80, mode='train'):
method check_head (line 117) | def check_head(self, frame_list, video_name, start, end):
method get_block_data_for_two (line 127) | def get_block_data_for_two(self, path, start, end):
method get_block_data (line 158) | def get_block_data(self, path, start, end):
method check_len (line 193) | def check_len(self, name):
method __len__ (line 197) | def __len__(self):
method __getitem__ (line 200) | def __getitem__(self, idx):
method update_parameters (line 265) | def update_parameters(self, parameters):
FILE: PBnet/src/datasets/datasets_hdtf_pos_df.py
function resize (line 22) | def resize(im, desired_size, interpolation):
class HDTF (line 38) | class HDTF(data.Dataset):
method __init__ (line 39) | def __init__(self, data_dir, max_num_frames=80, min_num_frames=40, mod...
method __len__ (line 92) | def __len__(self):
method __getitem__ (line 100) | def __getitem__(self, idx):
method update_parameters (line 159) | def update_parameters(self, parameters):
FILE: PBnet/src/datasets/datasets_hdtf_pos_dict_norm_2.py
function resize (line 28) | def resize(im, desired_size, interpolation):
class HDTF (line 44) | class HDTF(data.Dataset):
method __init__ (line 45) | def __init__(self, data_dir, max_num_frames=80, mode='train'):
method check_head (line 110) | def check_head(self, frame_list, video_name, start, end):
method get_block_data_for_two (line 120) | def get_block_data_for_two(self, path, start, end):
method get_block_data (line 151) | def get_block_data(self, path, start, end):
method check_len (line 186) | def check_len(self, name):
method __len__ (line 190) | def __len__(self):
method __getitem__ (line 193) | def __getitem__(self, idx):
method update_parameters (line 240) | def update_parameters(self, parameters):
FILE: PBnet/src/datasets/datasets_hdtf_wpose_lmk_block.py
function resize (line 28) | def resize(im, desired_size, interpolation):
class HDTF (line 44) | class HDTF(data.Dataset):
method __init__ (line 45) | def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=8...
method check_head (line 102) | def check_head(self, frame_list, video_name, start, end):
method get_block_data_for_two (line 112) | def get_block_data_for_two(self, path, start, end):
method get_block_data (line 143) | def get_block_data(self, path, start, end):
method check_len (line 178) | def check_len(self, name):
method __len__ (line 183) | def __len__(self):
method __getitem__ (line 186) | def __getitem__(self, idx):
FILE: PBnet/src/datasets/get_dataset.py
function get_dataset (line 1) | def get_dataset(name="ntu13"):
function get_datasets (line 13) | def get_datasets(parameters):
FILE: PBnet/src/datasets/tools.py
function parse_info_name (line 5) | def parse_info_name(path):
FILE: PBnet/src/evaluate/action2motion/accuracy.py
function calculate_accuracy (line 4) | def calculate_accuracy(model, motion_loader, num_labels, classifier, dev...
FILE: PBnet/src/evaluate/action2motion/diversity.py
function calculate_diversity_multimodality (line 6) | def calculate_diversity_multimodality(activations, labels, num_labels):
FILE: PBnet/src/evaluate/action2motion/evaluate.py
class A2MEvaluation (line 9) | class A2MEvaluation:
method __init__ (line 10) | def __init__(self, dataname, device):
method compute_features (line 31) | def compute_features(self, model, motionloader):
method calculate_activation_statistics (line 44) | def calculate_activation_statistics(activations):
method evaluate (line 50) | def evaluate(self, model, loaders):
FILE: PBnet/src/evaluate/action2motion/fid.py
function calculate_fid (line 6) | def calculate_fid(statistics_1, statistics_2):
function calculate_frechet_distance (line 11) | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
FILE: PBnet/src/evaluate/action2motion/models.py
class MotionDiscriminator (line 6) | class MotionDiscriminator(nn.Module):
method __init__ (line 7) | def __init__(self, input_size, hidden_size, hidden_layer, device, outp...
method forward (line 20) | def forward(self, motion_sequence, lengths=None, hidden_unit=None):
method initHidden (line 40) | def initHidden(self, num_samples, layer):
class MotionDiscriminatorForFID (line 44) | class MotionDiscriminatorForFID(MotionDiscriminator):
method forward (line 45) | def forward(self, motion_sequence, lengths=None, hidden_unit=None):
function load_classifier (line 70) | def load_classifier(dataset_type, input_size_raw, num_classes, device):
function load_classifier_for_fid (line 78) | def load_classifier_for_fid(dataset_type, input_size_raw, num_classes, d...
function test (line 86) | def test():
FILE: PBnet/src/evaluate/evaluate_cvae.py
function main (line 11) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_debug.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_f3.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_f3_debug.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_f3_mel.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_norm.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_norm_all.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_norm_all_seg.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_norm_all_seg_weye.py
function main (line 11) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_norm_all_seg_weye2.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_norm_eye_pose.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_norm_eye_pose_test.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/evaluate_cvae_onlyeye_all_seg.py
function main (line 10) | def main():
FILE: PBnet/src/evaluate/othermetrics/acceleration.py
function calculate_acceletation (line 7) | def calculate_acceletation(motionloader, device, xyz):
FILE: PBnet/src/evaluate/othermetrics/evaluation.py
class OtherMetricsEvaluation (line 8) | class OtherMetricsEvaluation:
method __init__ (line 16) | def __init__(self, device):
method compute_features (line 19) | def compute_features(self, model, motionloader, xyz=True):
method reconstructionloss (line 33) | def reconstructionloss(self, motionloader, xyz=True):
method evaluate (line 52) | def evaluate(self, model, num_classes, loaders, xyz=True):
FILE: PBnet/src/evaluate/stgcn/accuracy.py
function calculate_accuracy (line 4) | def calculate_accuracy(model, motion_loader, num_labels, classifier, dev...
FILE: PBnet/src/evaluate/stgcn/diversity.py
function calculate_diversity_multimodality (line 6) | def calculate_diversity_multimodality(activations, labels, num_labels, s...
FILE: PBnet/src/evaluate/stgcn/evaluate.py
class Evaluation (line 10) | class Evaluation:
method __init__ (line 11) | def __init__(self, dataname, parameters, device, seed=None):
method compute_features (line 35) | def compute_features(self, model, motionloader):
method calculate_activation_statistics (line 48) | def calculate_activation_statistics(activations):
method evaluate (line 54) | def evaluate(self, model, loaders):
FILE: PBnet/src/evaluate/stgcn/fid.py
function calculate_fid (line 6) | def calculate_fid(statistics_1, statistics_2):
function calculate_frechet_distance (line 11) | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
FILE: PBnet/src/evaluate/tables/archtable.py
function valformat (line 10) | def valformat(val, power=3):
function format_values (line 16) | def format_values(values, key):
function construct_table (line 33) | def construct_table(folder):
function parse_opts (line 176) | def parse_opts():
FILE: PBnet/src/evaluate/tables/bstable.py
function valformat (line 10) | def valformat(val, power=3):
function format_values (line 16) | def format_values(values, key):
function construct_table (line 33) | def construct_table(folder):
function parse_opts (line 150) | def parse_opts():
FILE: PBnet/src/evaluate/tables/easy_table.py
function get_gtname (line 9) | def get_gtname(mname):
function get_genname (line 13) | def get_genname(mname):
function get_reconsname (line 17) | def get_reconsname(mname):
function valformat (line 21) | def valformat(val, power=3):
function format_values (line 27) | def format_values(values, key, latex=True):
function print_results (line 46) | def print_results(folder, evaluation):
function parse_opts (line 94) | def parse_opts():
FILE: PBnet/src/evaluate/tables/easy_table_A2M.py
function valformat (line 9) | def valformat(val, power=3):
function construct_table (line 15) | def construct_table(folder, evaluation):
function parse_opts (line 77) | def parse_opts():
FILE: PBnet/src/evaluate/tables/kltable.py
function valformat (line 10) | def valformat(val, power=3):
function format_values (line 16) | def format_values(values, key):
function construct_table (line 33) | def construct_table(folder):
function parse_opts (line 149) | def parse_opts():
FILE: PBnet/src/evaluate/tables/latexmodela2m.py
function get_gtname (line 9) | def get_gtname(mname):
function get_genname (line 13) | def get_genname(mname):
function get_reconsname (line 17) | def get_reconsname(mname):
function construct_table (line 21) | def construct_table(folder, evaluation):
function parse_opts (line 78) | def parse_opts():
FILE: PBnet/src/evaluate/tables/latexmodelsa2m.py
function valformat (line 10) | def valformat(val, power=3):
function construct_table (line 16) | def construct_table(folder):
function parse_opts (line 121) | def parse_opts():
FILE: PBnet/src/evaluate/tables/latexmodelsstgcn.py
function get_gtname (line 10) | def get_gtname(mname):
function get_genname (line 14) | def get_genname(mname):
function get_reconsname (line 18) | def get_reconsname(mname):
function valformat (line 22) | def valformat(val, power=3):
function construct_table (line 28) | def construct_table(folder):
function parse_opts (line 137) | def parse_opts():
FILE: PBnet/src/evaluate/tables/losstable.py
function valformat (line 10) | def valformat(val, power=3):
function format_values (line 16) | def format_values(values, key):
function construct_table (line 33) | def construct_table(folder):
function parse_opts (line 171) | def parse_opts():
FILE: PBnet/src/evaluate/tables/maketable.py
function bold (line 31) | def bold(string):
function colorize_template (line 35) | def colorize_template(string, color):
function colorize_bold_template (line 39) | def colorize_bold_template(string, color):
function format_table (line 43) | def format_table(val, gtval, mname):
function get_gtname (line 94) | def get_gtname(mname):
function get_genname (line 98) | def get_genname(mname):
function get_reconsname (line 102) | def get_reconsname(mname):
function collect_tables (line 106) | def collect_tables(folder, expname, lastepoch=False, norecons=False):
function parse_opts (line 256) | def parse_opts():
FILE: PBnet/src/evaluate/tables/numlayertable.py
function valformat (line 10) | def valformat(val, power=3):
function format_values (line 16) | def format_values(values, key):
function construct_table (line 33) | def construct_table(folder):
function parse_opts (line 150) | def parse_opts():
FILE: PBnet/src/evaluate/tables/posereptable.py
function valformat (line 10) | def valformat(val, power=3):
function format_values (line 16) | def format_values(values, key):
function construct_table (line 33) | def construct_table(folder):
function parse_opts (line 157) | def parse_opts():
FILE: PBnet/src/evaluate/tools.py
function format_metrics (line 4) | def format_metrics(metrics, formatter="{:.6}"):
function save_metrics (line 11) | def save_metrics(path, metrics):
function load_metrics (line 16) | def load_metrics(path):
FILE: PBnet/src/evaluate/tvae_eval.py
function evaluate (line 17) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_norm.py
function transform (line 16) | def transform(x, min_val, max_val):
function evaluate (line 20) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_norm_all.py
function save_images_as_npy (line 16) | def save_images_as_npy(input_data, output_file):
function save_as_chunk (line 24) | def save_as_chunk(dir, data):
function transform (line 36) | def transform(x, min_val, max_val):
function evaluate (line 40) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_norm_eye_pose.py
function transform (line 17) | def transform(x, min_val, max_val):
function evaluate (line 21) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_norm_eye_pose_seg.py
function transform (line 18) | def transform(x, min_val, max_val):
function save_images_as_npy (line 22) | def save_images_as_npy(input_data, output_file):
function save_as_chunk (line 29) | def save_as_chunk(dir, data):
function evaluate (line 39) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_norm_seg.py
function save_images_as_npy (line 17) | def save_images_as_npy(input_data, output_file):
function save_as_chunk (line 25) | def save_as_chunk(dir, data):
function transform (line 37) | def transform(x, min_val, max_val):
function evaluate (line 41) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_onlyeye_all_seg.py
function save_images_as_npy (line 16) | def save_images_as_npy(input_data, output_file):
function save_as_chunk (line 24) | def save_as_chunk(dir, data):
function evaluate (line 40) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_single.py
function inv_transform (line 25) | def inv_transform(x, min_val, max_val):
function save_images_as_npy (line 29) | def save_images_as_npy(input_data, output_file):
function evaluate (line 43) | def evaluate(parameters_pose, parameters_blink, audio_path, init_pose_pa...
function get_arguments (line 114) | def get_arguments():
FILE: PBnet/src/evaluate/tvae_eval_single_both_eye_pose.py
function inv_transform (line 25) | def inv_transform(x, min_val, max_val):
function save_images_as_npy (line 29) | def save_images_as_npy(input_data, output_file):
function evaluate (line 43) | def evaluate(parameters, audio_path, init_pose_path, init_blink_path, ch...
function get_arguments (line 112) | def get_arguments():
FILE: PBnet/src/evaluate/tvae_eval_std.py
function evaluate (line 17) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_train.py
function evaluate (line 17) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_train_norm.py
function transform (line 16) | def transform(x, min_val, max_val):
function evaluate (line 20) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/evaluate/tvae_eval_train_std.py
function evaluate (line 17) | def evaluate(parameters, dataset, folder, checkpointname, epoch, niter):
FILE: PBnet/src/generate/generate_sequences.py
function generate_actions (line 16) | def generate_actions(beta, model, dataset, epoch, params, folder, num_fr...
function main (line 119) | def main():
FILE: PBnet/src/models/architectures/autotrans.py
function subsequent_mask (line 13) | def subsequent_mask(size: int):
function augment_x (line 25) | def augment_x(x, y, mask, lengths, num_classes, concatenate_time):
function augment_z (line 44) | def augment_z(z, y, mask, lengths, num_classes, concatenate_time):
class Decoder_AUTOTRANS (line 60) | class Decoder_AUTOTRANS(nn.Module):
method __init__ (line 61) | def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes...
method forward (line 112) | def forward(self, batch):
FILE: PBnet/src/models/architectures/fc.py
class Encoder_FC (line 6) | class Encoder_FC(nn.Module):
method __init__ (line 7) | def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes...
method forward (line 37) | def forward(self, batch):
class Decoder_FC (line 57) | class Decoder_FC(nn.Module):
method __init__ (line 58) | def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes...
method forward (line 84) | def forward(self, batch):
FILE: PBnet/src/models/architectures/gru.py
function augment_x (line 6) | def augment_x(x, y, mask, lengths, num_classes, concatenate_time):
function augment_z (line 25) | def augment_z(z, y, mask, lengths, num_classes, concatenate_time):
class Encoder_GRU (line 41) | class Encoder_GRU(nn.Module):
method __init__ (line 42) | def __init__(self, modeltype, njoints, nfeats, num_frames,
method forward (line 76) | def forward(self, batch):
class Decoder_GRU (line 95) | class Decoder_GRU(nn.Module):
method __init__ (line 96) | def __init__(self, modeltype, njoints, nfeats, num_frames,
method forward (line 127) | def forward(self, batch):
FILE: PBnet/src/models/architectures/mlp.py
class Upsample (line 6) | class Upsample(nn.Module):
method __init__ (line 7) | def __init__(self, input_dim, output_dim, kernel, stride):
method forward (line 14) | def forward(self, x):
class ResidualConv (line 17) | class ResidualConv(nn.Module):
method __init__ (line 18) | def __init__(self, input_dim, output_dim, stride, padding):
method forward (line 36) | def forward(self, x):
class PositionalEncoding (line 40) | class PositionalEncoding(nn.Module):
method __init__ (line 41) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 54) | def forward(self, x):
class RelativePositionBias (line 60) | class RelativePositionBias(nn.Module):
method __init__ (line 61) | def __init__(
method _relative_position_bucket (line 73) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 92) | def forward(self, n, device):
class TimeEncoding (line 102) | class TimeEncoding(nn.Module):
method __init__ (line 103) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 107) | def forward(self, x, mask, lengths):
class ResUnet (line 115) | class ResUnet(nn.Module):
method __init__ (line 116) | def __init__(self, channel=1, filters=[32, 64, 128, 256]):
method forward (line 148) | def forward(self, x):
class Encoder_MLP (line 176) | class Encoder_MLP(nn.Module):
method __init__ (line 177) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, p...
method forward (line 203) | def forward(self, batch):
class Decoder_MLP (line 232) | class Decoder_MLP(nn.Module):
method __init__ (line 233) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, p...
method forward (line 259) | def forward(self, batch):
FILE: PBnet/src/models/architectures/resnet34.py
function conv3x3 (line 6) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
function conv1x1 (line 12) | def conv1x1(in_planes, out_planes, stride=1):
class BasicBlock (line 16) | class BasicBlock(nn.Module):
method __init__ (line 19) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
method forward (line 37) | def forward(self, x):
class Bottleneck (line 56) | class Bottleneck(nn.Module):
method __init__ (line 59) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
method forward (line 76) | def forward(self, x):
class ResNet (line 98) | class ResNet(nn.Module):
method __init__ (line 100) | def __init__(self, block, layers, num_classes=1000, zero_init_residual...
method _make_layer (line 151) | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
method forward (line 175) | def forward(self, x):
function _resnet (line 192) | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
function resnet34 (line 196) | def resnet34(pretrained=False, progress=True, **kwargs):
class MyResNet34 (line 208) | class MyResNet34(nn.Module):
method __init__ (line 209) | def __init__(self,embedding_dim,input_channel = 3):
method forward (line 212) | def forward(self, x):
FILE: PBnet/src/models/architectures/tools/embeddings.py
function get_activation (line 9) | def get_activation(activation_type):
class MaskedNorm (line 38) | class MaskedNorm(nn.Module):
method __init__ (line 44) | def __init__(self, norm_type, num_groups, num_features):
method forward (line 58) | def forward(self, x: Tensor, mask: Tensor):
class Embeddings (line 77) | class Embeddings(nn.Module):
method __init__ (line 84) | def __init__(
method forward (line 134) | def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
method __repr__ (line 156) | def __repr__(self):
class SpatialEmbeddings (line 164) | class SpatialEmbeddings(nn.Module):
method __init__ (line 172) | def __init__(
method forward (line 219) | def forward(self, x: Tensor, mask: Tensor) -> Tensor:
method __repr__ (line 239) | def __repr__(self):
FILE: PBnet/src/models/architectures/tools/resnet.py
function conv3x3 (line 4) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
function conv1x1 (line 10) | def conv1x1(in_planes, out_planes, stride=1):
class BasicBlock (line 14) | class BasicBlock(nn.Module):
method __init__ (line 17) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
method forward (line 35) | def forward(self, x):
class Bottleneck (line 54) | class Bottleneck(nn.Module):
method __init__ (line 57) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
method forward (line 74) | def forward(self, x):
class ResNet (line 96) | class ResNet(nn.Module):
method __init__ (line 98) | def __init__(self, block, layers, num_classes=1000, zero_init_residual...
method _make_layer (line 149) | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
method forward (line 173) | def forward(self, x):
function _resnet (line 190) | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
function resnet34 (line 194) | def resnet34(pretrained=False, progress=True, **kwargs):
FILE: PBnet/src/models/architectures/tools/transformer_layers.py
class MultiHeadedAttention (line 11) | class MultiHeadedAttention(nn.Module):
method __init__ (line 19) | def __init__(self, num_heads: int, size: int, dropout: float = 0.1):
method forward (line 42) | def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None):
class PositionwiseFeedForward (line 97) | class PositionwiseFeedForward(nn.Module):
method __init__ (line 103) | def __init__(self, input_size, ff_size, dropout=0.1):
method forward (line 120) | def forward(self, x):
class PositionalEncoding (line 126) | class PositionalEncoding(nn.Module):
method __init__ (line 136) | def __init__(self,
method forward (line 159) | def forward(self, emb):
class TransformerEncoderLayer (line 169) | class TransformerEncoderLayer(nn.Module):
method __init__ (line 175) | def __init__(self,
method forward (line 198) | def forward(self, x: Tensor, mask: Tensor) -> Tensor:
class TransformerDecoderLayer (line 216) | class TransformerDecoderLayer(nn.Module):
method __init__ (line 223) | def __init__(self,
method forward (line 255) | def forward(self,
FILE: PBnet/src/models/architectures/tools/util.py
class MyResNet34 (line 10) | class MyResNet34(nn.Module):
method __init__ (line 11) | def __init__(self,embedding_dim,input_channel = 3):
method forward (line 14) | def forward(self, x):
FILE: PBnet/src/models/architectures/transformer.py
class PositionalEncoding (line 7) | class PositionalEncoding(nn.Module):
method __init__ (line 8) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 21) | def forward(self, x):
class RelativePositionBias (line 27) | class RelativePositionBias(nn.Module):
method __init__ (line 28) | def __init__(
method _relative_position_bucket (line 40) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 59) | def forward(self, n, device):
class TimeEncoding (line 69) | class TimeEncoding(nn.Module):
method __init__ (line 70) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 74) | def forward(self, x, mask, lengths):
class Encoder_TRANSFORMER (line 83) | class Encoder_TRANSFORMER(nn.Module):
method __init__ (line 84) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, p...
method forward (line 134) | def forward(self, batch):
class Decoder_TRANSFORMER (line 170) | class Decoder_TRANSFORMER(nn.Module):
method __init__ (line 171) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, p...
method forward (line 226) | def forward(self, batch):
FILE: PBnet/src/models/architectures/transformerdecoder.py
class MultiheadAttention (line 17) | class MultiheadAttention(nn.Module):
method __init__ (line 18) | def __init__(self, embed_size, heads, dropout = None, batch_first = No...
method sinusoidal_position_embedding (line 35) | def sinusoidal_position_embedding(self, batch_size, nums_head, max_len...
method RoPE (line 50) | def RoPE(self, q, k):
method forward (line 73) | def forward(self, q, k, v, attn_mask = None, key_padding_mask=None, ne...
class Transformer (line 108) | class Transformer(Module):
method __init__ (line 143) | def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_lay...
method forward (line 177) | def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor]...
method generate_square_subsequent_mask (line 244) | def generate_square_subsequent_mask(sz: int) -> Tensor:
method _reset_parameters (line 250) | def _reset_parameters(self):
class TransformerEncoder (line 258) | class TransformerEncoder(Module):
method __init__ (line 274) | def __init__(self, encoder_layer, num_layers, norm=None):
method forward (line 280) | def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_...
class TransformerDecoder (line 302) | class TransformerDecoder(Module):
method __init__ (line 319) | def __init__(self, decoder_layer, num_layers, norm=None):
method forward (line 325) | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tens...
class TransformerEncoderLayer (line 354) | class TransformerEncoderLayer(Module):
method __init__ (line 387) | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 20...
method __setstate__ (line 412) | def __setstate__(self, state):
method forward (line 417) | def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_...
method _sa_block (line 442) | def _sa_block(self, x: Tensor,
method _ff_block (line 451) | def _ff_block(self, x: Tensor) -> Tensor:
class TransformerDecoderLayer (line 456) | class TransformerDecoderLayer(Module):
method __init__ (line 492) | def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 20...
method __setstate__ (line 521) | def __setstate__(self, state):
method forward (line 526) | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tens...
method _sa_block (line 556) | def _sa_block(self, x: Tensor,
method _mha_block (line 565) | def _mha_block(self, x: Tensor, mem: Tensor,
method _ff_block (line 574) | def _ff_block(self, x: Tensor) -> Tensor:
function _get_clones (line 579) | def _get_clones(module, N):
function _get_activation_fn (line 583) | def _get_activation_fn(activation):
FILE: PBnet/src/models/architectures/transformerdecoder4.py
function exists (line 21) | def exists(x):
class Attention (line 24) | class Attention(nn.Module):
method __init__ (line 25) | def __init__(
method forward (line 41) | def forward(
class Attention_2 (line 102) | class Attention_2(nn.Module):
method __init__ (line 103) | def __init__(
method forward (line 121) | def forward(
class PositionwiseFeedforwardLayer (line 170) | class PositionwiseFeedforwardLayer(nn.Module):
method __init__ (line 171) | def __init__(self, d_model, d_ff, dropout):
method forward (line 179) | def forward(self, x):
class DecoderLayer (line 186) | class DecoderLayer(nn.Module):
method __init__ (line 187) | def __init__(self, d_model, num_heads, d_ff, dropout, rotary_emb):
method forward (line 201) | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
class TransformerDecoder (line 208) | class TransformerDecoder(nn.Module):
method __init__ (line 209) | def __init__(self, num_layers, d_model, num_heads, dim_feedforward, dr...
method forward (line 215) | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
FILE: PBnet/src/models/architectures/transformerdecoder5.py
function exists (line 20) | def exists(x):
class Attention (line 23) | class Attention(nn.Module):
method __init__ (line 24) | def __init__(
method forward (line 40) | def forward(
class Attention_2 (line 101) | class Attention_2(nn.Module):
method __init__ (line 102) | def __init__(
method forward (line 120) | def forward(
class PositionwiseFeedforwardLayer (line 169) | class PositionwiseFeedforwardLayer(nn.Module):
method __init__ (line 170) | def __init__(self, d_model, d_ff, dropout):
method forward (line 178) | def forward(self, x):
class DecoderLayer (line 185) | class DecoderLayer(nn.Module):
method __init__ (line 186) | def __init__(self, d_model, num_heads, d_ff, dropout, rotary_emb):
method forward (line 202) | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
class TransformerDecoder (line 209) | class TransformerDecoder(nn.Module):
method __init__ (line 210) | def __init__(self, num_layers, d_model, num_heads, dim_feedforward, dr...
method forward (line 216) | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
FILE: PBnet/src/models/architectures/transformerreemb.py
function exists (line 12) | def exists(x):
class LayerNorm (line 15) | class LayerNorm(nn.Module):
method __init__ (line 16) | def __init__(self, dim, eps=1e-5):
method forward (line 21) | def forward(self, x):
class PreNorm (line 26) | class PreNorm(nn.Module):
method __init__ (line 27) | def __init__(self, dim, fn):
method forward (line 32) | def forward(self, x, **kwargs):
class Residual (line 36) | class Residual(nn.Module):
method __init__ (line 37) | def __init__(self, fn):
method forward (line 41) | def forward(self, x, *args, **kwargs):
class EinopsToAndFrom (line 44) | class EinopsToAndFrom(nn.Module):
method __init__ (line 45) | def __init__(self, from_einops, to_einops, fn):
method forward (line 51) | def forward(self, x, **kwargs):
class Attention (line 59) | class Attention(nn.Module):
method __init__ (line 60) | def __init__(
method forward (line 76) | def forward(
class PositionalEncoding (line 138) | class PositionalEncoding(nn.Module):
method __init__ (line 139) | def __init__(self, d_model, dropout=0.1, max_len=20000):
method forward (line 152) | def forward(self, x):
class RelativePositionBias (line 158) | class RelativePositionBias(nn.Module):
method __init__ (line 159) | def __init__(
method _relative_position_bucket (line 171) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 190) | def forward(self, n, device, eval = False):
class TimeEncoding (line 205) | class TimeEncoding(nn.Module):
method __init__ (line 206) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 210) | def forward(self, x, mask, lengths):
class Encoder_TRANSFORMERREEMB (line 219) | class Encoder_TRANSFORMERREEMB(nn.Module):
method __init__ (line 220) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, p...
method forward (line 270) | def forward(self, batch):
class Decoder_TRANSFORMERREEMB (line 306) | class Decoder_TRANSFORMERREEMB(nn.Module):
method __init__ (line 307) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, p...
method forward (line 375) | def forward(self, batch):
FILE: PBnet/src/models/architectures/transformerreemb5.py
function exists (line 13) | def exists(x):
class LayerNorm (line 16) | class LayerNorm(nn.Module):
method __init__ (line 17) | def __init__(self, dim, eps=1e-5):
method forward (line 22) | def forward(self, x):
class PreNorm (line 27) | class PreNorm(nn.Module):
method __init__ (line 28) | def __init__(self, dim, fn):
method forward (line 33) | def forward(self, x, **kwargs):
class Residual (line 37) | class Residual(nn.Module):
method __init__ (line 38) | def __init__(self, fn):
method forward (line 42) | def forward(self, x, *args, **kwargs):
class EinopsToAndFrom (line 45) | class EinopsToAndFrom(nn.Module):
method __init__ (line 46) | def __init__(self, from_einops, to_einops, fn):
method forward (line 52) | def forward(self, x, **kwargs):
class PositionalEncoding (line 61) | class PositionalEncoding(nn.Module):
method __init__ (line 62) | def __init__(self, d_model, dropout=0.1, max_len=20000):
method forward (line 75) | def forward(self, x):
class RelativePositionBias (line 81) | class RelativePositionBias(nn.Module):
method __init__ (line 82) | def __init__(
method _relative_position_bucket (line 94) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 113) | def forward(self, n, device, eval = False):
class TimeEncoding (line 132) | class TimeEncoding(nn.Module):
method __init__ (line 133) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 137) | def forward(self, x, mask, lengths):
class Encoder_TRANSFORMERREEMB5 (line 146) | class Encoder_TRANSFORMERREEMB5(nn.Module):
method __init__ (line 147) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, e...
method forward (line 198) | def forward(self, batch):
class Decoder_TRANSFORMERREEMB5 (line 234) | class Decoder_TRANSFORMERREEMB5(nn.Module):
method __init__ (line 235) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, e...
method forward (line 311) | def forward(self, batch):
FILE: PBnet/src/models/architectures/transformerreemb6.py
function exists (line 13) | def exists(x):
class LayerNorm (line 16) | class LayerNorm(nn.Module):
method __init__ (line 17) | def __init__(self, dim, eps=1e-5):
method forward (line 22) | def forward(self, x):
class PreNorm (line 27) | class PreNorm(nn.Module):
method __init__ (line 28) | def __init__(self, dim, fn):
method forward (line 33) | def forward(self, x, **kwargs):
class Residual (line 37) | class Residual(nn.Module):
method __init__ (line 38) | def __init__(self, fn):
method forward (line 42) | def forward(self, x, *args, **kwargs):
class EinopsToAndFrom (line 45) | class EinopsToAndFrom(nn.Module):
method __init__ (line 46) | def __init__(self, from_einops, to_einops, fn):
method forward (line 52) | def forward(self, x, **kwargs):
class PositionalEncoding (line 61) | class PositionalEncoding(nn.Module):
method __init__ (line 62) | def __init__(self, d_model, dropout=0.1, max_len=20000):
method forward (line 75) | def forward(self, x):
class RelativePositionBias (line 81) | class RelativePositionBias(nn.Module):
method __init__ (line 82) | def __init__(
method _relative_position_bucket (line 94) | def _relative_position_bucket(relative_position, num_buckets=32, max_d...
method forward (line 113) | def forward(self, n, device, eval = False):
class TimeEncoding (line 132) | class TimeEncoding(nn.Module):
method __init__ (line 133) | def __init__(self, d_model, dropout=0.1, max_len=5000):
method forward (line 137) | def forward(self, x, mask, lengths):
class Encoder_TRANSFORMERREEMB6 (line 146) | class Encoder_TRANSFORMERREEMB6(nn.Module):
method __init__ (line 147) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, e...
method forward (line 198) | def forward(self, batch):
class Decoder_TRANSFORMERREEMB6 (line 234) | class Decoder_TRANSFORMERREEMB6(nn.Module):
method __init__ (line 235) | def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, e...
method forward (line 310) | def forward(self, batch):
FILE: PBnet/src/models/get_model.py
function get_model (line 19) | def get_model(parameters):
FILE: PBnet/src/models/modeltype/cae.py
class CAE (line 10) | class CAE(nn.Module):
method __init__ (line 11) | def __init__(self, encoder, decoder, device, lambdas, latent_dim, **kw...
method forward (line 46) | def forward(self, batch):
method compute_loss (line 66) | def compute_loss(self, batch, epoch = 0):
method lengths_to_mask (line 88) | def lengths_to_mask(lengths):
method generate_one (line 96) | def generate_one(self, cls, duration, fact=1, xyz=False):
method generate (line 112) | def generate(self, pose, audio, durations,
method return_latent (line 174) | def return_latent(self, batch, seed=None):
FILE: PBnet/src/models/modeltype/cae_0.py
class CAE (line 8) | class CAE(nn.Module):
method __init__ (line 9) | def __init__(self, encoder, decoder, device, lambdas, latent_dim, **kw...
method forward (line 43) | def forward(self, batch):
method compute_loss (line 61) | def compute_loss(self, batch):
method lengths_to_mask (line 73) | def lengths_to_mask(lengths):
method generate_one (line 81) | def generate_one(self, cls, duration, fact=1, xyz=False):
method generate (line 97) | def generate(self, pose, audio, durations,
method return_latent (line 159) | def return_latent(self, batch, seed=None):
FILE: PBnet/src/models/modeltype/cvae.py
class CVAE (line 5) | class CVAE(CAE):
method reparameterize (line 6) | def reparameterize(self, batch, seed=None):
method forward (line 20) | def forward(self, batch):
method return_latent (line 40) | def return_latent(self, batch, seed=None):
FILE: PBnet/src/models/modeltype/lstm.py
class MyResNet34 (line 11) | class MyResNet34(nn.Module):
method __init__ (line 12) | def __init__(self,embedding_dim,input_channel = 3):
method forward (line 15) | def forward(self, x):
class LSTM (line 19) | class LSTM(nn.Module):
method __init__ (line 20) | def __init__(self, encoder, decoder, device, lambdas, latent_dim, **kw...
method compute_loss (line 33) | def compute_loss(self, batch):
method forward (line 49) | def forward(self,batch):
method lengths_to_mask (line 72) | def lengths_to_mask(lengths):
method generate (line 80) | def generate(self, pose, audio, durations,
FILE: PBnet/src/models/rotation2xyz.py
class Rotation2xyz (line 8) | class Rotation2xyz:
method __init__ (line 9) | def __init__(self, device):
method __call__ (line 13) | def __call__(self, x, mask, pose_rep, translation, glob,
FILE: PBnet/src/models/smpl.py
class SMPL (line 60) | class SMPL(_SMPLLayer):
method __init__ (line 63) | def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs):
method forward (line 82) | def forward(self, *args, **kwargs):
FILE: PBnet/src/models/tools/graphconv.py
class GraphConvolution (line 9) | class GraphConvolution(Module):
method __init__ (line 14) | def __init__(self, in_features, out_features, bias=True):
method reset_parameters (line 25) | def reset_parameters(self):
method forward (line 31) | def forward(self, input, adj):
method __repr__ (line 39) | def __repr__(self):
FILE: PBnet/src/models/tools/hessian_penalty.py
function hessian_penalty (line 29) | def hessian_penalty(G, batch, k=2, epsilon=0.1, reduction=torch.max, ret...
function rademacher (line 67) | def rademacher(shape, device='cpu'):
function multi_layer_second_directional_derivative (line 75) | def multi_layer_second_directional_derivative(G, batch, dz, G_z, epsilon...
function stack_var_and_reduce (line 91) | def stack_var_and_reduce(list_of_activations, reduction=torch.max):
function multi_stack_var_and_reduce (line 99) | def multi_stack_var_and_reduce(sdds, reduction=torch.max, return_separat...
function listify (line 108) | def listify(x):
function _test_hessian_penalty (line 116) | def _test_hessian_penalty():
FILE: PBnet/src/models/tools/losses.py
function compute_rc_loss (line 9) | def compute_rc_loss(model, batch):
function compute_reg_loss (line 23) | def compute_reg_loss(model, batch):
function compute_rc_weight_loss (line 37) | def compute_rc_weight_loss(model, batch):
function compute_hp_loss (line 62) | def compute_hp_loss(model, batch):
function compute_kl_loss (line 67) | def compute_kl_loss(model, batch):
function compute_ssim_loss (line 73) | def compute_ssim_loss(model, batch):
function ssimnorm_loss (line 86) | def ssimnorm_loss(x, output, mask, bs):
function ssimnorm_self_loss (line 100) | def ssimnorm_self_loss(x, output, mask, bs):
function ssim255_loss (line 112) | def ssim255_loss(x, output, mask, bs):
function comput_var_loss (line 126) | def comput_var_loss(model, batch):
function compute_mmd_loss (line 147) | def compute_mmd_loss(model, batch):
function get_loss_function (line 164) | def get_loss_function(ltype):
function get_loss_names (line 168) | def get_loss_names():
FILE: PBnet/src/models/tools/mmd.py
function compute_kernel (line 5) | def compute_kernel(x, y):
function compute_mmd (line 17) | def compute_mmd(x, y):
FILE: PBnet/src/models/tools/msssim_loss.py
function gaussian (line 7) | def gaussian(window_size, sigma):
function create_window (line 12) | def create_window(window_size, channel=1):
function ssim (line 19) | def ssim(img1, img2, window_size=11, window=None, size_average=True, ful...
function msssim (line 73) | def msssim(img1, img2, window_size=11, size_average=True, val_range=None...
class SSIM (line 111) | class SSIM(torch.nn.Module):
method __init__ (line 112) | def __init__(self, window_size=11, size_average=True, val_range=None):
method forward (line 122) | def forward(self, img1, img2):
class MSSSIM (line 134) | class MSSSIM(torch.nn.Module):
method __init__ (line 135) | def __init__(self, window_size=11, size_average=True, channel=3):
method forward (line 141) | def forward(self, img1, img2):
FILE: PBnet/src/models/tools/normalize_data.py
function normalize_data (line 3) | def normalize_data(data, min_vals, max_vals):
FILE: PBnet/src/models/tools/ssim_loss.py
function gaussian (line 7) | def gaussian(window_size, sigma):
function create_window (line 11) | def create_window(window_size, channel):
function _ssim (line 17) | def _ssim(img1, img2, window, window_size, channel, val_range = 1, size_...
class SSIM (line 39) | class SSIM(torch.nn.Module):
method __init__ (line 40) | def __init__(self, window_size = 11, size_average = True):
method forward (line 47) | def forward(self, img1, img2):
function ssim (line 65) | def ssim(img1, img2, window_size = 11, val_range=1, size_average = True):
function read_pose_from_txt (line 75) | def read_pose_from_txt(file_path):
FILE: PBnet/src/models/tools/tools.py
class AutoParams (line 5) | class AutoParams(nn.Module):
method __init__ (line 6) | def __init__(self, **kargs):
function freeze_params (line 28) | def freeze_params(module: nn.Module) -> None:
FILE: PBnet/src/parser/base.py
function add_misc_options (line 4) | def add_misc_options(parser):
function add_cuda_options (line 10) | def add_cuda_options(parser):
function adding_cuda (line 19) | def adding_cuda(parameters):
FILE: PBnet/src/parser/checkpoint.py
function parser (line 6) | def parser():
function construct_checkpointname (line 23) | def construct_checkpointname(parameters, folder):
FILE: PBnet/src/parser/dataset.py
function add_dataset_options (line 4) | def add_dataset_options(parser):
FILE: PBnet/src/parser/evaluation.py
function parser (line 10) | def parser():
FILE: PBnet/src/parser/finetunning.py
function parser (line 5) | def parser():
FILE: PBnet/src/parser/generate.py
function add_generation_options (line 8) | def add_generation_options(parser):
function parser (line 25) | def parser():
FILE: PBnet/src/parser/model.py
function add_model_options (line 4) | def add_model_options(parser):
function parse_modelname (line 30) | def parse_modelname(modelname):
FILE: PBnet/src/parser/recognition.py
function training_parser (line 10) | def training_parser():
FILE: PBnet/src/parser/tools.py
function save_args (line 5) | def save_args(opt, folder):
function load_args (line 14) | def load_args(filename):
FILE: PBnet/src/parser/training.py
function add_training_options (line 10) | def add_training_options(parser):
function parser (line 19) | def parser():
FILE: PBnet/src/parser/visualize.py
function construct_figname (line 9) | def construct_figname(parameters):
function add_visualize_options (line 14) | def add_visualize_options(parser):
function parser (line 56) | def parser(checkpoint=True):
FILE: PBnet/src/preprocess/humanact12_process.py
function splitname (line 7) | def splitname(name):
function create_phpsd_name (line 17) | def create_phpsd_name(name):
function get_frames (line 23) | def get_frames(name):
function get_action (line 28) | def get_action(name, coarse=True):
function process_datata (line 55) | def process_datata(savepath, posesfolder="data/PHPSDposes", datapath="da...
FILE: PBnet/src/preprocess/phspdtools.py
class Transform (line 7) | class Transform:
method __init__ (line 8) | def __init__(self, R=np.eye(3, dtype='float'), t=np.zeros(3, 'float'),...
method __mul__ (line 13) | def __mul__(self, other):
method inv (line 22) | def inv(self):
method transform (line 28) | def transform(self, xyz):
method getmat4 (line 36) | def getmat4(self):
function quat2R (line 44) | def quat2R(quat):
function convert_param2tranform (line 83) | def convert_param2tranform(param, scale=1):
class CameraParams (line 90) | class CameraParams:
method __init__ (line 91) | def __init__(self, cam_folder="data/phspdCameras"):
method get_intrinsic (line 117) | def get_intrinsic(self, cam_name, subject_no):
method get_extrinsic (line 133) | def get_extrinsic(self, cams_name, subject_no):
method get_gender (line 155) | def get_gender(self, subject_no):
FILE: PBnet/src/preprocess/uestc_vibe_postprocessing.py
function get_kinect_motion (line 14) | def get_kinect_motion(tar, videos, index):
function motionto2d (line 25) | def motionto2d(motion, W=960, H=540):
function motionto2dvibe (line 36) | def motionto2dvibe(motion, cam):
function get_kcenter (line 41) | def get_kcenter(tar, videos, index):
function get_concat_goodtracks (line 49) | def get_concat_goodtracks(allvibe, tar, videos, index):
function interpolate_track (line 101) | def interpolate_track(gvibe):
FILE: PBnet/src/recognition/compute_accuracy.py
function compute_accuracy (line 16) | def compute_accuracy(model, datasets, parameters):
function main (line 43) | def main():
FILE: PBnet/src/recognition/get_model.py
function get_model (line 4) | def get_model(parameters):
FILE: PBnet/src/recognition/models/stgcn.py
class STGCN (line 11) | class STGCN(nn.Module):
method __init__ (line 29) | def __init__(self, in_channels, num_class, graph_args,
method forward (line 75) | def forward(self, batch):
method compute_accuracy (line 114) | def compute_accuracy(self, batch):
method compute_loss (line 123) | def compute_loss(self, batch):
class st_gcn (line 134) | class st_gcn(nn.Module):
method __init__ (line 155) | def __init__(self,
method forward (line 203) | def forward(self, x, A):
FILE: PBnet/src/recognition/models/stgcnutils/graph.py
class Graph (line 7) | class Graph:
method __init__ (line 26) | def __init__(self,
method __str__ (line 42) | def __str__(self):
method get_edge (line 45) | def get_edge(self, layout):
method get_adjacency (line 99) | def get_adjacency(self, strategy):
function get_hop_distance (line 144) | def get_hop_distance(num_node, edge, max_hop=1):
function normalize_digraph (line 159) | def normalize_digraph(A):
function normalize_undigraph (line 170) | def normalize_undigraph(A):
FILE: PBnet/src/recognition/models/stgcnutils/tgcn.py
class ConvTemporalGraphical (line 7) | class ConvTemporalGraphical(nn.Module):
method __init__ (line 34) | def __init__(self,
method forward (line 55) | def forward(self, x, A):
FILE: PBnet/src/render/renderer.py
function get_smpl_faces (line 19) | def get_smpl_faces():
class WeakPerspectiveCamera (line 23) | class WeakPerspectiveCamera(pyrender.Camera):
method __init__ (line 24) | def __init__(self,
method get_projection_matrix (line 38) | def get_projection_matrix(self, width=None, height=None):
class Renderer (line 48) | class Renderer:
method __init__ (line 49) | def __init__(self, background=None, resolution=(224, 224), bg_color=[0...
method render (line 104) | def render(self, img, verts, cam, angle=None, axis=None, mesh_filename...
function get_renderer (line 154) | def get_renderer(width, height):
FILE: PBnet/src/render/rendermotion.py
function get_rotation (line 9) | def get_rotation(theta=np.pi/3):
function render_video (line 18) | def render_video(meshes, key, action, renderer, savepath, background, ca...
function main (line 43) | def main():
FILE: PBnet/src/train/train_cvae_ganloss_ann_eye.py
class ConvNormRelu (line 33) | class ConvNormRelu(nn.Module):
method __init__ (line 34) | def __init__(self, in_channels, out_channels, kernel_size, stride, pad...
method forward (line 46) | def forward(self, x):
class D_patchgan (line 51) | class D_patchgan(nn.Module):
method __init__ (line 52) | def __init__(self, n_downsampling=2, pos_dim=6, eye_dim=0, norm='batch'):
method forward (line 71) | def forward(self, x):
method calculate_GAN_loss (line 77) | def calculate_GAN_loss(self, batch):
function get_model (line 91) | def get_model(parameters):
function do_epochs (line 108) | def do_epochs(model, model_d, dataset, parameters, optimizer_g, optimize...
FILE: PBnet/src/train/train_cvae_ganloss_ann_fast.py
class ConvNormRelu (line 30) | class ConvNormRelu(nn.Module):
method __init__ (line 31) | def __init__(self, in_channels, out_channels, kernel_size, stride, pad...
method forward (line 43) | def forward(self, x):
class D_patchgan (line 48) | class D_patchgan(nn.Module):
method __init__ (line 49) | def __init__(self, n_downsampling=2, norm='batch'):
method forward (line 66) | def forward(self, x):
method calculate_GAN_loss (line 72) | def calculate_GAN_loss(self, batch):
function get_model (line 86) | def get_model(parameters):
function do_epochs (line 103) | def do_epochs(model, model_d, dataset, parameters, optimizer_g, optimize...
FILE: PBnet/src/train/trainer.py
function train_or_test (line 5) | def train_or_test(model, optimizer, iterator, device, mode="train"):
function train (line 45) | def train(model, optimizer, iterator, device):
function test (line 49) | def test(model, optimizer, iterator, device):
FILE: PBnet/src/train/trainer_gan.py
function train_or_test (line 6) | def train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, de...
function train (line 73) | def train(model, model_d, optimizer_g, optimizer_d, iterator, device):
function test (line 77) | def test(model, model_d, optimizer_g, optimizer_d, iterator, device):
FILE: PBnet/src/train/trainer_gan_ann.py
function train_or_test (line 6) | def train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, de...
function train (line 79) | def train(model, model_d, optimizer_g, optimizer_d, iterator, device, ep...
function test (line 83) | def test(model, model_d, optimizer_g, optimizer_d, iterator, device):
FILE: PBnet/src/utils/fixseed.py
function fixseed (line 6) | def fixseed(seed):
FILE: PBnet/src/utils/get_model_and_data.py
function get_model_and_data (line 6) | def get_model_and_data(parameters):
FILE: PBnet/src/utils/misc.py
function to_numpy (line 4) | def to_numpy(tensor):
function to_torch (line 13) | def to_torch(ndarray):
function cleanexit (line 22) | def cleanexit():
FILE: PBnet/src/utils/rotation_conversions.py
function quaternion_to_matrix (line 37) | def quaternion_to_matrix(quaternions):
function _copysign (line 68) | def _copysign(a, b):
function _sqrt_positive_part (line 86) | def _sqrt_positive_part(x):
function matrix_to_quaternion (line 97) | def matrix_to_quaternion(matrix):
function _axis_angle_rotation (line 122) | def _axis_angle_rotation(axis: str, angle):
function euler_angles_to_matrix (line 150) | def euler_angles_to_matrix(euler_angles, convention: str):
function _angle_from_tan (line 175) | def _angle_from_tan(
function _index_from_letter (line 208) | def _index_from_letter(letter: str):
function matrix_to_euler_angles (line 217) | def matrix_to_euler_angles(matrix, convention: str):
function random_quaternions (line 259) | def random_quaternions(
function random_rotations (line 283) | def random_rotations(
function random_rotation (line 306) | def random_rotation(
function standardize_quaternion (line 325) | def standardize_quaternion(quaternions):
function quaternion_raw_multiply (line 340) | def quaternion_raw_multiply(a, b):
function quaternion_multiply (line 361) | def quaternion_multiply(a, b):
function quaternion_invert (line 378) | def quaternion_invert(quaternion):
function quaternion_apply (line 394) | def quaternion_apply(quaternion, point):
function axis_angle_to_matrix (line 417) | def axis_angle_to_matrix(axis_angle):
function matrix_to_axis_angle (line 433) | def matrix_to_axis_angle(matrix):
function axis_angle_to_quaternion (line 449) | def axis_angle_to_quaternion(axis_angle):
function quaternion_to_axis_angle (line 481) | def quaternion_to_axis_angle(quaternions):
function rotation_6d_to_matrix (line 512) | def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
function matrix_to_rotation_6d (line 536) | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
FILE: PBnet/src/utils/tensors.py
function lengths_to_mask (line 4) | def lengths_to_mask(lengths):
function collate_tensors (line 10) | def collate_tensors(batch):
function collate (line 23) | def collate(batch):
FILE: PBnet/src/utils/tensors_eye.py
function lengths_to_mask (line 4) | def lengths_to_mask(lengths):
function collate_tensors (line 10) | def collate_tensors(batch):
function collate (line 23) | def collate(batch):
FILE: PBnet/src/utils/tensors_eye_eval.py
function lengths_to_mask (line 4) | def lengths_to_mask(lengths):
function collate_tensors (line 10) | def collate_tensors(batch):
function collate (line 23) | def collate(batch):
FILE: PBnet/src/utils/tensors_hdtf.py
function lengths_to_mask (line 4) | def lengths_to_mask(lengths):
function collate_tensors (line 10) | def collate_tensors(batch):
function collate_old (line 23) | def collate_old(batch):
function collate (line 43) | def collate(batch):
FILE: PBnet/src/utils/tensors_onlyeye.py
function lengths_to_mask (line 4) | def lengths_to_mask(lengths):
function collate_tensors (line 10) | def collate_tensors(batch):
function collate (line 23) | def collate(batch):
function collate_eval (line 46) | def collate_eval(batch):
FILE: PBnet/src/utils/utils.py
class _RepeatSampler (line 7) | class _RepeatSampler(object):
method __init__ (line 9) | def __init__(self, sampler):
method __iter__ (line 12) | def __iter__(self):
class MultiEpochsDataLoader (line 16) | class MultiEpochsDataLoader(torch.utils.data.DataLoader):
method __init__ (line 20) | def __init__(self, *args, **kwargs):
method __len__ (line 25) | def __len__(self):
method __iter__ (line 28) | def __iter__(self):
class CudaDataLoader (line 32) | class CudaDataLoader:
method __init__ (line 34) | def __init__(self, loader, device, queue_size=2):
method load_loop (line 47) | def load_loop(self):
method load_instance (line 53) | def load_instance(self, sample):
method __iter__ (line 64) | def __iter__(self):
method __next__ (line 68) | def __next__(self):
method next (line 83) | def next(self):
method __len__ (line 86) | def __len__(self):
method sampler (line 90) | def sampler(self):
method dataset (line 94) | def dataset(self):
FILE: PBnet/src/utils/video.py
function load_video (line 5) | def load_video(filename):
class SaveVideo (line 12) | class SaveVideo:
method __init__ (line 13) | def __init__(self, outname, fps):
method __enter__ (line 17) | def __enter__(self):
method __exit__ (line 23) | def __exit__(self, exc_type, exc_value, exc_traceback):
method __iadd__ (line 26) | def __iadd__(self, data):
FILE: PBnet/src/visualize/anim.py
function add_shadow (line 30) | def add_shadow(img, shadow=15):
function load_anim (line 38) | def load_anim(path, timesize=None):
function plot_3d_motion (line 52) | def plot_3d_motion(motion, length, save_path, params, title="", interval...
function plot_3d_motion_dico (line 134) | def plot_3d_motion_dico(x):
FILE: PBnet/src/visualize/visualize.py
function stack_images (line 11) | def stack_images(real, real_gens, gen):
function generate_by_video (line 25) | def generate_by_video(visualization, reconstructions, generation,
function viz_epoch (line 110) | def viz_epoch(model, dataset, epoch, params, folder, writer=None):
function viz_dataset (line 257) | def viz_dataset(dataset, params, folder):
function generate_by_video_sequences (line 319) | def generate_by_video_sequences(visualization, label_to_action_name, par...
function stack_images_sequence (line 362) | def stack_images_sequence(visu):
FILE: PBnet/src/visualize/visualize_checkpoint.py
function main (line 14) | def main():
FILE: PBnet/src/visualize/visualize_nturefined.py
function viz_ntu13 (line 11) | def viz_ntu13(dataset, device):
FILE: extract_init_states/FaceBoxes/FaceBoxes.py
function viz_bbox (line 33) | def viz_bbox(img, dets, wfp='out.jpg'):
class FaceBoxes (line 48) | class FaceBoxes:
method __init__ (line 49) | def __init__(self, timer_flag=False):
method __call__ (line 61) | def __call__(self, img_):
function main (line 144) | def main():
FILE: extract_init_states/FaceBoxes/FaceBoxes_ONNX.py
function viz_bbox (line 33) | def viz_bbox(img, dets, wfp='out.jpg'):
class FaceBoxes_ONNX (line 48) | class FaceBoxes_ONNX(object):
method __init__ (line 49) | def __init__(self, timer_flag=False):
method __call__ (line 56) | def __call__(self, img_):
function main (line 147) | def main():
FILE: extract_init_states/FaceBoxes/models/faceboxes.py
class BasicConv2d (line 8) | class BasicConv2d(nn.Module):
method __init__ (line 10) | def __init__(self, in_channels, out_channels, **kwargs):
method forward (line 15) | def forward(self, x):
class Inception (line 21) | class Inception(nn.Module):
method __init__ (line 22) | def __init__(self):
method forward (line 32) | def forward(self, x):
class CRelu (line 49) | class CRelu(nn.Module):
method __init__ (line 51) | def __init__(self, in_channels, out_channels, **kwargs):
method forward (line 56) | def forward(self, x):
class FaceBoxesNet (line 64) | class FaceBoxesNet(nn.Module):
method __init__ (line 66) | def __init__(self, phase, size, num_classes):
method multibox (line 102) | def multibox(self, num_classes):
method forward (line 113) | def forward(self, x):
FILE: extract_init_states/FaceBoxes/onnx.py
function convert_to_onnx (line 11) | def convert_to_onnx(onnx_path):
FILE: extract_init_states/FaceBoxes/utils/box_utils.py
function point_form (line 7) | def point_form(boxes):
function center_size (line 19) | def center_size(boxes):
function intersect (line 31) | def intersect(box_a, box_b):
function jaccard (line 52) | def jaccard(box_a, box_b):
function matrix_iou (line 73) | def matrix_iou(a, b):
function matrix_iof (line 86) | def matrix_iof(a, b):
function match (line 98) | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, i...
function encode (line 152) | def encode(matched, priors, variances):
function decode (line 177) | def decode(loc, priors, variances):
function log_sum_exp (line 198) | def log_sum_exp(x):
function nms (line 212) | def nms(boxes, scores, overlap=0.5, top_k=200):
FILE: extract_init_states/FaceBoxes/utils/build.py
function find_in_path (line 18) | def find_in_path(name, path):
class custom_build_ext (line 36) | class custom_build_ext(build_ext):
method build_extensions (line 37) | def build_extensions(self):
FILE: extract_init_states/FaceBoxes/utils/functions.py
function check_keys (line 7) | def check_keys(model, pretrained_state_dict):
function remove_prefix (line 20) | def remove_prefix(state_dict, prefix):
function load_model (line 27) | def load_model(model, pretrained_path, load_to_cpu):
FILE: extract_init_states/FaceBoxes/utils/nms/py_cpu_nms.py
function py_cpu_nms (line 10) | def py_cpu_nms(dets, thresh):
FILE: extract_init_states/FaceBoxes/utils/nms_wrapper.py
function nms (line 13) | def nms(dets, thresh):
FILE: extract_init_states/FaceBoxes/utils/prior_box.py
class PriorBox (line 10) | class PriorBox(object):
method __init__ (line 11) | def __init__(self, image_size=None):
method forward (line 20) | def forward(self):
FILE: extract_init_states/FaceBoxes/utils/timer.py
class Timer (line 13) | class Timer(object):
method __init__ (line 16) | def __init__(self):
method tic (line 23) | def tic(self):
method toc (line 28) | def toc(self, average=True):
method clear (line 38) | def clear(self):
FILE: extract_init_states/TDDFA_ONNX.py
class TDDFA_ONNX (line 29) | class TDDFA_ONNX(object):
method __init__ (line 32) | def __init__(self, **kvs):
method __call__ (line 74) | def __call__(self, img_ori, objs, **kvs):
method recon_vers (line 105) | def recon_vers(self, param_lst, roi_box_lst, **kvs):
FILE: extract_init_states/bfm/bfm.py
function _to_ctype (line 16) | def _to_ctype(arr):
class BFMModel (line 22) | class BFMModel(object):
method __init__ (line 23) | def __init__(self, bfm_fp, shape_dim=40, exp_dim=10):
FILE: extract_init_states/bfm/bfm_onnx.py
function _to_ctype (line 19) | def _to_ctype(arr):
function _load_tri (line 25) | def _load_tri(bfm_fp):
class BFMModel_ONNX (line 35) | class BFMModel_ONNX(nn.Module):
method __init__ (line 38) | def __init__(self, bfm_fp, shape_dim=40, exp_dim=10):
method forward (line 63) | def forward(self, *inps):
function convert_bfm_to_onnx (line 73) | def convert_bfm_to_onnx(bfm_onnx_fp, shape_dim=40, exp_dim=10):
FILE: extract_init_states/demo_pose_extract_2d_lmk_img.py
function main (line 30) | def main(args,img, save_path, pose_path):
FILE: extract_init_states/functions.py
function get_suffix (line 15) | def get_suffix(filename):
function crop_img (line 23) | def crop_img(img, roi_box):
function calc_hypotenuse (line 56) | def calc_hypotenuse(pts):
function parse_roi_box_from_landmark (line 65) | def parse_roi_box_from_landmark(pts):
function parse_roi_box_from_bbox (line 85) | def parse_roi_box_from_bbox(bbox):
function plot_image (line 101) | def plot_image(img):
function draw_landmarks (line 112) | def draw_landmarks(img, pts, style='fancy', wfp=None, show_flag=False, *...
function cv_draw_landmark (line 159) | def cv_draw_landmark(img_ori, pts, box=None, color=GREEN, size=1):
function calculate_bbox (line 183) | def calculate_bbox(img, lmk):
function calculate_eye (line 204) | def calculate_eye(lmk):
FILE: extract_init_states/models/mobilenet_v1.py
class DepthWiseBlock (line 22) | class DepthWiseBlock(nn.Module):
method __init__ (line 23) | def __init__(self, inplanes, planes, stride=1, prelu=False):
method forward (line 36) | def forward(self, x):
class MobileNet (line 48) | class MobileNet(nn.Module):
method __init__ (line 49) | def __init__(self, widen_factor=1.0, num_classes=1000, prelu=False, in...
method forward (line 96) | def forward(self, x):
function mobilenet (line 122) | def mobilenet(**kwargs):
function mobilenet_2 (line 141) | def mobilenet_2(num_classes=62, input_channel=3):
function mobilenet_1 (line 146) | def mobilenet_1(num_classes=62, input_channel=3):
function mobilenet_075 (line 151) | def mobilenet_075(num_classes=62, input_channel=3):
function mobilenet_05 (line 156) | def mobilenet_05(num_classes=62, input_channel=3):
function mobilenet_025 (line 161) | def mobilenet_025(num_classes=62, input_channel=3):
FILE: extract_init_states/models/mobilenet_v3.py
function conv_bn (line 10) | def conv_bn(inp, oup, stride, conv_layer=nn.Conv2d, norm_layer=nn.BatchN...
function conv_1x1_bn (line 18) | def conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2...
class Hswish (line 26) | class Hswish(nn.Module):
method __init__ (line 27) | def __init__(self, inplace=True):
method forward (line 31) | def forward(self, x):
class Hsigmoid (line 35) | class Hsigmoid(nn.Module):
method __init__ (line 36) | def __init__(self, inplace=True):
method forward (line 40) | def forward(self, x):
class SEModule (line 44) | class SEModule(nn.Module):
method __init__ (line 45) | def __init__(self, channel, reduction=4):
method forward (line 56) | def forward(self, x):
class Identity (line 63) | class Identity(nn.Module):
method __init__ (line 64) | def __init__(self, channel):
method forward (line 67) | def forward(self, x):
function make_divisible (line 71) | def make_divisible(x, divisible_by=8):
class MobileBottleneck (line 76) | class MobileBottleneck(nn.Module):
method __init__ (line 77) | def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE'):
method forward (line 112) | def forward(self, x):
class MobileNetV3 (line 119) | class MobileNetV3(nn.Module):
method __init__ (line 120) | def __init__(self, widen_factor=1.0, num_classes=141, num_landmarks=13...
method forward (line 208) | def forward(self, x):
method _initialize_weights (line 221) | def _initialize_weights(self):
function mobilenet_v3 (line 237) | def mobilenet_v3(**kwargs):
FILE: extract_init_states/models/resnet.py
function conv3x3 (line 9) | def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock (line 15) | class BasicBlock(nn.Module):
method __init__ (line 18) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 28) | def forward(self, x):
class ResNet (line 47) | class ResNet(nn.Module):
method __init__ (line 50) | def __init__(self, block, layers, num_classes=62, num_landmarks=136, i...
method _make_layer (line 86) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 103) | def forward(self, x):
function resnet22 (line 134) | def resnet22(**kwargs):
function main (line 145) | def main():
FILE: extract_init_states/pose.py
function P2sRt (line 18) | def P2sRt(P):
function matrix2angle (line 39) | def matrix2angle(R):
function angle2matrix (line 65) | def angle2matrix(theta):
function angle2matrix_3ddfa (line 112) | def angle2matrix_3ddfa(angles):
function calc_pose (line 140) | def calc_pose(param):
function build_camera_box (line 150) | def build_camera_box(rear_size=90):
function plot_pose_box (line 171) | def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2):
function viz_pose (line 201) | def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None):
function pose_6 (line 217) | def pose_6(param):
function smooth_pose (line 231) | def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=...
function get_pose (line 263) | def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = N...
FILE: extract_init_states/utils/asset/render.c
type Tuple3D (line 9) | struct Tuple3D
function _render (line 16) | void _render(const int *triangles,
FILE: extract_init_states/utils/depth.py
function depth (line 17) | def depth(img, ver_lst, tri, show_flag=False, wfp=None, with_bg_flag=True):
FILE: extract_init_states/utils/functions.py
function get_suffix (line 15) | def get_suffix(filename):
function crop_img (line 23) | def crop_img(img, roi_box):
function calc_hypotenuse (line 56) | def calc_hypotenuse(pts):
function parse_roi_box_from_landmark (line 65) | def parse_roi_box_from_landmark(pts):
function parse_roi_box_from_bbox (line 85) | def parse_roi_box_from_bbox(bbox):
function plot_image (line 101) | def plot_image(img):
function draw_landmarks (line 112) | def draw_landmarks(img, pts, style='fancy', wfp=None, show_flag=False, *...
function cv_draw_landmark (line 159) | def cv_draw_landmark(img_ori, pts, box=None, color=GREEN, size=1):
function calculate_bbox (line 183) | def calculate_bbox(img, lmk):
function calculate_eye (line 204) | def calculate_eye(lmk):
FILE: extract_init_states/utils/io.py
function mkdir (line 11) | def mkdir(d):
function _get_suffix (line 15) | def _get_suffix(filename):
function _load (line 23) | def _load(fp):
function _dump (line 31) | def _dump(wfp, obj):
function _load_tensor (line 41) | def _load_tensor(fp, mode='cpu'):
function _tensor_to_cuda (line 48) | def _tensor_to_cuda(x):
function _load_gpu (line 55) | def _load_gpu(fp):
FILE: extract_init_states/utils/onnx.py
function convert_to_onnx (line 14) | def convert_to_onnx(**kvs):
FILE: extract_init_states/utils/pncc.py
function calc_ncc_code (line 21) | def calc_ncc_code():
function pncc (line 34) | def pncc(img, ver_lst, tri, show_flag=False, wfp=None, with_bg_flag=True):
function main (line 57) | def main():
FILE: extract_init_states/utils/pose.py
function P2sRt (line 18) | def P2sRt(P):
function matrix2angle (line 39) | def matrix2angle(R):
function angle2matrix (line 65) | def angle2matrix(theta):
function angle2matrix_3ddfa (line 112) | def angle2matrix_3ddfa(angles):
function calc_pose (line 140) | def calc_pose(param):
function build_camera_box (line 150) | def build_camera_box(rear_size=90):
function plot_pose_box (line 171) | def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2):
function viz_pose (line 201) | def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None):
function pose_6 (line 217) | def pose_6(param):
function smooth_pose (line 231) | def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=...
function get_pose (line 263) | def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = N...
FILE: extract_init_states/utils/render.py
function render (line 30) | def render(img, ver_lst, tri, alpha=0.6, show_flag=False, wfp=None, with...
FILE: extract_init_states/utils/render_ctypes.py
class TrianglesMeshRender (line 27) | class TrianglesMeshRender(object):
method __init__ (line 28) | def __init__(
method __call__ (line 50) | def __call__(self, vertices, triangles, bg):
function render (line 67) | def render(img, ver_lst, tri, alpha=0.6, show_flag=False, wfp=None, with...
FILE: extract_init_states/utils/serialization.py
function ser_to_ply_single (line 22) | def ser_to_ply_single(ver_lst, tri, height, wfp, reverse=True):
function ser_to_ply_multiple (line 50) | def ser_to_ply_multiple(ver_lst, tri, height, wfp, reverse=True):
function get_colors (line 84) | def get_colors(img, ver):
function ser_to_obj_single (line 94) | def ser_to_obj_single(img, ver_lst, tri, height, wfp):
function ser_to_obj_multiple (line 117) | def ser_to_obj_multiple(img, ver_lst, tri, height, wfp):
FILE: extract_init_states/utils/tddfa_util.py
function _to_ctype (line 14) | def _to_ctype(arr):
function str2bool (line 20) | def str2bool(v):
function load_model (line 29) | def load_model(model, checkpoint_fp):
class ToTensorGjz (line 44) | class ToTensorGjz(object):
method __call__ (line 45) | def __call__(self, pic):
method __repr__ (line 50) | def __repr__(self):
class NormalizeGjz (line 54) | class NormalizeGjz(object):
method __init__ (line 55) | def __init__(self, mean, std):
method __call__ (line 59) | def __call__(self, tensor):
function similar_transform (line 64) | def similar_transform(pts3d, roi_box, size):
function _parse_param (line 80) | def _parse_param(param):
FILE: extract_init_states/utils/uv.py
function load_uv_coords (line 22) | def load_uv_coords(fp):
function process_uv (line 28) | def process_uv(uv_coords, uv_h=256, uv_w=256):
function get_colors (line 41) | def get_colors(img, ver):
function bilinear_interpolate (line 52) | def bilinear_interpolate(img, x, y):
function uv_tex (line 79) | def uv_tex(img, ver_lst, tri, uv_h=256, uv_w=256, uv_c=3, show_flag=Fals...
FILE: filter_fourier.py
function gaussian_pdf (line 11) | def gaussian_pdf(x, mean, std):
function gaussian_density (line 15) | def gaussian_density(length = 20, amplitude = 2, mean = 19, sigma = 3):
function fourier_filter (line 21) | def fourier_filter(fea):
function fourier_filter_1D (line 50) | def fourier_filter_1D(fea, dim):
function hf_loss (line 69) | def hf_loss(fea, mask, dim):
function hf_loss_2 (line 77) | def hf_loss_2(fea_x, fea_y, dim):
class KalmanFilter1D (line 90) | class KalmanFilter1D:
method __init__ (line 91) | def __init__(self, A, H, Q, R, x_init, P_init):
method update (line 99) | def update(self, z):
function kalman_filter (line 111) | def kalman_filter(observations, dim):
function naive_filter (line 123) | def naive_filter(fea):
FILE: hubert_extract/data_gen/process_lrs3/binarizer.py
function load_video_npy (line 13) | def load_video_npy(fn):
function cal_lm3d_in_video_dict (line 23) | def cal_lm3d_in_video_dict(video_dict, face3d_helper):
function load_audio_npy (line 30) | def load_audio_npy(fn):
FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert.py
function get_hubert_from_16k_wav (line 14) | def get_hubert_from_16k_wav(wav_16k_name):
function get_hubert_from_16k_speech (line 20) | def get_hubert_from_16k_speech(speech, device="cuda:1"):
FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate.py
function get_hubert_from_16k_wav (line 18) | def get_hubert_from_16k_wav(wav_16k_name):
function get_hubert_from_16k_speech (line 24) | def get_hubert_from_16k_speech(speech, device="cuda:1"):
FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_batch.py
function get_hubert_from_16k_wav (line 18) | def get_hubert_from_16k_wav(wav_16k_name):
function get_hubert_from_16k_speech (line 24) | def get_hubert_from_16k_speech(speech, device="cuda:3"):
FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py
function get_hubert_from_16k_wav (line 28) | def get_hubert_from_16k_wav(wav_16k_name):
function get_hubert_from_16k_speech (line 34) | def get_hubert_from_16k_speech(speech, device="cuda:0"):
function get_arguments (line 97) | def get_arguments():
function convert_wav_to_16k (line 112) | def convert_wav_to_16k(input_file, output_file):
function delete_file (line 121) | def delete_file(file_path):
FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_single.py
function get_hubert_from_16k_wav (line 18) | def get_hubert_from_16k_wav(wav_16k_name):
function get_hubert_from_16k_speech (line 24) | def get_hubert_from_16k_speech(speech, device="cuda:1"):
FILE: hubert_extract/data_gen/process_lrs3/process_audio_mel_f0.py
function librosa_pad_lr (line 12) | def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
function extract_mel_from_fname (line 23) | def extract_mel_from_fname(wav_path,
function extract_f0_from_wav_and_mel (line 58) | def extract_f0_from_wav_and_mel(wav, mel,
function extract_mel_f0_from_fname (line 77) | def extract_mel_f0_from_fname(fname, out_name=None):
FILE: misc.py
function fig2data (line 16) | def fig2data(fig):
function plot_grid (line 35) | def plot_grid(x, y, ax=None, **kwargs):
function grid2fig (line 44) | def grid2fig(warped_grid, grid_size=32, img_size=256):
function flow2fig (line 68) | def flow2fig(warped_grid, id_grid, grid_size=32, img_size=128):
function conf2fig (line 79) | def conf2fig(conf, img_size=128):
class Logger (line 86) | class Logger(object):
method __init__ (line 87) | def __init__(self, filename='default.log', stream=sys.stdout):
method write (line 91) | def write(self, message):
method flush (line 95) | def flush(self):
function resize (line 99) | def resize(im, desired_size, interpolation):
function resample (line 116) | def resample(image, flow):
function get_grid (line 140) | def get_grid(batchsize, size, minval=-1.0, maxval=1.0):
function get_checkpoint (line 179) | def get_checkpoint(checkpoint_path, url=''):
function download_file_from_google_drive (line 204) | def download_file_from_google_drive(file_id, destination):
function get_confirm_token (line 224) | def get_confirm_token(response):
function save_response_content (line 239) | def save_response_content(response, destination):
function get_rank (line 256) | def get_rank():
function is_master (line 265) | def is_master():
FILE: sync_batchnorm/batchnorm.py
function _sum_ft (line 24) | def _sum_ft(tensor):
function _unsqueeze_ft (line 29) | def _unsqueeze_ft(tensor):
class _SynchronizedBatchNorm (line 38) | class _SynchronizedBatchNorm(_BatchNorm):
method __init__ (line 39) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
method forward (line 48) | def forward(self, input):
method __data_parallel_replicate__ (line 80) | def __data_parallel_replicate__(self, ctx, copy_id):
method _data_parallel_master (line 90) | def _data_parallel_master(self, intermediates):
method _compute_mean_std (line 113) | def _compute_mean_std(self, sum_, ssum, size):
class SynchronizedBatchNorm1d (line 128) | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
method _check_input_dim (line 184) | def _check_input_dim(self, input):
class SynchronizedBatchNorm2d (line 191) | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
method _check_input_dim (line 247) | def _check_input_dim(self, input):
class SynchronizedBatchNorm3d (line 254) | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
method _check_input_dim (line 311) | def _check_input_dim(self, input):
FILE: sync_batchnorm/comm.py
class FutureResult (line 18) | class FutureResult(object):
method __init__ (line 21) | def __init__(self):
method put (line 26) | def put(self, result):
method get (line 32) | def get(self):
class SlavePipe (line 46) | class SlavePipe(_SlavePipeBase):
method run_slave (line 49) | def run_slave(self, msg):
class SyncMaster (line 56) | class SyncMaster(object):
method __init__ (line 67) | def __init__(self, master_callback):
method __getstate__ (line 78) | def __getstate__(self):
method __setstate__ (line 81) | def __setstate__(self, state):
method register_slave (line 84) | def register_slave(self, identifier):
method run_master (line 102) | def run_master(self, master_msg):
method nr_slaves (line 136) | def nr_slaves(self):
FILE: sync_batchnorm/replicate.py
class CallbackContext (line 23) | class CallbackContext(object):
function execute_replication_callbacks (line 27) | def execute_replication_callbacks(modules):
class DataParallelWithCallback (line 50) | class DataParallelWithCallback(DataParallel):
method replicate (line 64) | def replicate(self, module, device_ids):
method update_num_frames (line 69) | def update_num_frames(self, new_num_frames):
function patch_replication_callback (line 75) | def patch_replication_callback(data_parallel):
FILE: sync_batchnorm/replicate_ddp.py
class CallbackContext (line 24) | class CallbackContext(object):
function execute_replication_callbacks (line 28) | def execute_replication_callbacks(modules):
class DataParallelWithCallback_ddp (line 51) | class DataParallelWithCallback_ddp(DistributedDataParallel):
method replicate (line 65) | def replicate(self, module, device_ids):
method update_num_frames (line 70) | def update_num_frames(self, new_num_frames):
function patch_replication_callback_ddp (line 76) | def patch_replication_callback_ddp(data_parallel):
FILE: sync_batchnorm/unittest.py
function as_numpy (line 17) | def as_numpy(v):
class TorchTestCase (line 23) | class TorchTestCase(unittest.TestCase):
method assertTensorClose (line 24) | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
FILE: unified_video_generator.py
function inv_transform (line 31) | def inv_transform(x, min_vals, max_vals):
function load_args (line 34) | def load_args(filename):
class VideoGenerator (line 39) | class VideoGenerator:
method __init__ (line 40) | def __init__(self, args):
method extract_pose (line 131) | def extract_pose(self):
method process_audio (line 202) | def process_audio(self):
method generate_pose_blink (line 252) | def generate_pose_blink(self):
method generate_final_video (line 304) | def generate_final_video(self):
method run (line 402) | def run(self):
method _convert_wav_to_16k (line 417) | def _convert_wav_to_16k(input_file, output_file):
method _get_hubert_from_16k_speech (line 434) | def _get_hubert_from_16k_speech(self, speech, device="cuda:0"):
method _init_video_model (line 504) | def _init_video_model(self, model_config):
method _process_output_frame (line 533) | def _process_output_frame(self, frame_batch, mean=(0.0, 0.0, 0.0), ind...
method _extract_audio_segment (line 550) | def _extract_audio_segment(self, input_wav, start_frame, num_frames, f...
method _combine_video_audio (line 567) | def _combine_video_audio(self, audio_path, video_path, output_path):
function parse_args (line 588) | def parse_args():
function main (line 597) | def main():
Condensed preview — 291 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,454K chars).
[
{
"path": ".gitignore",
"chars": 91,
"preview": ".idea/\n__pycache__\n**/__pycache__\ncache\nsubmit*\npretrain_models/*\n*.mp3\n*.mp4\ntmp*\noutput/*"
},
{
"path": "DAWN_256.yaml",
"chars": 348,
"preview": "input_size: 256\nmax_n_frames: 200\nrandom_seed: 1234\nmean: [0.0, 0.0, 0.0]\nwin_width: 40\nsampling_step: 20\nddim_sampling_"
},
{
"path": "DM_3/datasets_hdtf_wpose_lmk_block_lmk.py",
"chars": 10870,
"preview": "# dataset for HDTF, stage 1\nfrom os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport "
},
{
"path": "DM_3/datasets_hdtf_wpose_lmk_block_lmk_rand.py",
"chars": 13284,
"preview": "# dataset for HDTF, stage 2\nfrom os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport "
},
{
"path": "DM_3/modules/local_attention.py",
"chars": 21025,
"preview": "import sys\n# sys.path.append('your/path/DAWN-pytorch')\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as "
},
{
"path": "DM_3/modules/text.py",
"chars": 1979,
"preview": "# the code from https://github.com/lucidrains/video-diffusion-pytorch\nimport torch\nfrom einops import rearrange\n\n\ndef ex"
},
{
"path": "DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D.py",
"chars": 24602,
"preview": "'''\nstage 1: using 0th as the reference, short and fixed clip\n\nwith lip loss, 6D pose, conditioned by cross attention\n\n'"
},
{
"path": "DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D.py",
"chars": 24856,
"preview": "'''\nstage 2: using random reference, long and dynamic clip\n\nwith lip loss, 6D pose, conditioned by cross attention\n\n'''\n"
},
{
"path": "DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test.py",
"chars": 21880,
"preview": "import os\nimport torch\nimport torch.nn as nn\nimport sys\nfrom LFG.modules.generator import Generator\nfrom LFG.modules.bg_"
},
{
"path": "DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi.py",
"chars": 48654,
"preview": "'''\nadding pose condtioning on baseline\nusing cross attention to add different condition\n\nfor training\n'''\nimport math\ni"
},
{
"path": "DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py",
"chars": 48938,
"preview": "'''\nadding pose condtioning on baseline\nusing cross attention to add different condition\n\nusing local attention, for inf"
},
{
"path": "DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py",
"chars": 49271,
"preview": "'''\nadding pose condtioning on baseline\nusing cross attention to add different condition\n\nusing ram optimized local atte"
},
{
"path": "DM_3/test_lr.py",
"chars": 334,
"preview": "import torch\nimport torch.optim as optim\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\n\nmodel = torch.nn.Linear"
},
{
"path": "DM_3/train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D.py",
"chars": 25434,
"preview": "import sys\nsys.path.append('your/path/')\n\nimport argparse\nfrom datetime import datetime, time\n\nimport imageio\nimport tor"
},
{
"path": "DM_3/train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D_s2.py",
"chars": 26411,
"preview": "import sys\nsys.path.append('your/path/')\n\nimport argparse\nfrom datetime import datetime, time\n\nimport imageio\nimport tor"
},
{
"path": "DM_3/utils.py",
"chars": 842,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass MultiEpochsDataLoader(torch.utils.data.DataLoa"
},
{
"path": "LFG/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "LFG/augmentation.py",
"chars": 12532,
"preview": "\"\"\"\nCode from https://github.com/hassony2/torch_videovision\n\"\"\"\n\nimport numbers\n\nimport random\nimport numpy as np\nimport"
},
{
"path": "LFG/frames_dataset.py",
"chars": 8875,
"preview": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo licens"
},
{
"path": "LFG/hdtf_dataset.py",
"chars": 4129,
"preview": "# build MUG dataset for RegionMM\n\nimport os\nimport imageio\n\nimport numpy as np\nfrom torch.utils.data import Dataset\nimpo"
},
{
"path": "LFG/modules/avd_network.py",
"chars": 3458,
"preview": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo licens"
},
{
"path": "LFG/modules/bg_motion_predictor.py",
"chars": 2931,
"preview": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo licens"
},
{
"path": "LFG/modules/flow_autoenc.py",
"chars": 2887,
"preview": "# utilize RegionMM to design a flow auto-encoder\n\nimport torch\nimport torch.nn as nn\nimport sys\nsys.path.append('/train2"
},
{
"path": "LFG/modules/generator.py",
"chars": 8293,
"preview": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo licens"
},
{
"path": "LFG/modules/model.py",
"chars": 9781,
"preview": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo licens"
},
{
"path": "LFG/modules/pixelwise_flow_predictor.py",
"chars": 7029,
"preview": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo licens"
},
{
"path": "LFG/modules/region_predictor.py",
"chars": 4860,
"preview": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo licens"
},
{
"path": "LFG/modules/util.py",
"chars": 15455,
"preview": "\"\"\"\nCopyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.\nNo licens"
},
{
"path": "LFG/run_hdtf.py",
"chars": 6127,
"preview": "# Estimate flow and occlusion mask via RegionMM (or called MRAA) for MUG dataset\n# this code is based on RegionMM from S"
},
{
"path": "LFG/run_hdtf_crema.py",
"chars": 6129,
"preview": "# Estimate flow and occlusion mask via RegionMM (or called MRAA) for MUG dataset\n# this code is based on RegionMM from S"
},
{
"path": "LFG/sync_batchnorm/__init__.py",
"chars": 449,
"preview": "# -*- coding: utf-8 -*-\n# File : __init__.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2"
},
{
"path": "LFG/sync_batchnorm/batchnorm.py",
"chars": 12973,
"preview": "# -*- coding: utf-8 -*-\n# File : batchnorm.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/"
},
{
"path": "LFG/sync_batchnorm/comm.py",
"chars": 4449,
"preview": "# -*- coding: utf-8 -*-\n# File : comm.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2018\n"
},
{
"path": "LFG/sync_batchnorm/replicate.py",
"chars": 3226,
"preview": "# -*- coding: utf-8 -*-\n# File : replicate.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/"
},
{
"path": "LFG/sync_batchnorm/unittest.py",
"chars": 835,
"preview": "# -*- coding: utf-8 -*-\n# File : unittest.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2"
},
{
"path": "LFG/test_flowautoenc_crema_video.py",
"chars": 13961,
"preview": "# use LFG to reconstruct testing videos and measure the loss in video domain\n# using RegionMM\n\nimport argparse\nimport im"
},
{
"path": "LFG/test_flowautoenc_hdtf_video.py",
"chars": 14016,
"preview": "# use LFG to reconstruct testing videos and measure the loss in video domain\n# using RegionMM\n\nimport argparse\nimport im"
},
{
"path": "LFG/test_flowautoenc_hdtf_video_256.py",
"chars": 11887,
"preview": "# use LFG to reconstruct testing videos and measure the loss in video domain\n# using RegionMM\n\nimport argparse\nimport im"
},
{
"path": "LFG/train.py",
"chars": 7499,
"preview": "# train a LFAE\n# this code is based on RegionMM (MRAA): https://github.com/snap-research/articulated-animation\nimport os"
},
{
"path": "LFG/vis_flow.py",
"chars": 1054,
"preview": "import torch\nimport numpy as np\nimport matplotlib.pyplot as plt\n\ndef visualize_dense_optical_flow(flow_tensor, save_path"
},
{
"path": "PBnet/run_cvae_h_ann_reemb_rope_eye_3.sh",
"chars": 1499,
"preview": "source /home4/intern/lmlin2/.bashrc\nconda activate actor\n# crema rc delta pose\nexport CUDA_VISIBLE_DEVICES=\"0\"\n# python "
},
{
"path": "PBnet/src/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/config.py",
"chars": 266,
"preview": "import os\n\nSMPL_DATA_PATH = \"models/smpl/\"\n\nSMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, \"kintree_table.pkl\")\nSMPL_M"
},
{
"path": "PBnet/src/datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/datasets/datasets_hdtf_pos_chunk_norm_2_fast.py",
"chars": 8729,
"preview": "from os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nim"
},
{
"path": "PBnet/src/datasets/datasets_hdtf_pos_chunk_norm_eye_fast.py",
"chars": 11062,
"preview": "from os import name\nimport sys\n# sys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\n"
},
{
"path": "PBnet/src/datasets/datasets_hdtf_pos_df.py",
"chars": 8059,
"preview": "from os import name\nfrom src.datasets.datasets_hdtf_pos import HDTF\nimport sys\nsys.path.append('your_path')\n\nimport os\ni"
},
{
"path": "PBnet/src/datasets/datasets_hdtf_pos_dict_norm_2.py",
"chars": 9906,
"preview": "from os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nim"
},
{
"path": "PBnet/src/datasets/datasets_hdtf_wpose_lmk_block.py",
"chars": 10335,
"preview": "from os import name\nimport sys\nsys.path.append('your_path')\n\nimport os\nimport random\nimport torch\n\nimport numpy as np\nim"
},
{
"path": "PBnet/src/datasets/get_dataset.py",
"chars": 792,
"preview": "def get_dataset(name=\"ntu13\"):\n if name == \"ntu13\":\n from .ntu13 import NTU13\n return NTU13\n elif na"
},
{
"path": "PBnet/src/datasets/tools.py",
"chars": 429,
"preview": "import os\nimport string\n\n\ndef parse_info_name(path):\n name = os.path.splitext(os.path.split(path)[-1])[0]\n info = "
},
{
"path": "PBnet/src/evaluate/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/evaluate/action2motion/accuracy.py",
"chars": 565,
"preview": "import torch\n\n\ndef calculate_accuracy(model, motion_loader, num_labels, classifier, device):\n confusion = torch.zeros"
},
{
"path": "PBnet/src/evaluate/action2motion/diversity.py",
"chars": 1552,
"preview": "import torch\nimport numpy as np\n\n\n# from action2motion\ndef calculate_diversity_multimodality(activations, labels, num_la"
},
{
"path": "PBnet/src/evaluate/action2motion/evaluate.py",
"chars": 3751,
"preview": "import torch\nimport numpy as np\nfrom .models import load_classifier, load_classifier_for_fid\nfrom .accuracy import calcu"
},
{
"path": "PBnet/src/evaluate/action2motion/fid.py",
"chars": 2350,
"preview": "import numpy as np\nfrom scipy import linalg\n\n\n# from action2motion\ndef calculate_fid(statistics_1, statistics_2):\n re"
},
{
"path": "PBnet/src/evaluate/action2motion/models.py",
"chars": 5139,
"preview": "import torch\nimport torch.nn as nn\n\n\n# adapted from action2motion to take inputs of different lengths\nclass MotionDiscri"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae.py",
"chars": 1332,
"preview": "import sys\nsys.path.append('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master')\n\nfrom src.parser.evaluation im"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_debug.py",
"chars": 1244,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_f3.py",
"chars": 1232,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_f3_debug.py",
"chars": 1238,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_f3_mel.py",
"chars": 1237,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_norm.py",
"chars": 1268,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_norm_all.py",
"chars": 1300,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_norm_all_seg.py",
"chars": 1300,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_norm_all_seg_weye.py",
"chars": 1314,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_norm_all_seg_weye2.py",
"chars": 1364,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_norm_eye_pose.py",
"chars": 1263,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_norm_eye_pose_test.py",
"chars": 1254,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nfrom src.parser.evaluation import parser\nfrom src.datasets.datasets_crema"
},
{
"path": "PBnet/src/evaluate/evaluate_cvae_onlyeye_all_seg.py",
"chars": 1345,
"preview": "import sys\nsys.path.append('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master')\n\nfrom src.parser.evaluation im"
},
{
"path": "PBnet/src/evaluate/othermetrics/acceleration.py",
"chars": 934,
"preview": "import torch\nimport numpy as np\n\nfrom src.utils.tensors import lengths_to_mask\n\n\ndef calculate_acceletation(motionloader"
},
{
"path": "PBnet/src/evaluate/othermetrics/evaluation.py",
"chars": 2871,
"preview": "import torch\nimport numpy as np\n\nfrom ..action2motion.diversity import calculate_diversity_multimodality\nfrom .accelerat"
},
{
"path": "PBnet/src/evaluate/stgcn/accuracy.py",
"chars": 533,
"preview": "import torch\n\n\ndef calculate_accuracy(model, motion_loader, num_labels, classifier, device):\n confusion = torch.zeros"
},
{
"path": "PBnet/src/evaluate/stgcn/diversity.py",
"chars": 1618,
"preview": "import torch\nimport numpy as np\n\n\n# from action2motion\ndef calculate_diversity_multimodality(activations, labels, num_la"
},
{
"path": "PBnet/src/evaluate/stgcn/evaluate.py",
"chars": 3891,
"preview": "import torch\nimport numpy as np\nfrom .accuracy import calculate_accuracy\nfrom .fid import calculate_fid\nfrom .diversity "
},
{
"path": "PBnet/src/evaluate/stgcn/fid.py",
"chars": 2350,
"preview": "import numpy as np\nfrom scipy import linalg\n\n\n# from action2motion\ndef calculate_fid(statistics_1, statistics_2):\n re"
},
{
"path": "PBnet/src/evaluate/tables/archtable.py",
"chars": 6278,
"preview": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, pow"
},
{
"path": "PBnet/src/evaluate/tables/bstable.py",
"chars": 5353,
"preview": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, pow"
},
{
"path": "PBnet/src/evaluate/tables/easy_table.py",
"chars": 2616,
"preview": "import os\nimport glob\nimport math\nimport numpy as np\n\nfrom ..tools import load_metrics\n\n\ndef get_gtname(mname):\n retu"
},
{
"path": "PBnet/src/evaluate/tables/easy_table_A2M.py",
"chars": 2781,
"preview": "import os\nimport glob\nimport math\nimport numpy as np\n\nfrom ..tools import load_metrics\n\n\ndef valformat(val, power=3):\n "
},
{
"path": "PBnet/src/evaluate/tables/kltable.py",
"chars": 5336,
"preview": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, pow"
},
{
"path": "PBnet/src/evaluate/tables/latexmodela2m.py",
"chars": 2779,
"preview": "import os\nimport glob\nimport math\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef get_gtname(mname):\n retur"
},
{
"path": "PBnet/src/evaluate/tables/latexmodelsa2m.py",
"chars": 4902,
"preview": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, pow"
},
{
"path": "PBnet/src/evaluate/tables/latexmodelsstgcn.py",
"chars": 5180,
"preview": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef get_gtname(mname):"
},
{
"path": "PBnet/src/evaluate/tables/losstable.py",
"chars": 6079,
"preview": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, pow"
},
{
"path": "PBnet/src/evaluate/tables/maketable.py",
"chars": 9222,
"preview": "import os\nimport glob\nimport math\n\nfrom .tools import load_metrics\n\nMETRICS = {\"joints\": [\"acceleration\", \"rc\", \"diversi"
},
{
"path": "PBnet/src/evaluate/tables/numlayertable.py",
"chars": 5373,
"preview": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, pow"
},
{
"path": "PBnet/src/evaluate/tables/posereptable.py",
"chars": 5627,
"preview": "import os\nimport glob\nimport math\nimport re\nimport numpy as np\n\nfrom .tools import load_metrics\n\n\ndef valformat(val, pow"
},
{
"path": "PBnet/src/evaluate/tools.py",
"chars": 449,
"preview": "import yaml\n\n\ndef format_metrics(metrics, formatter=\"{:.6}\"):\n newmetrics = {}\n for key, val in metrics.items():\n "
},
{
"path": "PBnet/src/evaluate/tvae_eval.py",
"chars": 3894,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom torch.utils.data import DataLoader\nfrom "
},
{
"path": "PBnet/src/evaluate/tvae_eval_norm.py",
"chars": 3972,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/evaluate/tvae_eval_norm_all.py",
"chars": 5069,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/evaluate/tvae_eval_norm_eye_pose.py",
"chars": 4288,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/evaluate/tvae_eval_norm_eye_pose_seg.py",
"chars": 6762,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/evaluate/tvae_eval_norm_seg.py",
"chars": 6209,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/evaluate/tvae_eval_onlyeye_all_seg.py",
"chars": 6212,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/evaluate/tvae_eval_single.py",
"chars": 6409,
"preview": "import torch\nfrom tqdm import tqdm\nimport sys\nimport os\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nparent_"
},
{
"path": "PBnet/src/evaluate/tvae_eval_single_both_eye_pose.py",
"chars": 5510,
"preview": "import torch\nfrom tqdm import tqdm\nimport os\nimport sys\n# adding path of PBnet\ncurrent_dir = os.path.dirname(os.path.abs"
},
{
"path": "PBnet/src/evaluate/tvae_eval_std.py",
"chars": 3888,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/evaluate/tvae_eval_train.py",
"chars": 3785,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/evaluate/tvae_eval_train_norm.py",
"chars": 4486,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/evaluate/tvae_eval_train_std.py",
"chars": 3819,
"preview": "import torch\nfrom tqdm import tqdm\n\nfrom src.utils.fixseed import fixseed\n\nfrom src.utils.utils import MultiEpochsDataLo"
},
{
"path": "PBnet/src/generate/generate_sequences.py",
"chars": 6036,
"preview": "import os\n\nimport matplotlib.pyplot as plt\nimport torch\nimport numpy as np\n\nfrom src.utils.get_model_and_data import get"
},
{
"path": "PBnet/src/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/models/architectures/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/models/architectures/autotrans.py",
"chars": 7434,
"preview": "from .transformer import Encoder_TRANSFORMER as Encoder_AUTOTRANS # noqa\n\nimport torch\nimport torch.nn as nn\nimport tor"
},
{
"path": "PBnet/src/models/architectures/fc.py",
"chars": 3587,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Encoder_FC(nn.Module):\n def __init__(self,"
},
{
"path": "PBnet/src/models/architectures/gru.py",
"chars": 5058,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef augment_x(x, y, mask, lengths, num_classes, con"
},
{
"path": "PBnet/src/models/architectures/grutrans.py",
"chars": 130,
"preview": "from .gru import Encoder_GRU as Encoder_GRUTRANS # noqa\nfrom .transformer import Decoder_TRANSFORMER as Decoder_GRUTRAN"
},
{
"path": "PBnet/src/models/architectures/mlp.py",
"chars": 11257,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Upsample(nn.Module):\n de"
},
{
"path": "PBnet/src/models/architectures/resnet34.py",
"chars": 8199,
"preview": "import torch\nimport torch.nn as nn\nfrom sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\nfrom sync_batchnorm"
},
{
"path": "PBnet/src/models/architectures/tools/embeddings.py",
"chars": 7192,
"preview": "# This file is taken from signjoey repository\nimport math\nimport torch\n\nfrom torch import nn, Tensor\nfrom ....tools.tool"
},
{
"path": "PBnet/src/models/architectures/tools/resnet.py",
"chars": 7767,
"preview": "import torch\nimport torch.nn as nn\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n \"\"\"3x3 convo"
},
{
"path": "PBnet/src/models/architectures/tools/transformer_layers.py",
"chars": 9739,
"preview": "# -*- coding: utf-8 -*-\nimport math\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\n\n# Took from https://git"
},
{
"path": "PBnet/src/models/architectures/tools/util.py",
"chars": 556,
"preview": "from torch import nn\n\nimport torch.nn.functional as F\nimport torch\n\nfrom sync_batchnorm import SynchronizedBatchNorm2d a"
},
{
"path": "PBnet/src/models/architectures/transformer.py",
"chars": 12427,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass PositionalEncoding(nn.Modu"
},
{
"path": "PBnet/src/models/architectures/transformerdecoder.py",
"chars": 27726,
"preview": "import copy\nfrom typing import Optional, Any, Union, Callable\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn.funct"
},
{
"path": "PBnet/src/models/architectures/transformerdecoder4.py",
"chars": 7203,
"preview": "import copy\nfrom typing import Optional, Any, Union, Callable\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn.funct"
},
{
"path": "PBnet/src/models/architectures/transformerdecoder5.py",
"chars": 7242,
"preview": "import copy\nfrom typing import Optional, Any, Union, Callable\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn.funct"
},
{
"path": "PBnet/src/models/architectures/transformerreemb.py",
"chars": 17801,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repe"
},
{
"path": "PBnet/src/models/architectures/transformerreemb5.py",
"chars": 16440,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repe"
},
{
"path": "PBnet/src/models/architectures/transformerreemb6.py",
"chars": 16382,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repe"
},
{
"path": "PBnet/src/models/architectures/transgru.py",
"chars": 132,
"preview": "from .transformer import Encoder_TRANSFORMER as Encoder_TRANSGRU # noqa\nfrom .gru import Decoder_GRU as Decoder_TRANSGR"
},
{
"path": "PBnet/src/models/get_model.py",
"chars": 1355,
"preview": "import importlib\n\nimport sys\nimport os\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nparent_dir = os.path.dir"
},
{
"path": "PBnet/src/models/modeltype/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/models/modeltype/cae.py",
"chars": 6743,
"preview": "import torch\nimport torch.nn as nn\n\nfrom ..tools.losses import get_loss_function\nimport torch.nn.functional as F\n# from "
},
{
"path": "PBnet/src/models/modeltype/cae_0.py",
"chars": 6284,
"preview": "import torch\nimport torch.nn as nn\n\nfrom ..tools.losses import get_loss_function\n# from ..rotation2xyz import Rotation2x"
},
{
"path": "PBnet/src/models/modeltype/cvae.py",
"chars": 1347,
"preview": "import torch\nfrom .cae import CAE\n\n\nclass CVAE(CAE):\n def reparameterize(self, batch, seed=None):\n mu, logvar "
},
{
"path": "PBnet/src/models/modeltype/lstm.py",
"chars": 3609,
"preview": "import torch\nimport torch.nn as nn\n\nfrom ..tools.losses import get_loss_function\nimport torch.nn.functional as F\n# from "
},
{
"path": "PBnet/src/models/rotation2xyz.py",
"chars": 3178,
"preview": "import torch\nimport src.utils.rotation_conversions as geometry\n\nfrom .smpl import SMPL, JOINTSTYPE_ROOT\nfrom .get_model "
},
{
"path": "PBnet/src/models/smpl.py",
"chars": 3657,
"preview": "import numpy as np\nimport torch\n\nimport contextlib\n\nfrom smplx import SMPLLayer as _SMPLLayer\nfrom smplx.lbs import vert"
},
{
"path": "PBnet/src/models/tools/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/models/tools/graphconv.py",
"chars": 1297,
"preview": "import math\n\nimport torch\n\nfrom torch.nn.parameter import Parameter\nfrom torch.nn.modules.module import Module\n\n\nclass G"
},
{
"path": "PBnet/src/models/tools/hessian_penalty.py",
"chars": 6690,
"preview": "\"\"\"\n## Adapted to work with our \"batches\"\nOfficial PyTorch implementation of the Hessian Penalty regularization term fro"
},
{
"path": "PBnet/src/models/tools/losses.py",
"chars": 5548,
"preview": "import torch\nfrom einops import rearrange\nimport torch.nn.functional as F\nfrom .hessian_penalty import hessian_penalty\nf"
},
{
"path": "PBnet/src/models/tools/mmd.py",
"chars": 712,
"preview": "import torch\n\n\n# from https://github.com/napsternxg/pytorch-practice/blob/master/Pytorch%20-%20MMD%20VAE.ipynb\ndef compu"
},
{
"path": "PBnet/src/models/tools/msssim_loss.py",
"chars": 4958,
"preview": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\n\n\ndef gaussian(window_size, sigma):"
},
{
"path": "PBnet/src/models/tools/normalize_data.py",
"chars": 934,
"preview": "import torch\n\ndef normalize_data(data, min_vals, max_vals):\n min_vals = min_vals.unsqueeze(0).unsqueeze(0) \n max_"
},
{
"path": "PBnet/src/models/tools/ssim_loss.py",
"chars": 3836,
"preview": "import torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nimport numpy as np\nfrom math import exp"
},
{
"path": "PBnet/src/models/tools/tools.py",
"chars": 1093,
"preview": "import torch.nn as nn\nfrom torch.nn.modules.module import ModuleAttributeError\n\n\nclass AutoParams(nn.Module):\n def __"
},
{
"path": "PBnet/src/parser/base.py",
"chars": 1042,
"preview": "from argparse import ArgumentParser # noqa\n\n\ndef add_misc_options(parser):\n group = parser.add_argument_group('Misce"
},
{
"path": "PBnet/src/parser/checkpoint.py",
"chars": 2980,
"preview": "import os\nfrom .base import ArgumentParser, adding_cuda\nfrom .tools import load_args\n\n\ndef parser():\n parser = Argume"
},
{
"path": "PBnet/src/parser/dataset.py",
"chars": 1925,
"preview": "from src.datasets.dataset import POSE_REPS\n\n\ndef add_dataset_options(parser):\n group = parser.add_argument_group('Dat"
},
{
"path": "PBnet/src/parser/evaluation.py",
"chars": 1469,
"preview": "import argparse\nimport os\nimport sys\nsys.path.append('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master')\n\nfro"
},
{
"path": "PBnet/src/parser/finetunning.py",
"chars": 1264,
"preview": "import os\nfrom .base import argparse, adding_cuda, load_args\n\n \ndef parser():\n parser = argparse.ArgumentParser()\n"
},
{
"path": "PBnet/src/parser/generate.py",
"chars": 1799,
"preview": "import os\n\nfrom src.models.get_model import JOINTSTYPES\nfrom .base import ArgumentParser, add_cuda_options, adding_cuda\n"
},
{
"path": "PBnet/src/parser/model.py",
"chars": 2702,
"preview": "from src.models.get_model import LOSSES, MODELTYPES, ARCHINAMES\n\n\ndef add_model_options(parser):\n group = parser.add_"
},
{
"path": "PBnet/src/parser/recognition.py",
"chars": 1101,
"preview": "import os\n\nfrom .base import argparse, add_misc_options, add_cuda_options, adding_cuda\nfrom .tools import save_args\nfrom"
},
{
"path": "PBnet/src/parser/tools.py",
"chars": 375,
"preview": "import os\nimport yaml\n\n\ndef save_args(opt, folder):\n os.makedirs(folder, exist_ok=True)\n \n # Save as yaml\n o"
},
{
"path": "PBnet/src/parser/training.py",
"chars": 1912,
"preview": "import os\n\nfrom .base import add_misc_options, add_cuda_options, adding_cuda, ArgumentParser\nfrom .tools import save_arg"
},
{
"path": "PBnet/src/parser/visualize.py",
"chars": 3648,
"preview": "import os\n\nfrom src.models.get_model import JOINTSTYPES\nfrom .base import ArgumentParser, add_cuda_options, adding_cuda\n"
},
{
"path": "PBnet/src/preprocess/humanact12_process.py",
"chars": 4549,
"preview": "import os\nimport numpy as np\nimport pickle as pkl\nfrom phspdtools import CameraParams\n\n\ndef splitname(name):\n subject"
},
{
"path": "PBnet/src/preprocess/phspdtools.py",
"chars": 5645,
"preview": "# taken and adapted from https://github.com/JimmyZou/PolarHumanPoseShape/\nimport pickle\nimport numpy as np\nimport os\n\n\nc"
},
{
"path": "PBnet/src/preprocess/uestc_vibe_postprocessing.py",
"chars": 7452,
"preview": "import numpy as np\nimport pickle as pkl\nimport tarfile\nimport os\nimport scipy.io as sio\nfrom tqdm import tqdm\nimport src"
},
{
"path": "PBnet/src/recognition/compute_accuracy.py",
"chars": 2098,
"preview": "import os\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom tqdm import tqdm\n\nfrom src.utils.get_model_and_data"
},
{
"path": "PBnet/src/recognition/get_model.py",
"chars": 475,
"preview": "from .models.stgcn import STGCN\n\n\ndef get_model(parameters):\n layout = \"smpl\" if parameters[\"glob\"] else \"smpl_noglob"
},
{
"path": "PBnet/src/recognition/models/stgcn.py",
"chars": 7967,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .stgcnutils.tgcn import ConvTemporalGraphical\nf"
},
{
"path": "PBnet/src/recognition/models/stgcnutils/graph.py",
"chars": 7500,
"preview": "import numpy as np\nimport pickle as pkl\n\nfrom src.config import SMPL_KINTREE_PATH\n\n\nclass Graph:\n \"\"\" The Graph to mo"
},
{
"path": "PBnet/src/recognition/models/stgcnutils/tgcn.py",
"chars": 2398,
"preview": "# The based unit of graph convolutional networks.\n\nimport torch\nimport torch.nn as nn\n\n\nclass ConvTemporalGraphical(nn.M"
},
{
"path": "PBnet/src/render/renderer.py",
"chars": 4768,
"preview": "\"\"\"\nThis script is borrowed from https://github.com/mkocabas/VIBE\n Adhere to their licence to use this script\n It has be"
},
{
"path": "PBnet/src/render/rendermotion.py",
"chars": 3273,
"preview": "import numpy as np\nimport imageio\nimport os\nimport argparse\nfrom tqdm import tqdm\nfrom .renderer import get_renderer\n\n\nd"
},
{
"path": "PBnet/src/train/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/train/train_cvae_ganloss_ann_eye.py",
"chars": 8242,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nimport os\nimport torch\nfrom torch.utils.tensorboard import SummaryWriter\n"
},
{
"path": "PBnet/src/train/train_cvae_ganloss_ann_fast.py",
"chars": 6985,
"preview": "import sys\nsys.path.append('your_path/PBnet')\n\nimport os\nimport torch\nfrom torch.utils.tensorboard import SummaryWriter\n"
},
{
"path": "PBnet/src/train/trainer.py",
"chars": 1476,
"preview": "import torch\nfrom tqdm import tqdm\n\n\ndef train_or_test(model, optimizer, iterator, device, mode=\"train\"):\n if mode =="
},
{
"path": "PBnet/src/train/trainer_gan.py",
"chars": 2643,
"preview": "import torch\nfrom tqdm import tqdm\nimport time\n\n\ndef train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, d"
},
{
"path": "PBnet/src/train/trainer_gan_ann.py",
"chars": 2817,
"preview": "import torch\nfrom tqdm import tqdm\nimport time\n\n\ndef train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, d"
},
{
"path": "PBnet/src/utils/PYTORCH3D_LICENSE",
"chars": 1546,
"preview": "BSD License\n\nFor PyTorch3D software\n\nCopyright (c) Facebook, Inc. and its affiliates. All rights reserved.\n\nRedistributi"
},
{
"path": "PBnet/src/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/utils/fixseed.py",
"chars": 297,
"preview": "import numpy as np\nimport torch\nimport random\n\n\ndef fixseed(seed):\n random.seed(seed)\n np.random.seed(seed)\n to"
},
{
"path": "PBnet/src/utils/get_model_and_data.py",
"chars": 418,
"preview": "from ..datasets.get_dataset import get_datasets\nfrom ..recognition.get_model import get_model as get_rec_model\nfrom ..mo"
},
{
"path": "PBnet/src/utils/misc.py",
"chars": 649,
"preview": "import torch\n\n\ndef to_numpy(tensor):\n if torch.is_tensor(tensor):\n return tensor.cpu().numpy()\n elif type(t"
},
{
"path": "PBnet/src/utils/rotation_conversions.py",
"chars": 18097,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.\n# Check PYTORCH3D_LICENCE before use\n\nimport fun"
},
{
"path": "PBnet/src/utils/tensors.py",
"chars": 1177,
"preview": "import torch\n\n\ndef lengths_to_mask(lengths):\n max_len = max(lengths)\n mask = torch.arange(max_len, device=lengths."
},
{
"path": "PBnet/src/utils/tensors_eye.py",
"chars": 1476,
"preview": "import torch\n\n\ndef lengths_to_mask(lengths):\n max_len = max(lengths)\n mask = torch.arange(max_len, device=lengths."
},
{
"path": "PBnet/src/utils/tensors_eye_eval.py",
"chars": 1508,
"preview": "import torch\n\n\ndef lengths_to_mask(lengths):\n max_len = max(lengths)\n mask = torch.arange(max_len, device=lengths."
},
{
"path": "PBnet/src/utils/tensors_hdtf.py",
"chars": 2169,
"preview": "import torch\n\n\ndef lengths_to_mask(lengths):\n max_len = max(lengths)\n mask = torch.arange(max_len, device=lengths."
},
{
"path": "PBnet/src/utils/tensors_onlyeye.py",
"chars": 2315,
"preview": "import torch\n\n\ndef lengths_to_mask(lengths):\n max_len = max(lengths)\n mask = torch.arange(max_len, device=lengths."
},
{
"path": "PBnet/src/utils/utils.py",
"chars": 2799,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom threading import Thread\nfrom queue import Queue\n"
},
{
"path": "PBnet/src/utils/video.py",
"chars": 866,
"preview": "import numpy as np\nimport imageio\n\n\ndef load_video(filename):\n vid = imageio.get_reader(filename, 'ffmpeg')\n fps ="
},
{
"path": "PBnet/src/visualize/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "PBnet/src/visualize/anim.py",
"chars": 4413,
"preview": "import numpy as np\nimport torch\nimport imageio\n\n# from action2motion\n# Define a kinematic tree for the skeletal struture"
},
{
"path": "PBnet/src/visualize/visualize.py",
"chars": 16322,
"preview": "import os\nimport imageio\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\nfrom .an"
},
{
"path": "PBnet/src/visualize/visualize_checkpoint.py",
"chars": 779,
"preview": "import os\n\nimport matplotlib.pyplot as plt\nimport torch\nfrom src.utils.get_model_and_data import get_model_and_data\nfrom"
},
{
"path": "PBnet/src/visualize/visualize_dataset.py",
"chars": 777,
"preview": "import matplotlib.pyplot as plt\n# import torch\nimport os\n\nfrom src.datasets.get_dataset import get_dataset\nfrom src.util"
},
{
"path": "PBnet/src/visualize/visualize_latent_space.py",
"chars": 4774,
"preview": "import matplotlib.pyplot as plt\nimport numpy as np\nimport os\nimport scipy\n\nimport torch\nimport torch.nn.functional as F\n"
},
{
"path": "PBnet/src/visualize/visualize_nturefined.py",
"chars": 2161,
"preview": "import matplotlib.pyplot as plt\nimport torch\n\nfrom src.datasets.get_dataset import get_dataset\nfrom src.utils.anim impor"
},
{
"path": "PBnet/src/visualize/visualize_sequence.py",
"chars": 1964,
"preview": "import os\n\nimport matplotlib.pyplot as plt\nimport torch\nimport numpy as np\n\nfrom src.datasets.get_dataset import get_dat"
},
{
"path": "README.md",
"chars": 6438,
"preview": "# 🌅 DAWN: Dynamic Frame Avatar with Non-autoregressive Diffusion Framework for Talking Head Video Generation\n\n[
About this extraction
This page contains the full source code of the Hanbo-Cheng/DAWN-pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 291 files (72.5 MB), approximately 369.3k tokens, and a symbol index with 1652 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.